GithubHelp home page GithubHelp logo

lucidrains / flash-pytorch Goto Github PK

View Code? Open in Web Editor NEW
333.0 9.0 22.0 35.01 MB

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

License: MIT License

Python 100.00%
artificial-intelligence deep-learning attention-mechanism transformers efficient-transformers

flash-pytorch's Introduction

FLASH - Pytorch

Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time

Install

$ pip install FLASH-pytorch

Usage

The main novel circuit in this paper is the "Gated Attention Unit", which they claim can replace multi-headed attention while reducing it to just one head.

It uses a relu squared activation in place of the softmax, the activation of which was first seen in the Primer paper, and the use of ReLU in ReLA Transformer. The gating style seems mostly inspired by gMLPs.

import torch
from flash_pytorch import GAU

gau = GAU(
    dim = 512,
    query_key_dim = 128,     # query / key dimension
    causal = True,           # autoregressive or not
    expansion_factor = 2,    # hidden dimension = dim * expansion_factor
    laplace_attn_fn = True   # new Mega paper claims this is more stable than relu squared as attention function
)

x = torch.randn(1, 1024, 512)
out = gau(x) # (1, 1024, 512)

The authors then combine GAU with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.

This combination of the quadratic gated attention unit with grouped linear attention they named FLASH

You can also use this quite easily

import torch
from flash_pytorch import FLASH

flash = FLASH(
    dim = 512,
    group_size = 256,             # group size
    causal = True,                # autoregressive or not
    query_key_dim = 128,          # query / key dimension
    expansion_factor = 2.,        # hidden dimension = dim * expansion_factor
    laplace_attn_fn = True   # new Mega paper claims this is more stable than relu squared as attention function
)

x = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size
out = flash(x) # (1, 1111, 512)

Finally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).

import torch
from flash_pytorch import FLASHTransformer

model = FLASHTransformer(
    num_tokens = 20000,          # number of tokens
    dim = 512,                   # model dimension
    depth = 12,                  # depth
    causal = True,               # autoregressive or not
    group_size = 256,            # size of the groups
    query_key_dim = 128,         # dimension of queries / keys
    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor
    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)
    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x) # (1, 1024, 20000)

Test on Autoregressive Enwik8

$ python train.py

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@inproceedings{Ma2022MegaMA,
    title   = {Mega: Moving Average Equipped Gated Attention},
    author  = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
    year    = {2022}
}

flash-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

flash-pytorch's Issues

AttributeError: module 'torch' has no attribute 'special'

I have torch==1.7.0a0 installed, but when I ran out = gau(x), there is an error:

Traceback (most recent call last):
File "", line 1, in
File "MYPATH/work/venv/pytorch-rocm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "MYPATH/work/flash-pytorch/git-FLASH-pytorch/flash_pytorch/flash_pytorch.py", line 201, in forward
attn = self.attn_fn(sim / seq_len)
File "MYPATH/work/venv/pytorch-rocm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "MYPATH/work/flash-pytorch/git-FLASH-pytorch/flash_pytorch/flash_pytorch.py", line 131, in forward
return (1 + torch.special.erf((x - mu) / (std * math.sqrt(2)))) * 0.5
AttributeError: module 'torch' has no attribute 'special'

It seems that torch.special is introduced in Pytorch 1.9.

Laplace Activation Function Implementation

Seems like the implementation on Laplace Activation deviates from what the paper described:
In the paper, I think it should write std = 1 / math.sqrt(4 * math.pi) instead of std = math.sqrt(0.25 * math.pi) as the former one is an approximation of relu^2

Cross-Attention?

Hi, @lucidrains. Thank you for sharing this excellent implementation with us all!
Do you have any thoughts as to what changes would need to be made to make cross-attention possible with your FLASH model?

About negative values in my input sentence embeddings

Does FLASHTransformer accept negative values in sentence embedding?
I passed the BERT sentence embedding into the FLASHTransformer, and got the following error:

IndexError: index out of range in self

I then tried the README method:
x = torch.randint(0, 20000, (1, 1024)) <-- works fine
x = torch.randint(-1, 1, (1, 1024)) <-- same error

Thx

mask error

x = torch.randint(0, 20000, (1, 1024))
mask = x.ne(0)
logits = model(x, mask=mask)

RuntimeError: The size of tensor a (1024) must match the size of tensor b (128) at non-singleton dimension 2

rel_pos_bias in GAU

Hello @lucidrains, thanks for your generous sharing about this implementation. According to Figure 2 of the original paper, there is a rel_pos_bias(q, k) to obtain the final attention weights. Although I can find this function in your FLASH, this operation is missing in GAU. Could you please explain this question, or whether this operation is useless in GAU?

Thanks!

The speed.

Thanks for your excellent work.
However, GAU is slower than the original MHSA in my implementation, 3.5s vs 0.7s. As I simply use "from flash_pytorch import GAU" with the default setting.
I there something wrong with my implementation?
image

About the "shift_tokens"

Thank you for your amazing code.

In the class of FLASH, I find a flag: shift_tokens, and the corresponding code is as following:
if self.shift_tokens:
x_shift, x_pass = normed_x.chunk(2, dim = -1)
x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
normed_x = torch.cat((x_shift, x_pass), dim = -1)

Assume we have normed_x in the shape [1024, 512], the x_shift/x_pass is the shape of [1024, 256]. Then it adds a row (with all 0 value) and remove the last row in the x_shift, and concat x_shift and x_pass to get the normed_x.

In my opinion, the F.pad operation will make the row in x_shift and x_pass do not match again.

May I know why it works?

Kang

Speed on TPU

Hi,
Thanks for the code!
I test it on Google TPU v3, the training speed seems slower than my expectation. Maybe there is some operation which is not lower on TPU.

About the "/n"

Hi, @lucidrains Thanks for your excellent work.
However, I have a small question, why there need to be a "/n", which seems not appear in the paper?

line374 "lin_kv = einsum(f'b g n d, b g n e -> {context_einsum_eq}', lin_k, v) / n"

Looking forward to your reply.

einsum operation in Linear Attention Part

Hi,
Thanks a lot for your FLASH_pytorch, which helps a lot.
I found that there are some differences from the paper in the Linear Attention Part:
https://github.com/lucidrains/FLASH-pytorch/blob/main/flash_pytorch/flash_pytorch.py#L342-L343

lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
lin_out = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)

the lin_kv is three-dim (bde)
And the code in the paper is

lin_kv = tf.einsum('bhke,bgh→bgke', lin_kv, mask) 
linear = tf.einsum('bgnk,bgke→bgne', lin_q, lin_kv)

the lin_kv is four-dim (bgke)
It seems that the two ways are not equivalent.

Looking forward to your reply.
Best,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.