GithubHelp home page GithubHelp logo

lucidrains / flash-cosine-sim-attention Goto Github PK

View Code? Open in Web Editor NEW
194.0 12.0 9.0 35.21 MB

Implementation of fused cosine similarity attention in the same style as Flash Attention

License: MIT License

Python 32.56% Cuda 63.01% Makefile 0.67% C++ 3.76%
artificial-intelligence attention-mechanisms deep-learning

flash-cosine-sim-attention's Issues

GPU Benchmarks

Hi Phil,

Firstly, Thank you for the amazing work yet again!

I was wondering if you had done any benchmarking with mid-tier GPUs. I ran the benchmarks on my local system with a few RTX 3090s and received these results:

python3 benchmark.py --only-forwards

float32 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.96x kernel: 0.23ms baseline: 0.24ms
seq_len: 256 slower: 1.32x kernel: 0.38ms baseline: 0.28ms
seq_len: 512 slower: 1.85x kernel: 0.82ms baseline: 0.44ms
seq_len: 1024 slower: 1.57x kernel: 2.15ms baseline: 1.37ms
seq_len: 2048 slower: 1.17x kernel: 5.94ms baseline: 5.06ms
seq_len: 4096 slower: 1.20x kernel: 22.70ms baseline: 18.84ms
seq_len: 8192 slower: 0.00x kernel: 90.47ms baseline: oom

float16 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.72x kernel: 0.19ms baseline: 0.26ms
seq_len: 256 slower: 1.04x kernel: 0.24ms baseline: 0.23ms
seq_len: 512 slower: 1.04x kernel: 0.30ms baseline: 0.29ms
seq_len: 1024 slower: 1.00x kernel: 0.70ms baseline: 0.70ms
seq_len: 2048 slower: 0.71x kernel: 1.83ms baseline: 2.59ms
seq_len: 4096 slower: 0.67x kernel: 6.23ms baseline: 9.36ms
seq_len: 8192 slower: 0.65x kernel: 23.78ms baseline: 36.45ms**

python3 benchmark.py --only-backwards

float32 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.96x kernel: 0.55ms baseline: 0.57ms
seq_len: 256 slower: 1.76x kernel: 0.89ms baseline: 0.50ms
seq_len: 512 slower: 2.18x kernel: 2.09ms baseline: 0.96ms
seq_len: 1024 slower: 1.83x kernel: 5.16ms baseline: 2.82ms
seq_len: 2048 slower: 1.74x kernel: 17.56ms baseline: 10.12ms
seq_len: 4096 slower: 1.71x kernel: 64.56ms baseline: 37.74ms
seq_len: 8192 slower: 0.00x kernel: 250.87ms baseline: oom

float16 batch: 4 heads: 8 dim 64

seq_len: 128 slower: 0.92x kernel: 0.55ms baseline: 0.60ms
seq_len: 256 slower: 1.03x kernel: 0.60ms baseline: 0.58ms
seq_len: 512 slower: 1.54x kernel: 0.89ms baseline: 0.58ms
seq_len: 1024 slower: 1.34x kernel: 2.03ms baseline: 1.52ms
seq_len: 2048 slower: 1.20x kernel: 6.06ms baseline: 5.04ms
seq_len: 4096 slower: 1.25x kernel: 23.19ms baseline: 18.58ms
seq_len: 8192 slower: 1.22x kernel: 90.73ms baseline: 74.51ms

Is the speedup only seen on A100s?

I am going to train a small model on Wikitext-103 on an A100 cluster next and report the results.

Thank you,

Enrico

Can not import debug

Got this error:

"ImportError: cannot import name 'debug' from 'flash_cosine_sim_attention.flash_cosine_sim_attention' (/home/administrator/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py)"

make install fails

