GithubHelp home page GithubHelp logo

lucidrains / megabyte-pytorch Goto Github PK

View Code? Open in Web Editor NEW
607.0 11.0 51.0 35.3 MB

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch

License: MIT License

Python 100.00%
artificial-intelligence deep-learning learned-tokenization attention-mechanisms long-context transformers

megabyte-pytorch's Introduction

MEGABYTE - Pytorch

Implementation of MEGABYTE, Predicting Million-byte Sequences with Multiscale Transformers, in Pytorch. Took the liberty to generalize it even further so one can have multiple local models.

Similar independent research that is a further generalization

Appreciation

  • Stability and 🤗 Huggingface for the generous sponsorship to work on and open source cutting edge artificial intelligence research

Install

$ pip install MEGABYTE-pytorch

Usage

import torch
from MEGABYTE_pytorch import MEGABYTE

model = MEGABYTE(
    num_tokens = 16000,             # number of tokens
    dim = (512, 256),               # transformer model dimension (512 for coarsest, 256 for fine in this example)
    max_seq_len = (1024, 4),        # sequence length for global and then local. this can be more than 2
    depth = (6, 4),                 # number of layers for global and then local. this can be more than 2, but length must match the max_seq_len's
    dim_head = 64,                  # dimension per head
    heads = 8,                      # number of attention heads
    flash_attn = True               # use flash attention
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

Test

Train on character-level enwik8 with patches of size 4 - length 8192

$ python train.py

Citations

@misc{yu2023megabyte,
    title   = {MEGABYTE: Predicting Million-byte Sequences with Multiscale Transformers}, 
    author  = {Lili Yu and Dániel Simig and Colin Flaherty and Armen Aghajanyan and Luke Zettlemoyer and Mike Lewis},
    year    = {2023},
    eprint  = {2305.07185},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@software{peng_bo_2021_5196578,
    author    = {PENG Bo},
    title     = {BlinkDL/RWKV-LM: 0.01},
    month     = {aug},
    year      = {2021},
    publisher = {Zenodo},
    version   = {0.01},
    doi       = {10.5281/zenodo.5196578},
    url       = {https://doi.org/10.5281/zenodo.5196578}
}
@article{Kazemnejad2023TheIO,
    title   = {The Impact of Positional Encoding on Length Generalization in Transformers},
    author  = {Amirhossein Kazemnejad and Inkit Padhi and Karthikeyan Natesan Ramamurthy and Payel Das and Siva Reddy},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2305.19466}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding},
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}

megabyte-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

megabyte-pytorch's Issues

the string is still divided into pieces

I also saw this paper today. The main purpose of the paper is to get rid of the tokenizier, but in fact, the string is still divided into pieces, but a pre-decoder is added. Is there a problem with my understanding?

No available kernel error

Thank you for your code.

I tried to train MEGABYTE using default settings, but I faced the following error.
How can I fix it?

.../MEGABYTE_pytorch/attend.py", line 111, in flash_attn
    out = F.scaled_dot_product_attention(
RuntimeError: No available kernel.  Aborting execution.

some implementations are different from the original paper

Hi, thanks to your advance code helping me to understand original paper!

Given only global_transformer and local_transformer, I found that your codes are different at some points:

let's set D_G=768, D_L=256, P=4, seq=40,batch_size=1,

  1. then h^global_out shape is (1, 1+10, 768) in your code, but it should be (1, 1+10, 4 * 768) in paper (1+10 means first pad token + following patch tokens). There are two token_emb func in original paper for h^global_in and h^local_in, your code only has a token_emb for h^local_in.
  2. the proj between global and local transformer projects dim 768 to 256, your code is doing this:
  • h^global_out.shape=(1, 11, 768), choose top10 token, then h^global_out.shape=(1, 10, 768)

  • h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 256)

  • h^local_in.shape=(10, 4, 256), concat with h^global_out, the final h^local_in has a shape of (10, 5, 256). at the end, drop the first token

    in the original paper, it does

  • h^global_out.shape=(1, 11, 4 * 768), choose top10 token, then h^global_out.shape=(1, 10, 4 * 768)

  • h^global_out = proj(h^global_out), then h^global_out.shape=(1, 10, 4 * 256)

  • h^local_in = concat([pad, top3 token in h^local_in]) + h^global_out.reshape(10, 4, 256), which will drop last token in h_local_in.

image

Please take a look and see if I misunderstand something. Looking forward to your apply!

Minor shape error

Flagging in case anyone else ran into this:
train.py errored for me initially on line 400 of megabyte.py:
start_tokens, logits = logits[:, 0, :1, :], logits[..., 1:, :]
I reshaped the start_tokens so they're shaped (4, 1, 256) instead of (4, 1, 5, 256) and the code runs fine.

GPU used for original paper experiments

Quick question–what hardware did you use for training (+ controlling for computation time) for your original paper?

Do you think the hardware (degree of sharding, other optimizations) would effect the relative wall time a lot?

the patch embbeder implementations are different from the original paper

Thank you so much for taking the time to share your code with me! I appreciate your generosity in helping me better understand the paper.

I noticed that your code has a slightly unique approach to implementing the Patch Embedder in comparison to the original.

The original paper implementation of Patch Embedder uses separate global and local byte embeddings, followed by concatenation in the global model. Your implementation, however, involves using fixed bytes embedding and linear transformations between different transformers to construct the patch embedding.

the paper

image

your code

self.patch_embedders = nn.ModuleList([nn.Sequential(
    Rearrange('... r d -> ... (r d)'),
    nn.LayerNorm(seq_len * dim_in),
    nn.Linear(seq_len * dim_in, dim_out),  # linear transformations here
    nn.LayerNorm(dim_out)
) for dim_in, dim_out, seq_len in zip(dim[1:], dim[:-1], max_seq_len[1:])])

Training Results and Scaling

Hi there.

I’ve run the training code in this repository for 25k out of the 100k batches and achieved a validation loss of around 1.28, or perplexity of 3.59. After this, the training loss continues to drop but the validation loss either plateaus, or slowly starts going back up. I was curious if you also found the same (however, I stopped at 25k and restarted training. I reloaded the model and optimiser checkpoints but didn’t preserve train/val shuffling. Not sure if this confounds it either). Also I tried running the training on a H100 80GB VRAM with a batch size of 60 instead of 4 and found very slow convergence and an earlier plateau of the val loss (~2.5 ish). Do other hyperparameters need to be adjusted to scale training on larger devices? I originally tested on an RTX 3060Ti with 8GB VRAM.

Thanks in advance.

Some question about the MEGABYTE

First of all, thank you for the author's contribution. Is this MEGABYTE only suitable for ASCII encoding? If you use Unicode, will it go wrong, easy to explode memory, then how to achieve character-level segmentation?

Why does it expect tokens?

The input to this implementation is a token (0-16000). Isn't the whole point of the original paper that the input is a byte (0-255)? Am I missing something about the patch embedding?

Evaluation metric bits-per-byte

Hi there,

Megabyte paper uses bits-per-byte in Table 2 as their evaluation metric. It seems it has difference compared with byte level perplexity, since their number in arXiv and Code is < 1. So it should not be perplexity. This repo uses the cross-entropy loss and can easily calculate the byte level perplexity. May I ask how to compute bits-per-byte metric?

Thanks a lot.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.