GithubHelp home page GithubHelp logo

lucidrains / recurrent-memory-transformer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
382.0 13.0 15.0 35.09 MB

Implementation of Recurrent Memory Transformer, Neurips 2022 paper, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence attention-mechanisms deep-learning transformers long-context memory recurrence

recurrent-memory-transformer-pytorch's Introduction

Recurrent Memory Transformer - Pytorch

Implementation of Recurrent Memory Transformer (openreview) in Pytorch. They had a short follow up paper recently that demonstrated it was able to copy information across 1 million tokens at the very least.

There is no doubt in my mind that RMT would make a stronger RL agent than AdA, which is just a Transformer-XL - Update: Recurrent Memory Decision Transformer

Yannic Kilcher paper review

Appreciation

  • Stability and 🤗 Huggingface for their generous sponsorships to work on and open source cutting edge artificial intelligence research

Install

$ pip install recurrent-memory-transformer-pytorch

Usage

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,               # number of tokens
    num_memory_tokens = 128,          # number of memory tokens, this will determine the bottleneck for information being passed to the future
    dim = 512,                        # model dimensions
    depth = 6,                        # transformer depth
    causal = True,                    # autoregressive or not
    dim_head = 64,                    # dimension per head
    heads = 8,                        # heads
    seq_len = 1024,                   # sequence length of a segment
    use_flash_attn = True             # whether to use flash attention
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, _ = model(x)        # (1, 1024, 20000), (1, 128, 512), None
logits2, mem2, _ = model(x, mem1)  # (1, 1024, 20000), (1, 128, 512), None
logits3, mem3, _ = model(x, mem2)  # (1, 1024, 20000), (1, 128, 512), None

# and so on ...

With XL memories

import torch
from recurrent_memory_transformer_pytorch import RecurrentMemoryTransformer

model = RecurrentMemoryTransformer(
    num_tokens = 20000,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    causal = True,
    dim_head = 64,
    heads = 8,
    seq_len = 1024,
    use_flash_attn = True,
    use_xl_memories = True,    # set this to True
    xl_mem_len = 512           # can be shorter than the seq len - i think just having a bit of the past will prevent much of the RMT memories  memorizing the immediate preceding text
)

x = torch.randint(0, 256, (1, 1024))

logits1, mem1, xl_mem1 = model(x)                               # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits2, mem2, xl_mem2 = model(x, mem1, xl_memories = xl_mem1)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]
logits3, mem3, xl_mem3 = model(x, mem2, xl_memories = xl_mem2)  # (1, 1024, 20000), (1, 128, 512), [(2, 1, 512, 512)]

# and so on ...

Train on an absurdly long sequence

import torch
from recurrent_memory_transformer_pytorch import (
    RecurrentMemoryTransformer,
    RecurrentMemoryTransformerWrapper
)

model = RecurrentMemoryTransformer(
    num_tokens = 256,
    num_memory_tokens = 128,
    dim = 512,
    depth = 6,
    seq_len = 1024,
    use_flash_attn = True,
    causal = True
)

model = RecurrentMemoryTransformerWrapper(model).cuda()

seq = torch.randint(0, 256, (4, 65536)).cuda()   # absurdly long sequence, in reality, they curriculum learned this starting with 1 segment to about 7-8 segments

loss = model(seq, memory_replay_backprop = True) # memory efficient training from memformer paper

Todo

  • move the memory replay backprop into a torch.function, test out bidirectional, then test on a real problem

  • get rotary embeddings working properly with xl memories

  • add xl memories, detached

  • offer a way to turn off rotary embeddings, absolute positional embeddings, and add token shift

  • make memories being causally masked an option

  • add the memory replay backprop technique from memformer paper

  • relative positional encoding

Alternatives

Citations

