GithubHelp home page GithubHelp logo

lucidrains / flash-cosine-sim-attention Goto Github PK

View Code? Open in Web Editor NEW
192.0 12.0 9.0 35.21 MB

Implementation of fused cosine similarity attention in the same style as Flash Attention

License: MIT License

Python 32.56% Cuda 63.01% Makefile 0.67% C++ 3.76%
artificial-intelligence attention-mechanisms deep-learning

flash-cosine-sim-attention's Introduction

Dive into Deep Learning, redone by Quanta Magazine

Flash Cosine Similarity Attention

Implementation of fused cosine similarity attention in the same style as Flash Attention. The observation is that by adopting l2 normalized queries and keys, you no longer need to keep track of the row maximums for numerical stability. This greatly simplifies the flash attention algorithm, assuming cosine similarity attention comes at no generalization cost.

In other words, stable, fast, memory efficient, and longer context attention with no downsides.

Update: Unfortunately, Robin's experiments showed much worse evaluation FID scores not reflected in the loss. Pending more experiments. Use this library with caution.

Update 2: The only saving grace would be to use grouped l2norm, which could potentially allow for more expressivity. If anyone can evaluate this technique on their generative work and obtain some FID scores, would be much appreciated.

Update 3: An approach similar to cosine sim attention has been proven at scale, with a 22B parameter vision model from Brain.

Status (wip)

At the moment, autoregressive and variable lengthed sequences should be faster across all architectures. For sequences longer than 2048, it will also be memory efficient where regular attention would not.

However, for non-autoregressive without masking, the architecture is still slower on A100 for F16. The aim is to get it to perform faster on A100 forwards and backwards for both F32 and F16, as shared memory is not fully exploited yet.

Older graphic cards without enough shared memory, one will have to gauge the tradeoff of memory efficiency and speed depending on the sequence length being trained at.

Appreciation

  • Arthur Hennequin for coaching me through my first CUDA kernel, and for coding up a simple reference implementation, which helped me to bootstrap the first kernel that comes within reasonable performance to baseline. This work would not have been possible without his expertise.

  • Boris Dayma and Robin Rombach for running experiments the simplified cosine sim attention with fixed scaling on some significant text-to-image models and verifying that it indeeds perform just as well as regular attention.

  • Markus Rabe for penning the paper that showed attention does not require O(nยฒ) memory, and Tri Dao for putting it all together in a CUDA kernel implementation for regular attention, demonstrating superiority in speed using the tiled approach minimizing HBM accesses (and for figuring out dO * O == dP * P for backwards pass). Would not have been able to complete my pilgrimage looking for the ultimate attention formulation without their discoveries.

  • Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research

Install

$ pip install flash-cosine-sim-attention

Usage

Self Attention

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 1024, 64).cuda()
v = torch.randn(1, 8, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v)  # (1, 8, 1024, 64)

Cross attention

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v) # (1, 8, 1024, 64)

With key / value masking

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

mask = torch.ones(1, 2048).bool().cuda()

out = flash_cosine_sim_attention(q, k, v, mask = mask) # (1, 8, 1024, 64)

Autoregressive

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 8, 1024, 64).cuda()
v = torch.randn(4, 8, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True)  # (4, 8, 1024, 64)

Miscellaneous

Single-headed key / values (Shazeer et al & used in PaLM)

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True)  # (4, 8, 1024, 64)

If you need to do operations on the queries and keys in between the l2norm and the actual attention step, just set l2norm_qk = False

ex.

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention, l2norm_tensors

q = torch.randn(4, 8, 1024, 64).cuda()
k = torch.randn(4, 1024, 64).cuda()
v = torch.randn(4, 1024, 64).cuda()

q, k = l2norm_tensors(q, k)

# do your rotation of queries and keys
# say with https://github.com/lucidrains/rotary-embedding-torch

out = flash_cosine_sim_attention(q, k, v, l2norm_qk = False)  # (4, 8, 1024, 64)

Cross attention with causal works as expected - (caching of keys and values in autoregressive during inference, or transformer-xl like training)

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(1, 8, 1024, 64).cuda()
k = torch.randn(1, 8, 2048, 64).cuda()
v = torch.randn(1, 8, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True) # (1, 8, 1024, 64)

If you have batch and head dimensions merged, that is ok

import torch
from flash_cosine_sim_attention import flash_cosine_sim_attention

