Comments (5)
Thanks for confirming @davisyoshida! I will close this bug as fixed.
from jax.
Thanks for reporting this, Davis. I sent a PR fixing this. In the meantime you can do
from jaxlib.triton import dialect
dialect.permute = dialect.trans
to workaround the issue.
One caveat, though: the Triton version we have internally crashes if any of the dot operands is transposed. It is possible that the jaxlib version did not get the problematic upstream changes, but if not, we'd have to wait until the issue is fixed in Triton. I started a discussion in their Slack.
from jax.
the Triton version we have internally crashes if any of the dot operands is transposed
Good to know, I'm sure that just saved me an hour of fruitless debugging. Thanks!
from jax.
According to the folks on the Triton Slack channel, dot with transposed inputs is not supported on GPUs below Turing. I was testing on V100, so this seems to explain the crashes.
from jax.
@superbobry The monkeypatch works for me, and thanks for making it so much easier to get pallas up and running on GPU
from jax.
Related Issues (20)
- `jax.numpy.isclose` differs from `numpy.isclose`
- Strange behavior of jax.experimental.jet
- with_sharding_constraint raise error when some leaf have unspecified sharding
- Colab tpu initialization in XLA and JAX fails HOT 2
- TPU backend gets stuck HOT 3
- Conditional array update on GPU using jnp.where vs fori_loop HOT 2
- Large numerical error when using vmap and bfloat16/tensorfloat32 matmul precision, only on A100 GPU
- TPU XlaRuntimeError involving nn.Conv, transpose and avg_pool
- `jax.vmap(jax.pure_callback(...), in_axes=1)` is broken
- Errors when building on AMD GPU
- Executing genrule @tsl//tsl/cuda:cudnn_stub_gen failed
- XLA "cannot remove instruction" when compiling big MoE model HOT 3
- `ensure_compile_time_eval` does not error out for traced arrays HOT 3
- Vectorised operation on string arrays? HOT 1
- Pallas Tutorial outputs RESOURCE_EXHAUSTED HOT 1
- Unimplemented primitive in Pallas: slice HOT 1
- Marking non-trainable / frozen parameters HOT 3
- jax.clear_backends() does not release device memory
- XLA Check Failed: options.is_autotuning_compilation HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.