GithubHelp home page GithubHelp logo

hyper-attn's Introduction

HyperAttention: Long-context Attention in Near-Linear Time

The repository is the PyTorch implementation of HyperAttention paper:

HyperAttention: Long-context Attention in Near-Linear Time (https://arxiv.org/pdf/2310.05869.pdf)
Insu Han, Rajesh Jayaram, Amin Karbasi, Vahab Mirrokni, David P. Woodruff, Amir Zandieh ($\alpha$-$\beta$)

Requirements

The code requires pytorch and triton. We tested pytorch version 2.0.1, but any version >= 2.0.0 might work. We also make use of FlashAttention in triton implementation. Please make sure to install triton version 2.0.0.dev20221202 as the triton version of FlashAttention works on this version. (We will update removing the dependency on specific triton version.)

pip install triton==2.0.0.dev20221202 --no-deps

Benchmarks

The repository contains two benchmark experiments under the following files:

  1. benchmark_single_attention_layer.py: This code is for a benchmark estimating (GPU) runtimes of HyperAttention and FlashAttention exploring the sequence lengths from 1K to 131k. To run,

    python benchmark_single_attention_layer.py --attn_method hyper 

    You can choose the computation mode among forward, backward or both. To specify the mode, please add --mode fwd for forward, --mode bwd for backward and --mode fwd+bwd for both. The default is fwd+bwd. Additionally, to simulate attention without causal masking please add --no_causal.

  2. benchmark_patch_llm.py: This code is for a benchmark of computing perplexity of pretrained language models where their self-attention is patched with HyperAttention. We choose chatglm2-6b-32k model and LongBench datasets. To run with sequence length 32768

    python benchmark_patch_llm.py --attn_method hyper --seq_len 32768

    You can also override FlashAttention by specifying --attn_method flash and try other sequence lengths by specifying --seq_len 65536 as long as the VRAM allows.

We ran all experiments on a single NVIDIA A100 with 40GB VRAM.

How to use

The impelmentation of HyperAttention can be found in models/attention/hyper_attn.py. An example of usage:

from models.attention.hyper_attn import HyperAttention

attn = HyperAttention(
    input_dim=64 
    lsh_num_projs=7,
    block_size=256,
    sample_size=256
    min_seq_len=4096)

attn_output = attn(query, key, value, causal=True)

The module has the following parameters:

  • input_dim: the dimension of input query and key. (Required)
  • lsh_num_projs: the number of dimension in the hashing space. The default is 7.
  • block_size: the size of blocks for the block-diagonal approximation. The default is 256.
  • sample_size: the number of sampled columns in the attention matrix $A$. The default is 256.
  • min_seq_len: minimum sequence length that HyperAttention applies. When the sequence length is smaller than this value we compute exactly using the FlashAttention because additional operations of HyperAttention may not negligble. The default value is 4096.

How it works

The algorithm consists of (1) finding heavy entries in attention matrix and (2) column subsampling. For (1), we use the sorted locality sensitive hashing (sortLSH) based on the Hamming distance. Applying sortLSH makes heavy entries in the attention matrix (sorting rows/columns) located in near diagonal hence we do block-diagonal approximation which can be done fast.

Causal masking

To support the causal masking, we (implicitly) split the attention matrix into 3 parts: (1) upper-left, (2) lower-right and (3) lower-left. Both (1) and (2) still require the causal masking as they are near the diagonal and we recursively apply this process for them. The submatrix (3) is located entirely below the diagonal hence we run the HyperAttention without masking.

License

The code is licensed under the Apache 2.0 license.

Citation

@article{hyperattention,
  title={Hyperattention: Long-context attention in near-linear time},
  author={Han, Insu and Jarayam, Rajesh and Karbasi, Amin and Mirrokni, Vahab and Woodruff, David and Zandieh, Amir},
  journal={arXiv preprint arXiv:2310.05869},
  year={2023}
}

hyper-attn's People

Contributors

insuhan avatar tridao avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.