q = torch.randn(32, 1024, 64).cuda()
k = torch.randn(32, 2048, 64).cuda()
v = torch.randn(32, 2048, 64).cuda()

out = flash_cosine_sim_attention(q, k, v, causal = True) # (32, 1024, 64)

Supported head dimensions

  • 16 - f32

  • 32

  • 64

  • 96

  • 128

  • 16 -f16

  • 80 - in progress

Todo

  • bfloat16 support, use sfinae as recommended by Arthur

  • stream from qk_mma to shared memory in chunks to calculate out mma, see if freed smem can be used for caching more

  • support O(n) 1d dynamic positional bias

  • figure out why smem fragment caching would lead to performance degrade, it does not make sense

  • think about use of logsumexp - works but extra log lead to degraded perf

  • prepare a smem fragment caching mechanism, to allow for as much caching as allowed on A100 (or f16)

  • make attention tile size processing customizable for backwards pass

  • move atomic add to overloaded function inside mma

  • flexible which type is used for accumulation

  • test out 64x96 tiles on f16

  • bring in a CPU memory efficient version (only for inference, as training does not make sense) using just plain pytorch code

  • figure out how to dispatch differently for architectures (say A100), in case backwards can make use of the increase in shared memory differently

  • decouple row and column sizes for attention tiles

  • dk and dv are now in f16 when it can be (non single headed kv)

  • support more standard head dimensions (wip)

  • debug and fix bias backwards gradients yet again for head size of 32

  • fix attention bias gradients

  • allow for single-headed key / values, as in PaLM

  • fix atomic add for f16

  • attention bias should be able to accept dimensions of an extra batch dimension, for Alphafold2 like attention biasing

  • automate cache-busting of kernel using version as suffix to package name

  • resolve f16 causal numerical issues

  • adopt all learnings from forward kernel to backwards kernel and make sure it outperforms at least on A100

Description

So far cosine similarity attention is not widely used in industry. The only large model that has been trained with it so far is SwinV2. If anyone can invalidate the approach, please open an issue or send me an email. You can run experiments against regular attention using the x-transformers repository.

Update: Boris Dayma has graciously kicked off an experiment (blue with red as baseline) to validate cosine similarity attention with a fixed scale of 10 in a real-world model setting. ๐Ÿ™

Update 2: Cosine similarity attention has been proven out in a real-world text-to-image attention network, using a constant scale of 10. No worse than regular attention. Credit goes to Boris Dayma for investing the time to run the experiment and removing doubts surrounding the technique.

Update 3: Robin Rombach has tested out the kernel in this repository with head size of 64 and fixed scale of 10 in a text-to-image model, observing no difference from regular attention. More evaluations pending.

Update 4: The improvement in performance seen in Boris' experiments are likely due to the fact that cosine-sim attention allows for one to switch from pre layernorm to post layernorm configuration in the transformers (as the l2norm effectively takes the place of the pre-layernorm). Cosine sim attention will likely yield results the same as regular attention, without any other changes to the transformer.

Testing

For testing output and gradients are equal for non-autoregressive and autoregressive scenarios

$ python setup.py test

Benchmarking

Make sure to first install the CUDA kernel

$ python setup.py install

Then

$ python benchmark.py

For only benchmarking forwards or backwards, append either --only-forwards or --only-backwards flag to the above. To benchmark autoregressive, append --causal

Benchmarks - wip

GTX 2080 Ti

Forward

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 1.05x   kernel: 0.24ms  baseline: 0.23ms
seq_len: 256    slower: 1.27x   kernel: 0.38ms  baseline: 0.30ms
seq_len: 512    slower: 1.28x   kernel: 0.87ms  baseline: 0.68ms
seq_len: 1024   slower: 1.15x   kernel: 2.63ms  baseline: 2.28ms
seq_len: 2048   slower: 0.99x   kernel: 7.99ms  baseline: 8.10ms
seq_len: 4096   slower: 0.88x   kernel: 30.82ms baseline: 34.84ms
seq_len: 8192   slower: 0.00x   kernel: 121.96ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.85x   kernel: 0.20ms  baseline: 0.24ms
seq_len: 256    slower: 0.97x   kernel: 0.24ms  baseline: 0.25ms
seq_len: 512    slower: 1.22x   kernel: 0.43ms  baseline: 0.35ms
seq_len: 1024   slower: 0.95x   kernel: 0.93ms  baseline: 0.98ms
seq_len: 2048   slower: 0.90x   kernel: 3.16ms  baseline: 3.50ms
seq_len: 4096   slower: 0.85x   kernel: 11.06ms baseline: 13.07ms
seq_len: 8192   slower: 0.00x   kernel: 42.61ms baseline: oom

