Comments (20)
I think these newish attention replacements will take time to be adopted particularly because the dust has not settled on them and it takes a while for wide-scale experimentation and large-scale training with them to truly prove them out.
IMO all it takes is a leap for a highly-funded industrial lab to go out on a limb and train an LLM with one of these...
For instance, Mistral AI essentially has a linear cost attention mechanism based on SWA
- sliding window attention - one could argue of course how effective it is at truly capturing information across long context.
all these frameworks cannot do.
I think this is an overstatement? I think it simply has not been tried out in Triton yet. But it should not be that hard. But whether the performance matches is an open question.
I just hope that more devs become aware of how powerful triton is so that there's more experimentation with implementing these kernels.
from flash-attention-jax.
Wow this is open from almost a year ago...
I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).
from flash-attention-jax.
@niemiaszek I just recently saw they named and added docs for pallas, looks very interesting. JAX is also improving our ability to customize how networks are sharded across accelerators and are publishing papers on their results wrt efficiency, pretty cool I think. Unfortunately I don't have time to do a fair comparison between torch and jax with attention but it seems that whoever takes the time to delve into it, especially jax's recent improvements, would certainly benefit if they have a need.
Even if we don't take the time, it looks like the jax team continually adds their efficiency findings into jax as defaults so we don't have to implement ourselves.
from flash-attention-jax.
@lucidrains I'd agree as far as single-device optimizations go. I solely use jax because my work deals mainly with RL and I've already built everything out, but for things like language and vision models, resources like xformers are hard to beat. I do like jax's work toward multi-device customization especially from an RL perspective.
from flash-attention-jax.
I am also curious, additionally maybe it is possible to use cuda code with jax ?
https://github.com/dfm/extending-jax
from flash-attention-jax.
https://colab.research.google.com/drive/1-YCU9ps4gNuROJ3_8MLjSpbICGHaySxh?usp=sharing
from flash-attention-jax.
Fantastic! have you done experiment with the same data on original flash attention ?
from flash-attention-jax.
Not yet
from flash-attention-jax.
Hello, could I ask if this works with TPUs?
from flash-attention-jax.
Here's an updated notebook that precompiles jit and blocks results until ready for anyone interested:
https://colab.research.google.com/drive/11QKRdgMtcivrJNmjTrf2bXTE5yXkXl_Z?usp=sharing
Looks like JAX compiles vanilla attention in a way to be faster than jax flash attention, so no need to change to flash attention if you use JAX.
from flash-attention-jax.
Wow this is open from almost a year ago...
I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).
Would be definitely nice to see such benchmark, but I can imagine how hard is comparing JAX vs PyTorch (GPU/TPU), with many optimized implementations for each device. For PyTorch with GPU we have Triton/CUDA, but JAX recently has also added Triton-like mechanism for writing custom Kernels with GPU/TPU - Pallas. You can even find implementation of attention in it here.
from flash-attention-jax.
from what i've heard, flash attention doesn't work well on TPUs, but i haven't kept up with the latest iteration of their chip design.
Pallas is just a wrapper around Triton, developed at OpenAI for GPUs. you will basically be always limited by what the Triton compiler can do
from flash-attention-jax.
while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.
from flash-attention-jax.
while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.
Well, I would argue that in this day, that's no longer such a hard pill given the wide adoption of tiled programming paradigm like Triton (e.g. PyTorch - both codegen + incoming custom kernels, JAX - e.g. Pallas, hardware vendors including NVIDIA, AMD, Intel) which greatly reduces the effort and complexity of getting SOTA perf on GPUs.
from flash-attention-jax.
@jon-chuang hmm, still a bit early to declare that imho
we'll see, i hope so!
from flash-attention-jax.
Yes, Triton is still not 100% (some matmul kernel size and certain kernels like flash attention backwards are still not SOTA). But it's certainly the direction that industry is investing in, and IMO it's good news for developers and tinkerers who want hackability of each layer of the stack.
I've already heard of some success stories with customizing flash attention kernels via Triton.
from flash-attention-jax.
@jon-chuang yea, let us just agree that we both wish for Triton and the like to succeed so us non-CUDA experts can have control over the entire stack
i just know it isn't there yet.
from flash-attention-jax.
Interestingly, a basic building block for Mamba (associative scan) already has support in Triton: pytorch/pytorch#95408 (comment)
from flash-attention-jax.
it doesn't support multiple inputs. also i heard it is still buggy in its current state
from flash-attention-jax.
@jon-chuang anyways, let us take the discussion elsewhere, as this is about flash attention
from flash-attention-jax.
Related Issues (11)
- Question about calculation of Q and transpose(K).
- Multi-head causal flash attention support? HOT 8
- Slower than non-flash attention HOT 1
- Reshape error in causal_flash_attention when sequence length is not a multiple of 1024
- Online Softmax from FlashAttention HOT 2
- can I work on making a flax attention function out of this repository? HOT 1
- batch & multihead support? HOT 3
- more general mask support HOT 1
- support for per-head scales for cosine sim attention HOT 6
- fix compatibility with jax transformations HOT 28
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flash-attention-jax.