@inproceedings{bulatov2022recurrent,
  title     = {Recurrent Memory Transformer},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail Burtsev},
  booktitle = {Advances in Neural Information Processing Systems},
  editor    = {Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
  year      = {2022},
  url       = {https://openreview.net/forum?id=Uynr3iPhksa}
}
@misc{bulatov2023scaling,
  title     = {Scaling Transformer to 1M tokens and beyond with RMT},
  author    = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
  year      = {2023},
  eprint    = {2304.11062},
  archivePrefix = {arXiv},
  primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
  title     = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
  author    = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
  booktitle = {Advances in Neural Information Processing Systems},
  year      = {2022}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{Wu2020MemformerAM,
    title   = {Memformer: A Memory-Augmented Transformer for Sequence Modeling},
    author  = {Qingyang Wu and Zhenzhong Lan and Kun Qian and Jing Gu and Alborz Geramifard and Zhou Yu},
    booktitle = {AACL/IJCNLP},
    year    = {2020}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@misc{ding2021cogview,
    title   = {CogView: Mastering Text-to-Image Generation via Transformers},
    author  = {Ming Ding and Zhuoyi Yang and Wenyi Hong and Wendi Zheng and Chang Zhou and Da Yin and Junyang Lin and Xu Zou and Zhou Shao and Hongxia Yang and Jie Tang},
    year    = {2021},
    eprint  = {2105.13290},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@software{Dayma_DALLE_Mini_2021,
    author  = {Dayma, Boris and Patil, Suraj and Cuenca, Pedro and Saifullah, Khalid and Abraham, Tanishq and Lê Khắc, Phúc and Melas, Luke and Ghosh, Ritobrata},
    doi     = {10.5281/zenodo.5146400},
    license = {Apache-2.0},
    month   = {jul},
    title   = {{DALL·E Mini}},
    url     = {https://github.com/borisdayma/dalle-mini},
    version = {v0.1-alpha},
    year    = {2021}}
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
@misc{ding2021erniedoc,
    title   = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer},
    author  = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang},
    year    = {2021},
    eprint  = {2012.15688},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@article{Xie2023ResiDualTW,
  title     = {ResiDual: Transformer with Dual Residual Connections},
  author    = {Shufang Xie and Huishuai Zhang and Junliang Guo and Xu Tan and Jiang Bian and Hany Hassan Awadalla and Arul Menezes and Tao Qin and Rui Yan},
  journal   = {ArXiv},
  year      = {2023},
  volume    = {abs/2304.14802}
}

recurrent-memory-transformer-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

recurrent-memory-transformer-pytorch's Issues

Question: why do we need read_memory_emb

self.read_memory_emb = nn.Parameter(torch.zeros(num_memory_tokens, dim))

if exists(read_memories):
read_mem_length = mem_length
read_memories = read_memories + self.read_memory_emb
else:
read_mem_length = 0
read_memories = x[:, 0:0]

Why do we need self.read_memory_emb ? They are not used if read_memories is None. So why not just use read_memories ?

Question: first read memories

During the first run, mems == None, and the model doesn't attend to any "read" tokens, as per:

if exists(read_memories):
read_mem_length = mem_length
read_memories = read_memories + self.read_memory_emb
else:
read_mem_length = 0
read_memories = x[:, 0:0]

Why not attend to read_memory_emb, and replace with

 read_mem_length = mem_length 
 read_mems = repeat(self.read_memory_emb, 'm d -> b m d', b = batch)
 if exists(read_memories): 
     read_mems += read_memories

have you had a chance to train it yet?

ive got a few days of full access to a cluster of about 8a6000s and im itching to put them to some insane task, i hadnt even considered this but i really want to ensemble this and see what happens in an extremely accelerated environment

What happens if texts from the dataset don't have equal lengths

What would happen if the texts in the dataset didn't have equal lengths and batch size was > 1. E.g. first text would have 2 segments and second 4. Would the loss be nan for segments with padding tokens only? I think your training script assumes that a segment is always of sequence length of 2048 non-padded.

bptt depth implementation?

Hi, @lucidrains!
Sorry to bother you but I wonder how you implemented back-propagation-through-time depth here.
Is it replaced with replay memory buffer? In that case, is gradient flowing unrestricted?

Question: masks

So I added some code to plot the attention mask. I use the following:

def lengths_to_padding_mask(T: int, lengths: torch.Tensor):
    return torch.arange(T, device=lengths.device).unsqueeze(0) < lengths.unsqueeze(-1)

net = RecurrentMemoryTransformer(
    seq_len=1024,
    num_tokens=259,
    num_memory_tokens=128,
    dim=512,
    depth=1,
    causal=True,
    heads=4,
    dim_head=128,
    use_flash_attn=True,
    rotary_pos_emb=True
).eval()

x = torch.randint(0, 256, (8, 1024))
l = torch.full((x.shape[0],), fill_value=800)
m = lengths_to_padding_mask(x.shape[1], l)
l, mems, _ = net(x, mask=m)
l, mems, _ = net(x, mems, mask=m)

and added

import matplotlib.pyplot as plt
img = mask[0,0]
plt.imshow(img)
plt.show()

at

On both forward passes I get the following attention masks:
mask1_noreadmems

mask2_readmems

@lucidrains Is this correct ?

Question: Global write tokens or recurrent

Hi. Thanks for sharing this work.

I have a question regarding the write tokens on the input side. In the current version, you always use the global parameter memory_tokens on the input side as here. But in the origin RMT paper, the equation is written as [H^mem H^0 H^mem], and also in Fig. 4 of the new RMDT paper, they draw the arrow from the last write output to both of the read and write input for the next segment. I think they also intend to use the write memory in a recurrent fashion.

token_shift

Hello, I'm sorry to bother you, but this issue is troubling me. Specifically, what is the purpose of token_shift and why is it needed? I hope to receive a detailed explanation from you.

Question: how does memory replay backprogagation work with multiple models in series

Say we have 2 instances of RecurrentMemoryTransformer one after the other. The first returns embeddings, not logits, which are fed into the second instance which is a continuous (without the embedding layer) version of RecurrentMemoryTransformer. Suppose we have a loss function which uses the outputs of the second model and the input of the first model. How does memory replay backpropagation work with this system?

flash attention, and a potentially better improvement

hello! sorry to bother you, but considering flash attention isnt currently compatible with relative positional embeddings - do you intend to work around that or update the codebase?

the implementation i came up with last night (before there was a git lol) turned out to be quite a bit different than this one. i wound up trying to adapt it to a t5, theres still more work to be done but there may be space in there to utilize the flash attention stuff, especially considering the alternate ways that ive got in there to adapt/simulate relative positional embeddings. figured id link it for you to look at in case something happened to be useful to you, i found your github acc last night and its going to be wildly useful to me so im just wanting to return the favor as best i can :)

https://github.com/Alignment-Lab-AI/RMT5-HYENA

Question: How to set seq_len ?

What is a good number for seq_len ?
What are the trade-offs for shorter or longer seq_len?
Like, why can't seq_len==1 ?
Infinite recurrence is infinite recurrence no matter what the value is right?

Question: configuring scaled_dot_product_attention

it looks like from

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = Config(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = Config(False, True, True)
and
with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = self.causal
)
we manually configure F.scaled_dot_product_attention().
From the documentation it says "All implementations are enabled by default. Scaled dot product attention attempts to automatically select the most optimal implementation based on the inputs."
Can't we just let pytorch decide?

Feature request: make JIT and ONNX export work

net = RecurrentMemoryTransformer(
    seq_len=1024,
    num_tokens=256,
    num_memory_tokens=128,
    dim=512,
    depth=1,
    causal=True,
    heads=4,
    dim_head=128,
    use_flash_attn=True,
    rotary_pos_emb=True
).eval()

x = torch.randint(0, 256, (8, 1024))

jit = torch.jit.trace(net, (x,))

x = torch.randint(0, 256, (8, 1024))
l = torch.randint(100, x.shape[1], size=(x.shape[0],))
m = lengths_to_padding_mask(x.shape[1], l)

l1, mems, _ = net(x, mask=m)
l2, mems, _ = net(x, mems, mask=m)
l3, mems, _ = jit(x, mask=m)
l4, mems, _ = jit(x, mems, mask=m)

torch.testing.assert_close(l1, l3)
torch.testing.assert_close(l2, l4)

It would be great if the above worked.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.