GithubHelp home page GithubHelp logo

Comments (11)

jurahul avatar jurahul commented on June 30, 2024 2

I suspect the AllReduceCombiner enable/disable is a red herring. Its possible that disabling it changes other things (likely buffer assignment) in a way that hides the issue.

from xla.

wenscarl avatar wenscarl commented on June 30, 2024 1

Update: Comparing the HLO dumped by good runs vs bad runs, one pass AllReduceCombiner stands out. Good run has this pass while bad run doesn't. Disabling it(recompilation required) triggers the IMA for all matrix size.

from xla.

reedwm avatar reedwm commented on June 30, 2024 1

I think I found the problem. @wenscarl, you were right that the issue is with XLA and not cublas.

In BlasLt::DoMatmul, right before running the matmul, we set the scales on the cublasLtMatmulDesc_t here. But the cublasLtMatmulDesc_t is shared across devices, so if two devices try to modify the cublasLtMatmulDesc_t at the same time, one device will see the other device's scales.

We should either copy the cublasLtMatmulDesc_t or have a unique cublasLtMatmulDesc_t per device. I'll create a fix.

from xla.

cheshire avatar cheshire commented on June 30, 2024

@reedwm

from xla.

kaixih avatar kaixih commented on June 30, 2024

Also, I think it is worth to share that we also found the IMA issue was gone when the --d is small (e.g. 512, 1024, 2048), where the pointers of these scaling factors look correct on the cublasLt call sites. Can you confirm @wenscarl ?

from xla.

wenscarl avatar wenscarl commented on June 30, 2024

Also, I think it is worth to share that we also found the IMA issue was gone when the --d is small (e.g. 512, 1024, 2048), where the pointers of these scaling factors look correct on the cublasLt call sites. Can you confirm @wenscarl ?

That's correct. This error is not deterministic. When reproducing it, please try with different --d size since larger matrix size helps to trigger the error in general.

from xla.

pjannaty avatar pjannaty commented on June 30, 2024

where the cScalePointer and dScalePointer are clearly wrong

Have we tracked down where the pointer mix up occurs? My understanding is we believe the rewritten HLO looks good, so does that mean the mix up must be happening in the stream executor?

from xla.

reedwm avatar reedwm commented on June 30, 2024

To reproduce the error on my machine, I had to use a size of 8192 instead of 4096. To make it easier to debug, I created a small HLO program to reproduce:

HloModule m

ENTRY f {
  x = f32[8192,8192]{1,0} parameter(0), sharding={replicated}
  x_converted = f8e4m3fn[8192,8192]{1,0} convert(x)
  y = f32[8192,8192]{1,0} parameter(1), sharding={replicated}
  y_converted = f8e4m3fn[8192,8192]{1,0} convert(y)
  output = f8e4m3fn[8192,8192]{1,0} dot(y_converted, x_converted), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  ROOT output_converted = f32[8192,8192]{1,0} convert(output)
}

Running the following crashes after about a nondeterministic number of iterations (typically about 30)

bazel build //xla/tools/multihost_hlo_runner:hlo_runner_main && TF_CPP_VMODULE=functional_hlo_runner=1 CUDA_LAUNCH_BLOCKING=1 bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning=true --num_partitions=2 --num_replicas=1 --hlo_file=/home/reedwm/v/fp8_hlo/fp8_crash.hlo --num_repeats=100

Running with compute-sanitizer causes it to not crash.

To see if this is potentially a cublas issue, I replaced the cublas matmul call with a fake matmul kernel which simply reads its inputs and writes arbitrary outputs (reedwm@18a2a9a). This caused the crash to disappear (of course the matmul outputs were bogus), which indicates this is likely a cublas issue. This means presumably the scaling factors are on the correct devices.

@wenscarl, how do you know what devices the scaling factors are on? I'm not sure how to directly determine this, but my fake matmul shows that when I run with the small HLO above, the scales are likely on the correct device.

from xla.

kaixih avatar kaixih commented on June 30, 2024

Good find. Thanks @reedwm. Today, we also found that these SetAttr didn't correctly set the pointers to the desc (Shu inserted GetAttr right after the SetAttr but saw wrong pointers in the desc). Yes, looking forward to your fix (it seems we should remove the mutex scope and let each device has its own desc, right?)

from xla.

wenscarl avatar wenscarl commented on June 30, 2024

For the bad run on 2 GPUs, comparing the device pointers before cublasLtMatmulDescSetAttribute and after cublasLtMatmulDescGetAttribute :
SetAttr host pointer: 0x7f0967ff8240
SetAttr device pointer: 0x7f2300455f00
SetAttr host pointer: 0x7f08dbff9240
SetAttr device pointer: 0x7f2300255f00

GetAttr host pointer: 0x7f08dbff9280
GetAttr device pointer: 0x7f2300255f00
GetAttr host pointer: 0x7f0967ff8280
GetAttr device pointer: 0x7f2300255f00

The highlighted 2 pointers shouldn't be the same.

from xla.

wenscarl avatar wenscarl commented on June 30, 2024

To reproduce the error on my machine, I had to use a size of 8192 instead of 4096. To make it easier to debug, I created a small HLO program to reproduce:

HloModule m

ENTRY f {
  x = f32[8192,8192]{1,0} parameter(0), sharding={replicated}
  x_converted = f8e4m3fn[8192,8192]{1,0} convert(x)
  y = f32[8192,8192]{1,0} parameter(1), sharding={replicated}
  y_converted = f8e4m3fn[8192,8192]{1,0} convert(y)
  output = f8e4m3fn[8192,8192]{1,0} dot(y_converted, x_converted), lhs_contracting_dims={1}, rhs_contracting_dims={1}
  ROOT output_converted = f32[8192,8192]{1,0} convert(output)
}

Running the following crashes after about a nondeterministic number of iterations (typically about 30)

bazel build //xla/tools/multihost_hlo_runner:hlo_runner_main && TF_CPP_VMODULE=functional_hlo_runner=1 CUDA_LAUNCH_BLOCKING=1 bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning=true --num_partitions=2 --num_replicas=1 --hlo_file=/home/reedwm/v/fp8_hlo/fp8_crash.hlo --num_repeats=100

Running with compute-sanitizer causes it to not crash.

To see if this is potentially a cublas issue, I replaced the cublas matmul call with a fake matmul kernel which simply reads its inputs and writes arbitrary outputs (reedwm@18a2a9a). This caused the crash to disappear (of course the matmul outputs were bogus), which indicates this is likely a cublas issue. This means presumably the scaling factors are on the correct devices.

@wenscarl, how do you know what devices the scaling factors are on? I'm not sure how to directly determine this, but my fake matmul shows that when I run with the small HLO above, the scales are likely on the correct device.

I used CUDA API cudaPointerGetAttributes to query the device pointer.

from xla.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.