GithubHelp home page GithubHelp logo

Comments (20)

jon-chuang avatar jon-chuang commented on May 30, 2024 2

I think these newish attention replacements will take time to be adopted particularly because the dust has not settled on them and it takes a while for wide-scale experimentation and large-scale training with them to truly prove them out.

IMO all it takes is a leap for a highly-funded industrial lab to go out on a limb and train an LLM with one of these...

For instance, Mistral AI essentially has a linear cost attention mechanism based on SWA - sliding window attention - one could argue of course how effective it is at truly capturing information across long context.

all these frameworks cannot do.

I think this is an overstatement? I think it simply has not been tried out in Triton yet. But it should not be that hard. But whether the performance matches is an open question.

I just hope that more devs become aware of how powerful triton is so that there's more experimentation with implementing these kernels.

from flash-attention-jax.

SamuelGabriel avatar SamuelGabriel commented on May 30, 2024 1

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

from flash-attention-jax.

evanatyourservice avatar evanatyourservice commented on May 30, 2024 1

@niemiaszek I just recently saw they named and added docs for pallas, looks very interesting. JAX is also improving our ability to customize how networks are sharded across accelerators and are publishing papers on their results wrt efficiency, pretty cool I think. Unfortunately I don't have time to do a fair comparison between torch and jax with attention but it seems that whoever takes the time to delve into it, especially jax's recent improvements, would certainly benefit if they have a need.

Even if we don't take the time, it looks like the jax team continually adds their efficiency findings into jax as defaults so we don't have to implement ourselves.

from flash-attention-jax.

evanatyourservice avatar evanatyourservice commented on May 30, 2024 1

@lucidrains I'd agree as far as single-device optimizations go. I solely use jax because my work deals mainly with RL and I've already built everything out, but for things like language and vision models, resources like xformers are hard to beat. I do like jax's work toward multi-device customization especially from an RL perspective.

from flash-attention-jax.

jakubMitura14 avatar jakubMitura14 commented on May 30, 2024

I am also curious, additionally maybe it is possible to use cuda code with jax ?

https://github.com/dfm/extending-jax

from flash-attention-jax.

OhadRubin avatar OhadRubin commented on May 30, 2024

https://colab.research.google.com/drive/1-YCU9ps4gNuROJ3_8MLjSpbICGHaySxh?usp=sharing

from flash-attention-jax.

jakubMitura14 avatar jakubMitura14 commented on May 30, 2024

Fantastic! have you done experiment with the same data on original flash attention ?

from flash-attention-jax.

OhadRubin avatar OhadRubin commented on May 30, 2024

Not yet

from flash-attention-jax.

jon-chuang avatar jon-chuang commented on May 30, 2024

Hello, could I ask if this works with TPUs?

from flash-attention-jax.

evanatyourservice avatar evanatyourservice commented on May 30, 2024

Here's an updated notebook that precompiles jit and blocks results until ready for anyone interested:

https://colab.research.google.com/drive/11QKRdgMtcivrJNmjTrf2bXTE5yXkXl_Z?usp=sharing

Looks like JAX compiles vanilla attention in a way to be faster than jax flash attention, so no need to change to flash attention if you use JAX.

from flash-attention-jax.

niemiaszek avatar niemiaszek commented on May 30, 2024

Wow this is open from almost a year ago...

I think someone could get a lot citations / clicks if they did a proper benchmark of transformer train/inference across platforms torch/jax GPU/TPU with standard tricks. I would cite you straight away for sure. It would also just be nice to settle this dispute that Google employees seem to have with everyone else, whether or not Jax is meaningfully more efficient or not (might just be up to TPU or GPU!?).

Would be definitely nice to see such benchmark, but I can imagine how hard is comparing JAX vs PyTorch (GPU/TPU), with many optimized implementations for each device. For PyTorch with GPU we have Triton/CUDA, but JAX recently has also added Triton-like mechanism for writing custom Kernels with GPU/TPU - Pallas. You can even find implementation of attention in it here.

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

from what i've heard, flash attention doesn't work well on TPUs, but i haven't kept up with the latest iteration of their chip design.

Pallas is just a wrapper around Triton, developed at OpenAI for GPUs. you will basically be always limited by what the Triton compiler can do

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

from flash-attention-jax.

jon-chuang avatar jon-chuang commented on May 30, 2024

while this is a hard pill to swallow, i think existence of flash attention is a clear victory for having finely controlled GPGPU programming.

Well, I would argue that in this day, that's no longer such a hard pill given the wide adoption of tiled programming paradigm like Triton (e.g. PyTorch - both codegen + incoming custom kernels, JAX - e.g. Pallas, hardware vendors including NVIDIA, AMD, Intel) which greatly reduces the effort and complexity of getting SOTA perf on GPUs.

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

@jon-chuang hmm, still a bit early to declare that imho

we'll see, i hope so!

from flash-attention-jax.

jon-chuang avatar jon-chuang commented on May 30, 2024

Yes, Triton is still not 100% (some matmul kernel size and certain kernels like flash attention backwards are still not SOTA). But it's certainly the direction that industry is investing in, and IMO it's good news for developers and tinkerers who want hackability of each layer of the stack.

I've already heard of some success stories with customizing flash attention kernels via Triton.

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

@jon-chuang yea, let us just agree that we both wish for Triton and the like to succeed so us non-CUDA experts can have control over the entire stack

i just know it isn't there yet.

from flash-attention-jax.

jon-chuang avatar jon-chuang commented on May 30, 2024

Interestingly, a basic building block for Mamba (associative scan) already has support in Triton: pytorch/pytorch#95408 (comment)

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

it doesn't support multiple inputs. also i heard it is still buggy in its current state

from flash-attention-jax.

lucidrains avatar lucidrains commented on May 30, 2024

@jon-chuang anyways, let us take the discussion elsewhere, as this is about flash attention

from flash-attention-jax.

Related Issues (11)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.