Backwards - still needs work

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 1.07x   kernel: 0.61ms  baseline: 0.57ms
seq_len: 256    slower: 1.40x   kernel: 0.91ms  baseline: 0.65ms
seq_len: 512    slower: 1.70x   kernel: 2.34ms  baseline: 1.38ms
seq_len: 1024   slower: 1.26x   kernel: 5.67ms  baseline: 4.50ms
seq_len: 2048   slower: 1.29x   kernel: 20.60ms baseline: 15.91ms
seq_len: 4096   slower: 1.30x   kernel: 78.93ms baseline: 60.81ms
seq_len: 8192   slower: 0.00x   kernel: 314.51ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.91x   kernel: 0.50ms  baseline: 0.55ms
seq_len: 256    slower: 1.06x   kernel: 0.58ms  baseline: 0.55ms
seq_len: 512    slower: 1.13x   kernel: 0.81ms  baseline: 0.72ms
seq_len: 1024   slower: 0.97x   kernel: 2.09ms  baseline: 2.16ms
seq_len: 2048   slower: 0.96x   kernel: 7.06ms  baseline: 7.35ms
seq_len: 4096   slower: 0.97x   kernel: 26.08ms baseline: 26.84ms
seq_len: 8192   slower: 0.00x   kernel: 101.02ms    baseline: oom

Forward & Backwards - F32 is definitely slower

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 1.05x   kernel: 0.83ms  baseline: 0.79ms
seq_len: 256    slower: 1.34x   kernel: 1.26ms  baseline: 0.95ms
seq_len: 512    slower: 1.44x   kernel: 3.14ms  baseline: 2.18ms
seq_len: 1024   slower: 1.15x   kernel: 7.83ms  baseline: 6.81ms
seq_len: 2048   slower: 1.20x   kernel: 28.83ms baseline: 24.03ms
seq_len: 4096   slower: 1.20x   kernel: 111.13ms    baseline: 92.51ms
seq_len: 8192   slower: 0.00x   kernel: 441.70ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.89x   kernel: 0.68ms  baseline: 0.77ms
seq_len: 256    slower: 1.03x   kernel: 0.80ms  baseline: 0.77ms
seq_len: 512    slower: 1.06x   kernel: 1.16ms  baseline: 1.10ms
seq_len: 1024   slower: 0.93x   kernel: 2.94ms  baseline: 3.16ms
seq_len: 2048   slower: 0.93x   kernel: 10.06ms baseline: 10.87ms
seq_len: 4096   slower: 0.93x   kernel: 37.09ms baseline: 39.96ms
seq_len: 8192   slower: 0.00x   kernel: 143.13ms    baseline: oom

For autoregressive, a clear win python benchmark.py --causal

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.97x   kernel: 0.81ms  baseline: 0.84ms
seq_len: 256    slower: 1.07x   kernel: 1.12ms  baseline: 1.05ms
seq_len: 512    slower: 0.83x   kernel: 2.23ms  baseline: 2.68ms
seq_len: 1024   slower: 0.55x   kernel: 4.83ms  baseline: 8.82ms
seq_len: 2048   slower: 0.49x   kernel: 15.89ms baseline: 32.68ms
seq_len: 4096   slower: 0.46x   kernel: 57.50ms baseline: 126.00ms
seq_len: 8192   slower: 0.00x   kernel: 224.76ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.69ms  baseline: 0.84ms
seq_len: 256    slower: 0.95x   kernel: 0.79ms  baseline: 0.83ms
seq_len: 512    slower: 0.78x   kernel: 1.06ms  baseline: 1.37ms
seq_len: 1024   slower: 0.50x   kernel: 2.10ms  baseline: 4.24ms
seq_len: 2048   slower: 0.37x   kernel: 5.85ms  baseline: 15.92ms
seq_len: 4096   slower: 0.31x   kernel: 19.80ms baseline: 64.42ms
seq_len: 8192   slower: 0.00x   kernel: 75.25ms baseline: oom

