Comments (28)
This is roughly repeating what @dlwh just said, but I just figured it out and came back to explain: this use of custom_vjp
is buggy in that the flash_attention_forward
output needs to be a pair where the first element has the same type as the output of flash_attention
. Yet we can see that where flash_attention
includes three arrays, the first element of the return value of flash_attention_forward
only has one array.
There's a JAX bug in that this was a terrible error message to raise, but the fundamental bug is in that use of custom_vjp
.
from flash-attention-jax.
google/jax#12611 should improve the error message we got here! With the same repro (i.e. before the fix #7 was merged here), the error will be:
TypeError: Custom VJP fwd rule flash_attention_forward for function
flash_attention must produce a pair (list or tuple of length two) where the
first element represents the primal output (equal to the output of the
custom_vjp-decorated function flash_attention) and the second element
represents residuals (i.e. values stored from the forward pass for use on the
backward pass), but instead the fwd rule output's first element had
container/pytree structure:
float32[3,16,5,19]
while the custom_vjp-decorated function flash_attention had output
container/pytree structure:
(float32[3,16,5,19], (float32[3,16,5], float32[3,16,5])).
from flash-attention-jax.
can confirm that this error also appears under jax.lax.scan
example here:
q = jax.random.normal(keys[0], (l, b, lq, h, d))
k = jax.random.normal(keys[1], (l, b, lkv, h, d))
v = jax.random.normal(keys[2], (l, b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (l, b, lkv))
def scan_fn(carry, qkv):
out = flash_attention(*qkv)[0]
carry += out
return carry, out
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(
lambda q, k, v, mask: jnp.sum(
jax.lax.scan(
scan_fn,
jnp.zeros_like(q[0]),
(q, k, v, mask),
)[0],
)
)(q, k, v, mask)
bench_flash_bwd(q, k, v, mask)
from flash-attention-jax.
Thanks for raising this! It looks like a JAX core bug most likely.
Could you provide a self-contained runnable repro, in particular including the import or definition for flash_attention
? (Sorry, I'm not the developer of this repo, so I'm not familiar with that function.)
from flash-attention-jax.
from flash_attention_jax import flash_attention
from flash-attention-jax.
ran into this and failed to upstream. The trick to fix it is to basically do this:
from flash-attention-jax.
@dlwh looks like you also ran an autoformatter so there's a ton of other changes here - can you say a bit more about how you fixed it?
from flash-attention-jax.
Yeah sorry, the line linked is the key one. Basically just rename the method called "causal_flash_attention" to "_causal_flash_attention" and make causal_flash_attention return just the first result. Then make flash_attention_forward
call _causal_flash_attention
instead, and you're done.
@custom_vjp
def causal_flash_attention(q, k, v):
+ return _causal_flash_attention(q, k, v)[0]
+
+
+def _causal_flash_attention(q, k, v):
from flash-attention-jax.
won't that make flash_attention
always do causal masking? I'm using this in a context where that's not appropriate
from flash-attention-jax.
you'll need to make the analogous change to flash_attention
then. as @mattjj said it's really just a buggy use of custom_vjp. (Though despite it not running the code was otherwise correct according to my gradient testing!)
from flash-attention-jax.
Shall I send a PR fix to this repo (maybe you both could review it), and then separately fix the JAX error message? Or @dlwh do you want to send the fix to this repo?
from flash-attention-jax.
I can probably get to it tonight or tomorrow, but I'm about to go dark for several hours. Totally up to you!
from flash-attention-jax.
I'll take the first stab, and cc you!
from flash-attention-jax.
so the relevant fix would be to replace
with return (out, (row_sum, row_max)), (q, k, v, key_mask, out, row_sum, row_max)
?
from flash-attention-jax.
interesting that this works with grad
outside of scan
and remat
- probably it should fail under grad
alone without either of those?
from flash-attention-jax.
@GallagherCommaJack Yes, that'd work! It's probably the simplest fix, though we could also look at the call sites of flash_attention
to see if some other organization would be more natural.
What's a repro for the behavior you're describing? I tried removing jax.checkpoint
from the repro in the OP and I still got an error. That is, this still errors for me:
import jax
import jax.numpy as jnp
from flash_attention_jax import flash_attention
b = 3
lq = 16
lkv = 17
h = 5
d = 19
keys = jax.random.split(jax.random.PRNGKey(0), 4)
q = jax.random.normal(keys[0], (b, lq, h, d))
k = jax.random.normal(keys[1], (b, lkv, h, d))
v = jax.random.normal(keys[2], (b, lkv, h, d))
mask = jax.random.bernoulli(keys[3], 0.5, (b, lkv))
@jax.jit
def bench_flash_bwd(q, k, v, mask):
return jax.grad(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0]))(q)
bench_flash_bwd(q, k, v, mask)
from flash-attention-jax.
Ah, I think it was just a shape bug; if I sent lq = lvk = 16
then I see what you mean.
I think by adding the better JAX error message I described, we'll catch this much earlier and get an error in both cases. I'll be sure to test both with and without checkpoint/scan.
from flash-attention-jax.
Yes, that'd work!
Actually, I think it would not work just because the callers expect only a single output there.
I think the issue here was that the custom_vjp
-decorated function (ie the "primal function") didn't agree with the custom_vjp
rule (i.e. their output types didn't agree in the way that they should), but when we only use grad
(possibly together with jit
) we never actually run the primal function; we only run its forward rule. When grad
is applied, we only actually run the primal function when under a jax.checkpoint
or jax.scan
(or jax.cond
etc); that's just because of a JAX implementation detail (these are "initial-style higher-order primitives") which is usually invisible, except apparently when there's a type error in a custom_vjp
rule!
from flash-attention-jax.
with the fix it's working with lq = lkv
under jax.checkpoint
!
still fails with lq != lkv
which I'm trying to debug now
from flash-attention-jax.
from flash-attention-jax.
the error with lq = 16; lkv = 17
is TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).
full backtrace:
TypeError Traceback (most recent call last)
Cell In [5], line 22
18 @jax.jit
19 def bench_flash_bwd(q, k, v, mask):
20 return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
---> 22 bench_flash_bwd(q, k, v, mask)
[... skipping hidden 14 frame]
Cell In [5], line 20, in bench_flash_bwd(q, k, v, mask)
18 @jax.jit
19 def bench_flash_bwd(q, k, v, mask):
---> 20 return jax.grad(jax.checkpoint(lambda x: jnp.sum(flash_attention(x, k, v, mask)[0])))(q)
[... skipping hidden 30 frame]
File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:172, in flash_attention_backward(res, do)
169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
--> 172 (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
174 dq = rearrange(dq, 'c n b h d -> b h (c n) d')
175 dk, dv = map(lambda t: rearrange(t, 'n b h d -> b h n d'), (dk, dv))
[... skipping hidden 11 frame]
File ~/code/flash-attention-jax/flash_attention_jax/flash_attention.py:170, in flash_attention_backward.<locals>.chunk_scanner(carries, _)
167 do_chunk = lax.dynamic_slice(do, (chunk_idx, batch, heads, 0), slice_sizes = (chunk_sizes, batch, heads, do.shape[-1]))
169 dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
--> 170 return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk
[... skipping hidden 1 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:4658, in _defer_to_unrecognized_arg.<locals>.deferring_binary_op(self, other)
4656 args = (other, self) if swap else (self, other)
4657 if isinstance(other, _accepted_binop_types):
-> 4658 return binary_op(*args)
4659 if isinstance(other, _rejected_binop_types):
4660 raise TypeError(f"unsupported operand type(s) for {opchar}: "
4661 f"{type(args[0]).__name__!r} and {type(args[1]).__name__!r}")
[... skipping hidden 7 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py:84, in _maybe_bool_binop.<locals>.fn(x1, x2)
82 def fn(x1, x2):
83 x1, x2 = _promote_args(numpy_fn.__name__, x1, x2)
---> 84 return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2)
[... skipping hidden 7 frame]
File ~/.cache/pypoetry/virtualenvs/k-diffusion-jax-doaswkte-py3.9/lib/python3.9/site-packages/jax/_src/lax/lax.py:1537, in broadcasting_shape_rule(name, *avals)
1535 result_shape.append(non_1s[0])
1536 else:
-> 1537 raise TypeError(f'{name} got incompatible shapes for broadcasting: '
1538 f'{", ".join(map(str, map(tuple, shapes)))}.')
1540 return tuple(result_shape)
TypeError: add got incompatible shapes for broadcasting: (5, 3, 17, 19), (5, 3, 16, 19).
from flash-attention-jax.
It looks like one of chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk
has a shape error, in flash_attention_backward
. (EDIT: I don't feel comfortable debugging that without learning what this code is actually doing, so hopefully someone who knows the code/algorithm can help!)
from flash-attention-jax.
debugging a bit, it looks like the issue is that dk
has shape h, b, lkv, d
and dk_chunk
has shape h, b, lq, d
from flash-attention-jax.
@lucidrains looks like there's an implicit assumption somewhere in here that lq == lkv
in the backwards pass, in _query_chunk_flash_attention_backward
from flash-attention-jax.
@GallagherCommaJack the fix I proposed in #8 is different from the commit you sent, just FYI.
from flash-attention-jax.
does that work with lq != lkv
?
from flash-attention-jax.
looks like it does not
from flash-attention-jax.
Indeed I think the shape issue is unrelated.
from flash-attention-jax.
Related Issues (11)
- Question about calculation of Q and transpose(K).
- Multi-head causal flash attention support? HOT 8
- Slower than non-flash attention HOT 1
- Reshape error in causal_flash_attention when sequence length is not a multiple of 1024
- Online Softmax from FlashAttention HOT 2
- can I work on making a flax attention function out of this repository? HOT 1
- batch & multihead support? HOT 3
- more general mask support HOT 1
- support for per-head scales for cosine sim attention HOT 6
- Performance benchmarks? HOT 20
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flash-attention-jax.