lucidrains / flash-attention-jax Goto Github PK
View Code? Open in Web Editor NEWImplementation of Flash Attention in Jax
License: MIT License
Implementation of Flash Attention in Jax
License: MIT License
how hard would it be to add in support for leading dimensions (for e.g. batching & multiple heads)?
in my experience vmap
is often less performant than batching by hand.
Are there any benchmark results now? Looking forward to performance comparisons with original attention, and official torch+CUDA implementation.
Thought you might find this interesting. I did some benchmarking of the online softmax algorithm used in the flash attention paper. https://github.com/jenkspt/online-softmax-jax.
TLDR; not reliably faster than the naive softmax.
First off, thanks for writing this. It'd been a substantial improvement, even if the hand written CUDA kernels would've been better.
I've discovered a bug with odd sequence lengths. For e.g. 1025, you get TypeError: reshape total size must be unchanged, got new_sizes (1025, 256, 64) for shape (2, 1024, 256, 64).
with a traceback pointing to causal_flash_attention.py:96
which is this line: out = out.reshape(q_len, bh, v_dim)
. AFAICT the problem occurs whenever your sequence length is greater than 1024 and not a multiple of 1024.
Repro:
import jax.numpy as jnp
from flash_attention_jax import causal_flash_attention
q = k = v = jnp.ones((1, 1, 1025, 16), dtype=jnp.float32)
_ = causal_flash_attention(q, k, v)
That fails, changing 1025 to 1024 works fine.
I tried to compare this implementation with a no bells-and-whistles implementation:
import time
import jax
import jax.numpy as jnp
import numpy as np
from flash_attention_jax import flash_attention
import jax.random
@jax.jit
def jax_attention(q, k, v):
n_seq = q.shape[-2]
logits = jnp.matmul(q, k)
mask = jnp.tril(jnp.ones((1, 1, n_seq, n_seq), dtype=q.dtype))
mask = jnp.broadcast_to(mask, logits.shape)
logits = jnp.where(mask, logits, float('-inf'))
ref_qk = jax.nn.softmax(logits)
return jnp.matmul(ref_qk, v)
BATCH, N_HEADS, N_CTX, D_HEAD = 8, 64, 2000, 64
def bench_jax_flash(batch, heads, seq_len, d_model):
shape = (batch, heads, seq_len, d_model,)
q_jax = jnp.ones(shape, dtype=jnp.float16)
k_jax = jnp.ones(shape, dtype=jnp.float16)
v_jax = jnp.ones(shape, dtype=jnp.float16)
mask = jnp.ones((batch, seq_len), dtype=jnp.int_)
# warmup
print('Warming up...')
flash_attention(q_jax, k_jax, v_jax, mask).block_until_ready()
flash_attention(q_jax, k_jax, v_jax, mask).block_until_ready()
print('Benchmarking...')
t1 = time.time()
num_runs = 100
for _ in range(num_runs):
flash_attention(q_jax, k_jax, q_jax, mask).block_until_ready()
estimate_ms = 1000 * (time.time() - t1) / num_runs
return estimate_ms
print('Flash Jax implementation:')
print(bench_jax_flash(batch=BATCH, heads=N_HEADS, seq_len=N_CTX, d_model=D_HEAD))
def bench_jax(batch, heads, seq_len, d_model):
q_jax = jnp.ones((batch, heads, seq_len, d_model), dtype=jnp.float16)
k_jax = jnp.ones((batch, heads, d_model, seq_len), dtype=jnp.float16)
v_jax = jnp.ones((batch, heads, seq_len, d_model), dtype=jnp.float16)
# warmup
print('Warming up...')
jax_attention(q_jax, k_jax, q_jax).block_until_ready()
jax_attention(q_jax, k_jax, q_jax).block_until_ready()
print('Benchmarking...')
t1 = time.time()
num_runs = 100
for _ in range(num_runs):
jax_attention(q_jax, k_jax, q_jax).block_until_ready()
estimate_ms = 1000 * (time.time() - t1) / num_runs
return estimate_ms
print('Jax implementation:')
print(bench_jax(batch=BATCH, heads=N_HEADS, seq_len=N_CTX, d_model=D_HEAD))
Output:
Flash Jax implementation:
Warming up...
Benchmarking...
51.73063278198242
Jax implementation:
Warming up...
Benchmarking...
32.94001817703247
usually with cosine-sim models I'd train with learned per-head scales for the attention logits, I guess I can get this from multiplying by q
& k
by sqrt(scales)
before the dot product but that's probably less stable
Thanks for your effort to make this great platform.
In normal attention, the input of softmax function is a form of matmul(Q,K_T) and its dimension is (batch, num_heads, q_len, k_len)
Also, the attention mask is like a trigonal shape (total shape is could be q_len x k_len)
so, matmul(q, k_t) is masked with the attention mask.
However, I don't understand how matmul(q_chunk, transposed k_chunk) works and results in masked input of softmax compared with original attention algorithm flow at the code lines below.
flash-attention-jax/flash_attention_jax/flash_attention.py
Lines 34 to 37 in 5727815
Can you explain it with details?
the general case of attention is (using annotations from jaxtyping)
q: Float["lq d"]
k: Float["lkv d"]
v: Float["lkv o"]
mask: Bool["lq lkv"]
returns: Float["lq o"]
but it looks like right now this library only supports a 1 dimensional mask?
Hello,
Does the implementation of causal flash attention support multi-head?
It seems not because the shape of query and key are (q_len, q_dim) and (k_len, k_dim).
currently impossible to use flash_attention
within a function that will use gradient checkpointing
minimal example to reproduce:
b = 3
lq = 16
lkv = 17
h = 5
d = 19
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
fails with error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in <cell line: 1>()
----> [1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) get_ipython().run_line_magic('timeit', 'bench_flash_bwd(q, k, v, mask).block_until_ready()')
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/interactiveshell.py:2305, in InteractiveShell.run_line_magic(self, magic_name, line, _stack_depth)
2303 kwargs['local_ns'] = self.get_local_scope(stack_depth)
2304 with self.builtin_trap:
-> 2305 result = fn(*args, **kwargs)
2306 return result
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:1162, in ExecutionMagics.timeit(self, line, cell, local_ns)
1160 for index in range(0, 10):
1161 number = 10 ** index
-> 1162 time_number = timer.timeit(number)
1163 if time_number >= 0.2:
1164 break
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/IPython/core/magics/execution.py:156, in Timer.timeit(self, number)
154 gc.disable()
155 try:
--> 156 timing = self.inner(it, self.timer)
157 finally:
158 if gcold:
File <magic-timeit>:1, in inner(_it, _timer)
[... skipping hidden 14 frame]
/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb Cell 9 in bench_flash_bwd(q, k, v, mask)
[1](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) @jax.jit
[2](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) def bench_flash_bwd(q, k, v, mask):
----> [3](vscode-notebook-cell://ssh-remote%2Bjackcpu/home/jack/code/k-diffusion-jax/misc/test_flash.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D?line=2) return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]), policy=jax.checkpoint_policies.everything_saveable))(q)
[... skipping hidden 25 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/util.py:48, in safe_map(f, *args)
46 n = len(args[0])
47 for arg in args[1:]:
---> 48 assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
49 return list(map(f, *args))
AssertionError: length mismatch: [3, 1]
Hi lucidrain!
I wanted to use flash attention in one of my projects. I wanted a transformer model that works on sequences as long as 2400 with a batch size of 1000. The original flash attention does not fit in the memory for me. I wanted to use flash attention and found your implementation.
However, I found out I cannot just pass your attention implementation to flax.MultiHeadDotProductAttention here because there the attention_fn needs to be multiheaded, accept mask, dropout_rate, etc.
I was wondering if I could use your flash attention building block and add the required capabilities to it. I am not familiar with flash attention implementation but I am familiar with jax and flax. I was wondering if it is doable without understanding the underlying flash attention. If you think it is possible I can work on it and then create a pull request.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.