For variable length sequences with masking, also a clear win. Assume on average 25% of tokens masked out python benchmark.py --mask-prob 0.25

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.95x   kernel: 0.84ms  baseline: 0.89ms
seq_len: 256    slower: 1.19x   kernel: 1.28ms  baseline: 1.08ms
seq_len: 512    slower: 1.23x   kernel: 3.19ms  baseline: 2.59ms
seq_len: 1024   slower: 0.92x   kernel: 8.19ms  baseline: 8.88ms
seq_len: 2048   slower: 0.92x   kernel: 30.08ms baseline: 32.57ms
seq_len: 4096   slower: 0.94x   kernel: 123.20ms    baseline: 131.22ms
seq_len: 8192   slower: 0.00x   kernel: 461.77ms    baseline: oom
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.85x   kernel: 0.77ms  baseline: 0.90ms
seq_len: 256    slower: 0.93x   kernel: 0.86ms  baseline: 0.93ms
seq_len: 512    slower: 0.93x   kernel: 1.31ms  baseline: 1.40ms
seq_len: 1024   slower: 0.76x   kernel: 3.31ms  baseline: 4.35ms
seq_len: 2048   slower: 0.71x   kernel: 11.19ms baseline: 15.65ms
seq_len: 4096   slower: 0.70x   kernel: 41.27ms baseline: 59.01ms
seq_len: 8192   slower: 0.00x   kernel: 158.60ms    baseline: oom

A100 40GB (wip)

Thanks goes out to Stability for providing access to A100s for testing. Thanks to Enrico for taking the time to run some benchmarks when I had no access yet.

A100 is still a work in progress. Shared memory is not fully exploited yet. Strangely enough, F32 seems to be doing better than F16

Forwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.98x   kernel: 0.29ms  baseline: 0.30ms
seq_len: 256    slower: 1.19x   kernel: 0.35ms  baseline: 0.29ms
seq_len: 512    slower: 0.94x   kernel: 0.52ms  baseline: 0.55ms
seq_len: 1024   slower: 0.75x   kernel: 1.23ms  baseline: 1.65ms
seq_len: 2048   slower: 0.88x   kernel: 4.17ms  baseline: 4.73ms
seq_len: 4096   slower: 0.79x   kernel: 14.53ms baseline: 18.36ms
seq_len: 8192   slower: 0.64x   kernel: 55.01ms baseline: 85.93ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.84x   kernel: 0.24ms  baseline: 0.29ms
seq_len: 256    slower: 1.02x   kernel: 0.29ms  baseline: 0.29ms
seq_len: 512    slower: 1.24x   kernel: 0.36ms  baseline: 0.29ms
seq_len: 1024   slower: 1.48x   kernel: 0.79ms  baseline: 0.54ms
seq_len: 2048   slower: 1.31x   kernel: 2.08ms  baseline: 1.59ms
seq_len: 4096   slower: 1.21x   kernel: 6.89ms  baseline: 5.70ms
seq_len: 8192   slower: 1.07x   kernel: 24.80ms baseline: 23.15ms

Backwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.94x   kernel: 0.57ms  baseline: 0.60ms
seq_len: 256    slower: 1.29x   kernel: 0.75ms  baseline: 0.58ms
seq_len: 512    slower: 1.16x   kernel: 1.30ms  baseline: 1.12ms
seq_len: 1024   slower: 0.98x   kernel: 3.14ms  baseline: 3.19ms
seq_len: 2048   slower: 1.05x   kernel: 11.13ms baseline: 10.63ms
seq_len: 4096   slower: 0.98x   kernel: 40.11ms baseline: 40.79ms
seq_len: 8192   slower: 0.97x   kernel: 154.96ms    baseline: 159.70ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.91x   kernel: 0.55ms  baseline: 0.60ms
seq_len: 256    slower: 1.03x   kernel: 0.62ms  baseline: 0.60ms
seq_len: 512    slower: 1.36x   kernel: 0.82ms  baseline: 0.60ms
seq_len: 1024   slower: 1.52x   kernel: 1.52ms  baseline: 1.01ms
seq_len: 2048   slower: 1.37x   kernel: 4.14ms  baseline: 3.03ms
seq_len: 4096   slower: 1.33x   kernel: 14.23ms baseline: 10.71ms
seq_len: 8192   slower: 1.34x   kernel: 53.90ms baseline: 40.28ms

