lucidrains / flash-cosine-sim-attention Goto Github PK
View Code? Open in Web Editor NEWImplementation of fused cosine similarity attention in the same style as Flash Attention
License: MIT License
Implementation of fused cosine similarity attention in the same style as Flash Attention
License: MIT License
Hi Phil,
Firstly, Thank you for the amazing work yet again!
I was wondering if you had done any benchmarking with mid-tier GPUs. I ran the benchmarks on my local system with a few RTX 3090s and received these results:
python3 benchmark.py --only-forwards
seq_len: 128 slower: 0.96x kernel: 0.23ms baseline: 0.24ms
seq_len: 256 slower: 1.32x kernel: 0.38ms baseline: 0.28ms
seq_len: 512 slower: 1.85x kernel: 0.82ms baseline: 0.44ms
seq_len: 1024 slower: 1.57x kernel: 2.15ms baseline: 1.37ms
seq_len: 2048 slower: 1.17x kernel: 5.94ms baseline: 5.06ms
seq_len: 4096 slower: 1.20x kernel: 22.70ms baseline: 18.84ms
seq_len: 8192 slower: 0.00x kernel: 90.47ms baseline: oom
seq_len: 128 slower: 0.72x kernel: 0.19ms baseline: 0.26ms
seq_len: 256 slower: 1.04x kernel: 0.24ms baseline: 0.23ms
seq_len: 512 slower: 1.04x kernel: 0.30ms baseline: 0.29ms
seq_len: 1024 slower: 1.00x kernel: 0.70ms baseline: 0.70ms
seq_len: 2048 slower: 0.71x kernel: 1.83ms baseline: 2.59ms
seq_len: 4096 slower: 0.67x kernel: 6.23ms baseline: 9.36ms
seq_len: 8192 slower: 0.65x kernel: 23.78ms baseline: 36.45ms**
python3 benchmark.py --only-backwards
seq_len: 128 slower: 0.96x kernel: 0.55ms baseline: 0.57ms
seq_len: 256 slower: 1.76x kernel: 0.89ms baseline: 0.50ms
seq_len: 512 slower: 2.18x kernel: 2.09ms baseline: 0.96ms
seq_len: 1024 slower: 1.83x kernel: 5.16ms baseline: 2.82ms
seq_len: 2048 slower: 1.74x kernel: 17.56ms baseline: 10.12ms
seq_len: 4096 slower: 1.71x kernel: 64.56ms baseline: 37.74ms
seq_len: 8192 slower: 0.00x kernel: 250.87ms baseline: oom
seq_len: 128 slower: 0.92x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.60ms baseline: 0.58ms
seq_len: 512 slower: 1.54x kernel: 0.89ms baseline: 0.58ms
seq_len: 1024 slower: 1.34x kernel: 2.03ms baseline: 1.52ms
seq_len: 2048 slower: 1.20x kernel: 6.06ms baseline: 5.04ms
seq_len: 4096 slower: 1.25x kernel: 23.19ms baseline: 18.58ms
seq_len: 8192 slower: 1.22x kernel: 90.73ms baseline: 74.51ms
Is the speedup only seen on A100s?
I am going to train a small model on Wikitext-103 on an A100 cluster next and report the results.
Thank you,
Enrico
Got this error:
"ImportError: cannot import name 'debug' from 'flash_cosine_sim_attention.flash_cosine_sim_attention' (/home/administrator/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py)"
$ make install
python setup.py install --user
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
warnings.warn(
running install
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
warnings.warn(
running bdist_egg
running egg_info
writing flash_cosine_sim_attention.egg-info/PKG-INFO
writing dependency_links to flash_cosine_sim_attention.egg-info/dependency_links.txt
writing requirements to flash_cosine_sim_attention.egg-info/requires.txt
writing top-level names to flash_cosine_sim_attention.egg-info/top_level.txt
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/utils/cpp_extension.py:472: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-cpython-39
creating build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/__init__.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/benchmark.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/transformer.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/flash_cosine_sim_attention.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
creating build/lib.linux-x86_64-cpython-39/tests
copying tests/__init__.py -> build/lib.linux-x86_64-cpython-39/tests
copying tests/test.py -> build/lib.linux-x86_64-cpython-39/tests
copying flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
running build_ext
building 'flash_cosine_sim_attention_cuda' extension
creating build/temp.linux-x86_64-cpython-39
creating build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention
/usr/local/cuda/bin/nvcc -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/TH -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/antor/anaconda3/envs/open/include/python3.9 -c flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -o build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=flash_cosine_sim_attention_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
argument types are: (c10::Half *, float)
detected during:
instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=2, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1016): here
instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here
flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
argument types are: (c10::Half *, float)
detected during:
instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=4, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1051): here
instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here
2 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".
error: command '/usr/local/cuda/bin/nvcc' failed with exit code 1
make: *** [Makefile:3: install] Error 1
Enviroment
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0
$ python -c 'import torch;print(torch.__version__)'
1.13.0.dev20220918
Hi, package currently getting errors when building both local (100 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".) and wheel.
Wheel build error logs:
Collecting flash-cosine-sim-attention
Downloading flash-cosine-sim-attention-0.1.40.tar.gz (25 kB)
Preparing metadata (setup.py) ... done
Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from flash-cosine-sim-attention) (2.0.1+cu118)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (4.6.3)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (16.0.6)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->flash-cosine-sim-attention) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->flash-cosine-sim-attention) (1.3.0)
Building wheels for collected packages: flash-cosine-sim-attention
error: subprocess-exited-with-error
× python setup.py bdist_wheel did not run successfully.
│ exit code: 1
╰─> See above for output.
note: This error originates from a subprocess, and is likely not a problem with pip.
Building wheel for flash-cosine-sim-attention (setup.py) ... error
ERROR: Failed building wheel for flash-cosine-sim-attention
Running setup.py clean for flash-cosine-sim-attention
Failed to build flash-cosine-sim-attention
Hey! Great work as usual! :)
Correct me if I'm wrong but it seems that cuda version is still times slower compared to plain impl?
Thanks for your work! I test the flash-cosine-sim-attention by running benchmark.py
in my 2080, while it shows the flash is slower than plain, especially for large seq_len. Why does the result conflict with the paper?
Hi @lucidrains,
Here are the results for training the GPT2 model on an A100 (40 GB). This is a different A100 I have not used before. I left everything the same other than just logging the loss. After around 65k steps there seems to be an exploding/vanishing gradient and loss goes to NaN. Training became more unstable 20k step mark from my few runs.
I will have to test training on A100 (80 GB) as well.
Thank you,
Enrico
Hello and thanks for your work.
Can flash-cosine-sim-attention support head dimension 16?32 is too big for my model, so I wonder do you have any plans to support head dimension 16 as flash-attention didi?
Hey!
I attempted to install this package using pip, but ran into
flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu:10:10: fatal error: dispatch.h: No such file or directory
10 | #include "dispatch.h"
| ^~~~~~~~~~~~
compilation terminated.
error: command '/usr/bin/nvcc' failed with exit status 1
Cloning the repo and running make install
worked just fine, so I'm assuming setup.py
(or whatever makes the pip package) just isn't including the dispatch.h
header file.
Something seems to have changed overnight -- I'm getting an error when running import flash_cosine_sim_attention
(running off 0.1.15
pip-installed):
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/__init__.py", line 1, in <module>
from flash_cosine_sim_attention.flash_cosine_sim_attention import flash_cosine_sim_attention, plain_cosine_sim_attention, l2norm_tensors
File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py", line 10, in <module>
exec(open('flash_cosine_sim_attention/version.py').read())
FileNotFoundError: [Errno 2] No such file or directory: 'flash_cosine_sim_attention/version.py'
When the package is make install
'd, everything works as expected (I'm assuming it's another bundling issue :^))
Edit: actually the make install
version only works when running from the repository -- I get the same issue when running outside the repository
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.