$ make install
python setup.py install --user
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/installer.py:27: SetuptoolsDeprecationWarning: setuptools.installer is deprecated. Requirements should be satisfied by a PEP 517 installer.
  warnings.warn(
running install
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/setuptools/command/easy_install.py:144: EasyInstallDeprecationWarning: easy_install command is deprecated. Use build and pip and other standards-based tools.
  warnings.warn(
running bdist_egg
running egg_info
writing flash_cosine_sim_attention.egg-info/PKG-INFO
writing dependency_links to flash_cosine_sim_attention.egg-info/dependency_links.txt
writing requirements to flash_cosine_sim_attention.egg-info/requires.txt
writing top-level names to flash_cosine_sim_attention.egg-info/top_level.txt
/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/utils/cpp_extension.py:472: UserWarning: Attempted to use ninja as the BuildExtension backend but we could not find ninja.. Falling back to using the slow distutils backend.
  warnings.warn(msg.format('we could not find ninja.'))
reading manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'flash_cosine_sim_attention.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
creating build
creating build/lib.linux-x86_64-cpython-39
creating build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/__init__.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/benchmark.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/transformer.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
copying flash_cosine_sim_attention/flash_cosine_sim_attention.py -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
creating build/lib.linux-x86_64-cpython-39/tests
copying tests/__init__.py -> build/lib.linux-x86_64-cpython-39/tests
copying tests/test.py -> build/lib.linux-x86_64-cpython-39/tests
copying flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -> build/lib.linux-x86_64-cpython-39/flash_cosine_sim_attention
running build_ext
building 'flash_cosine_sim_attention_cuda' extension
creating build/temp.linux-x86_64-cpython-39
creating build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention
/usr/local/cuda/bin/nvcc -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/torch/csrc/api/include -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/TH -I/home/antor/anaconda3/envs/open/lib/python3.9/site-packages/torch/include/THC -I/usr/local/cuda/include -I/home/antor/anaconda3/envs/open/include/python3.9 -c flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu -o build/temp.linux-x86_64-cpython-39/flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options '-fPIC' -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -DTORCH_EXTENSION_NAME=flash_cosine_sim_attention_cuda -D_GLIBCXX_USE_CXX11_ABI=0 -gencode=arch=compute_86,code=compute_86 -gencode=arch=compute_86,code=sm_86 -std=c++14
flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
            argument types are: (c10::Half *, float)
          detected during:
            instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=2, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1016): here
            instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here

flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu(405): error: no instance of overloaded function "atomicAdd" matches the argument list
            argument types are: (c10::Half *, float)
          detected during:
            instantiation of "void mma_warp_tile<scalar_t, tmpl_N_thread, tmpl_M_thread>::atomic_add(accessor, int, int, int, int) [with scalar_t=c10::Half, tmpl_N_thread=2, tmpl_M_thread=4, accessor=at::TensorAccessor<c10::Half, 2UL, at::RestrictPtrTraits, signed int>]"
(1051): here
            instantiation of "void backward_kernel(PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<__nv_bool, 2>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 3>, PackedAccessor<scalar_t, 4>, PackedAccessor<scalar_t, 4>, float, __nv_bool, __nv_bool, __nv_bool, __nv_bool) [with scalar_t=c10::Half]"
(1118): here

2 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".
error: command '/usr/local/cuda/bin/nvcc' failed with exit code 1
make: *** [Makefile:3: install] Error 1

Enviroment

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Jun__8_16:49:14_PDT_2022
Cuda compilation tools, release 11.7, V11.7.99
Build cuda_11.7.r11.7/compiler.31442593_0
$ python -c 'import torch;print(torch.__version__)'
1.13.0.dev20220918

failed building wheel for flash-cosine-sim-attention

Hi, package currently getting errors when building both local (100 errors detected in the compilation of "flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu".) and wheel.

Wheel build error logs:

Collecting flash-cosine-sim-attention
  Downloading flash-cosine-sim-attention-0.1.40.tar.gz (25 kB)
  Preparing metadata (setup.py) ... done
Requirement already satisfied: torch>=1.10 in /usr/local/lib/python3.10/dist-packages (from flash-cosine-sim-attention) (2.0.1+cu118)
Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.12.2)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (4.6.3)
Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (1.11.1)
Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1)
Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (3.1.2)
Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10->flash-cosine-sim-attention) (2.0.0)
Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (3.25.2)
Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.10->flash-cosine-sim-attention) (16.0.6)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10->flash-cosine-sim-attention) (2.1.3)
Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10->flash-cosine-sim-attention) (1.3.0)
Building wheels for collected packages: flash-cosine-sim-attention
  error: subprocess-exited-with-error
  
  × python setup.py bdist_wheel did not run successfully.
  │ exit code: 1
  ╰─> See above for output.
  
  note: This error originates from a subprocess, and is likely not a problem with pip.
  Building wheel for flash-cosine-sim-attention (setup.py) ... error
  ERROR: Failed building wheel for flash-cosine-sim-attention
  Running setup.py clean for flash-cosine-sim-attention
Failed to build flash-cosine-sim-attention

Training Loss and Experiments

Hi @lucidrains,

Here are the results for training the GPT2 model on an A100 (40 GB). This is a different A100 I have not used before. I left everything the same other than just logging the loss. After around 65k steps there seems to be an exploding/vanishing gradient and loss goes to NaN. Training became more unstable 20k step mark from my few runs.

Screenshot from 2022-10-30 14-00-45

I will have to test training on A100 (80 GB) as well.

Thank you,

Enrico

Support head dimension 16?

Hello and thanks for your work.
Can flash-cosine-sim-attention support head dimension 16?32 is too big for my model, so I wonder do you have any plans to support head dimension 16 as flash-attention didi?

Pip Install Fails

Hey!

I attempted to install this package using pip, but ran into

flash_cosine_sim_attention/flash_cosine_sim_attention_cuda.cu:10:10: fatal error: dispatch.h: No such file or directory
         10 | #include "dispatch.h"
            |          ^~~~~~~~~~~~
      compilation terminated.
      error: command '/usr/bin/nvcc' failed with exit status 1

Cloning the repo and running make install worked just fine, so I'm assuming setup.py (or whatever makes the pip package) just isn't including the dispatch.h header file.

Import fails

Something seems to have changed overnight -- I'm getting an error when running import flash_cosine_sim_attention (running off 0.1.15 pip-installed):

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/__init__.py", line 1, in <module>
    from flash_cosine_sim_attention.flash_cosine_sim_attention import flash_cosine_sim_attention, plain_cosine_sim_attention, l2norm_tensors
  File "/home/ubuntu/.local/lib/python3.8/site-packages/flash_cosine_sim_attention/flash_cosine_sim_attention.py", line 10, in <module>
    exec(open('flash_cosine_sim_attention/version.py').read())
FileNotFoundError: [Errno 2] No such file or directory: 'flash_cosine_sim_attention/version.py'

When the package is make install'd, everything works as expected (I'm assuming it's another bundling issue :^))

Edit: actually the make install version only works when running from the repository -- I get the same issue when running outside the repository

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.