Forwards & Backwards

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.92x   kernel: 0.80ms  baseline: 0.87ms
seq_len: 256    slower: 1.23x   kernel: 1.07ms  baseline: 0.87ms
seq_len: 512    slower: 1.08x   kernel: 1.80ms  baseline: 1.66ms
seq_len: 1024   slower: 0.94x   kernel: 4.33ms  baseline: 4.62ms
seq_len: 2048   slower: 0.99x   kernel: 15.26ms baseline: 15.44ms
seq_len: 4096   slower: 0.93x   kernel: 54.78ms baseline: 59.21ms
seq_len: 8192   slower: 0.91x   kernel: 210.38ms    baseline: 230.97ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64
------------------------------------------------------------
seq_len: 128    slower: 0.90x   kernel: 0.78ms  baseline: 0.86ms
seq_len: 256    slower: 1.00x   kernel: 0.87ms  baseline: 0.87ms
seq_len: 512    slower: 1.36x   kernel: 1.18ms  baseline: 0.86ms
seq_len: 1024   slower: 1.49x   kernel: 2.31ms  baseline: 1.55ms
seq_len: 2048   slower: 1.33x   kernel: 6.17ms  baseline: 4.63ms
seq_len: 4096   slower: 1.28x   kernel: 21.08ms baseline: 16.44ms
seq_len: 8192   slower: 1.24x   kernel: 78.75ms baseline: 63.45ms

Autoregressive

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.82ms  baseline: 1.01ms
seq_len: 256    slower: 1.02x   kernel: 1.00ms  baseline: 0.98ms
seq_len: 512    slower: 0.82x   kernel: 1.55ms  baseline: 1.89ms
seq_len: 1024   slower: 0.51x   kernel: 2.79ms  baseline: 5.44ms
seq_len: 2048   slower: 0.45x   kernel: 8.37ms  baseline: 18.67ms
seq_len: 4096   slower: 0.40x   kernel: 29.16ms baseline: 72.97ms
seq_len: 8192   slower: 0.38x   kernel: 108.68ms    baseline: 285.47ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.82x   kernel: 0.81ms  baseline: 0.98ms
seq_len: 256    slower: 0.90x   kernel: 0.88ms  baseline: 0.98ms
seq_len: 512    slower: 1.16x   kernel: 1.13ms  baseline: 0.97ms
seq_len: 1024   slower: 0.80x   kernel: 1.68ms  baseline: 2.10ms
seq_len: 2048   slower: 0.54x   kernel: 3.66ms  baseline: 6.81ms
seq_len: 4096   slower: 0.45x   kernel: 11.43ms baseline: 25.32ms
seq_len: 8192   slower: 0.41x   kernel: 40.58ms baseline: 99.14ms

Variable lengthed sequences (up to 25% tokens masked out)

------------------------------------------------------------
float32     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.80x   kernel: 0.85ms  baseline: 1.07ms
seq_len: 256    slower: 1.07x   kernel: 1.15ms  baseline: 1.08ms
seq_len: 512    slower: 1.00x   kernel: 1.94ms  baseline: 1.94ms
seq_len: 1024   slower: 0.84x   kernel: 4.64ms  baseline: 5.55ms
seq_len: 2048   slower: 0.84x   kernel: 15.86ms baseline: 18.86ms
seq_len: 4096   slower: 0.76x   kernel: 55.19ms baseline: 72.47ms
seq_len: 8192   slower: 0.75x   kernel: 212.48ms    baseline: 282.71ms
------------------------------------------------------------
float16     batch: 4    heads: 8    dim 64  
------------------------------------------------------------
seq_len: 128    slower: 0.80x   kernel: 0.83ms  baseline: 1.04ms
seq_len: 256    slower: 0.90x   kernel: 0.93ms  baseline: 1.03ms
seq_len: 512    slower: 1.18x   kernel: 1.22ms  baseline: 1.04ms
seq_len: 1024   slower: 1.10x   kernel: 2.40ms  baseline: 2.17ms
seq_len: 2048   slower: 0.89x   kernel: 6.27ms  baseline: 7.06ms
seq_len: 4096   slower: 0.82x   kernel: 21.19ms baseline: 25.95ms
seq_len: 8192   slower: 0.78x   kernel: 79.45ms baseline: 101.83ms

Training a small GPT on Enwik8

$ make train

Try 8192 sequence length. It'll be slow but will work (normal attention will break at > 2048, you'll see this if you remove the --use-cuda-kernel flag)

$ python train.py --seq-len 8192 --use-cuda-kernel

Citations

@article{Dao2022FlashAttentionFA,
    title   = {FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},
    author  = {Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2205.14135}
}
@misc{rabe2021selfattention,
    title   = {Self-attention Does Not Need $O(n^2)$ Memory}, 
    author  = {Markus N. Rabe and Charles Staats},
    year    = {2021},
    eprint  = {2112.05682},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@inproceedings{Henry2020QueryKeyNF,
    title   = {Query-Key Normalization for Transformers},
    author  = {Alex Henry and Prudhvi Raj Dachapally and Shubham Vivek Pawar and Yuxuan Chen},
    booktitle = {FINDINGS},
    year    = {2020}
}
@article{Wang2022DeepNetST,
    title   = {DeepNet: Scaling Transformers to 1, 000 Layers},
    author  = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Dongdong Zhang and Furu Wei},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2203.00555}
}

flash-cosine-sim-attention's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flash-cosine-sim-attention's Issues

failed building wheel for flash-cosine-sim-attention

Hi, package currently getting errors when building both local (100 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".) and wheel.

Wheel build error logs:

Collecting flash-cosine-sim-attention
  Downloading flash-cosine-sim-attention-0.1.40.tar.gz (25 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from flash-cosine-sim-attention) (2.0.1+cu118)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (4.6.3)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (16.0.6)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->flash-cosine-sim-attention) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->flash-cosine-sim-attention) (1.3.0)
Building wheels for collected packages: flash-cosine-sim-attention
  error: subprocess-exited-with-error
  
  ร— python setup.py bdist_wheel did not run successfully.
  โ”‚ exit code: 1
  โ•ฐโ”€> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for flash-cosine-sim-attention (setup.py) ... error
  ERROR: Failed building wheel for flash-cosine-sim-attention
  Running setup.py clean for flash-cosine-sim-attention
Failed to build flash-cosine-sim-attention

Support head dimension 16๏ผŸ

Hello and thanks for your work.
Can flash-cosine-sim-attention support head dimension 16๏ผŸ32 is too big for my model, so I wonder do you have any plans to support head dimension 16 as flash-attention didi?

make install fails

$ make install
python setup.py install --user
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
  warnings.warn(
running install
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
running bdist_egg
running egg_info
writing flash_cosine_sim_attention.egg-info/PKG-INFO
writing dependency_links to flash_cosine_sim_attention.egg-info/dependency_links.txt
writing requirements to flash_cosine_sim_attention.egg-info/requires.txt
writing top-level names to flash_cosine_sim_attention.egg-info/top_level.txt
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/utils/cpp_extension.py:472: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
  warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-cpython-39
creating build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/__init__.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/benchmark.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/transformer.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/flash_cosine_sim_attention.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
creating build/lib.linux-x86_64-cpython-39/tests
copying tests/__init__.py -> build/lib.linux-x86_64-cpython-39/tests
copying tests/test.py -> build/lib.linux-x86_64-cpython-39/tests
copying flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
running build_ext
building 'flash_cosine_sim_attention_cuda' extension
creating build/temp.linux-x86_64-cpython-39
creating build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention
/usr/local/cuda/bin/nvcc -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/TH -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/antor/anaconda3/envs/open/include/python3.9 -c flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -o build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=flash_cosine_sim_attention_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
            argument types are: (c10::Half *, float)
          detected during:
            instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=2, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1016): here
            instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here

flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
            argument types are: (c10::Half *, float)
          detected during:
            instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=4, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1051): here
            instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here

2 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".
error: command '/usr/local/cuda/bin/nvcc' failed with exit code 1
make: *** [Makefile:3: install] Error 1

Enviroment

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0
$ python -c 'import torch;print(torch.__version__)'
1.13.0.dev20220918

GPU Benchmarks

Hi Phil,

Firstly, Thank you for the amazing work yet again!

I was wondering if you had done any benchmarking with mid-tier GPUs. I ran the benchmarks on my local system with a few RTX 3090s and received these results:

python3 benchmark.py --only-forwards

float32 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.96x kernel: 0.23ms baseline: 0.24ms
seq_len: 256 slower: 1.32x kernel: 0.38ms baseline: 0.28ms
seq_len: 512 slower: 1.85x kernel: 0.82ms baseline: 0.44ms
seq_len: 1024 slower: 1.57x kernel: 2.15ms baseline: 1.37ms
seq_len: 2048 slower: 1.17x kernel: 5.94ms baseline: 5.06ms
seq_len: 4096 slower: 1.20x kernel: 22.70ms baseline: 18.84ms
seq_len: 8192 slower: 0.00x kernel: 90.47ms baseline: oom

float16 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.72x kernel: 0.19ms baseline: 0.26ms
seq_len: 256 slower: 1.04x kernel: 0.24ms baseline: 0.23ms
seq_len: 512 slower: 1.04x kernel: 0.30ms baseline: 0.29ms
seq_len: 1024 slower: 1.00x kernel: 0.70ms baseline: 0.70ms
seq_len: 2048 slower: 0.71x kernel: 1.83ms baseline: 2.59ms
seq_len: 4096 slower: 0.67x kernel: 6.23ms baseline: 9.36ms
seq_len: 8192 slower: 0.65x kernel: 23.78ms baseline: 36.45ms**

python3 benchmark.py --only-backwards

float32 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.96x kernel: 0.55ms baseline: 0.57ms
seq_len: 256 slower: 1.76x kernel: 0.89ms baseline: 0.50ms
seq_len: 512 slower: 2.18x kernel: 2.09ms baseline: 0.96ms
seq_len: 1024 slower: 1.83x kernel: 5.16ms baseline: 2.82ms
seq_len: 2048 slower: 1.74x kernel: 17.56ms baseline: 10.12ms
seq_len: 4096 slower: 1.71x kernel: 64.56ms baseline: 37.74ms
seq_len: 8192 slower: 0.00x kernel: 250.87ms baseline: oom

float16 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.92x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.60ms baseline: 0.58ms
seq_len: 512 slower: 1.54x kernel: 0.89ms baseline: 0.58ms
seq_len: 1024 slower: 1.34x kernel: 2.03ms baseline: 1.52ms
seq_len: 2048 slower: 1.20x kernel: 6.06ms baseline: 5.04ms
seq_len: 4096 slower: 1.25x kernel: 23.19ms baseline: 18.58ms
seq_len: 8192 slower: 1.22x kernel: 90.73ms baseline: 74.51ms

Is the speedup only seen on A100s?

I am going to train a small model on Wikitext-103 on an A100 cluster next and report the results.

Thank you,

Enrico

Import fails

Something seems to have changed overnight -- I'm getting an error when running import flash_cosine_sim_attention (running off 0.1.15 pip-installed):

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/__init__.py", line 1, in <module>
    from flash_cosine_sim_attention.flash_cosine_sim_attention import flash_cosine_sim_attention, plain_cosine_sim_attention, l2norm_tensors
  File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py", line 10, in <module>
    exec(open('flash_cosine_sim_attention/version.py').read())
FileNotFoundError: [Errno 2] No such file or directory: 'flash_cosine_sim_attention/version.py'

When the package is make install'd, everything works as expected (I'm assuming it's another bundling issue :^))

Edit: actually the make install version only works when running from the repository -- I get the same issue when running outside the repository

Can not import debug

Got this error:

"ImportError: cannot import name 'debug' from 'flash_cosine_sim_attention.flash_cosine_sim_attention' (/home/administrator/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py)"

Training Loss and Experiments

Hi @lucidrains,

Here are the results for training the GPT2 model on an A100 (40 GB). This is a different A100 I have not used before. I left everything the same other than just logging the loss. After around 65k steps there seems to be an exploding/vanishing gradient and loss goes to NaN. Training became more unstable 20k step mark from my few runs.

Screenshot from 2022-10-30 14-00-45

I will have to test training on A100 (80 GB) as well.

Thank you,

Enrico

Pip Install Fails

Hey!

I attempted to install this package using pip, but ran into

flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu:10:10: fatal error: dispatch.h: No such file or directory
         10 | #include "dispatch.h"
            |          ^~~~~~~~~~~~
      compilation terminated.
      error: command '/usr/bin/nvcc' failed with exit status 1

Cloning the repo and running make install worked just fine, so I'm assuming setup.py (or whatever makes the pip package) just isn't including the dispatch.h header file.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.