GithubHelp home page GithubHelp logo

leemgs / rq-transformer Goto Github PK

View Code? Open in Web Editor NEW

This project forked from lucidrains/rq-transformer

0.0 1.0 0.0 35.02 MB

Implementation of RQ Transformer, proposed in the paper "Autoregressive Image Generation using Residual Quantization"

License: MIT License

Python 100.00%

rq-transformer's Introduction

RQ-Transformer

Implementation of RQ Transformer, which proposes a more efficient way of training multi-dimensional sequences autoregressively. This repository will only contain the transformer for now. You can use this vector quantization library for the residual VQ.

This type of axial autoregressive transformer should be compatible with memcodes, proposed in NWT. It would likely also work well with multi-headed VQ

Install

$ pip install RQ-transformer

Usage

import torch
from rq_transformer import RQTransformer

model = RQTransformer(
    num_tokens = 16000,             # number of tokens, in the paper they had a codebook size of 16k
    dim = 512,                      # transformer model dimension
    max_spatial_seq_len = 1024,     # maximum positions along space
    depth_seq_len = 4,              # number of positions along depth (residual quantizations in paper)
    spatial_layers = 8,             # number of layers for space
    depth_layers = 4,               # number of layers for depth
    dim_head = 64,                  # dimension per head
    heads = 8,                      # number of attention heads
)

x = torch.randint(0, 16000, (1, 1024, 4))

loss = model(x, return_loss = True)
loss.backward()

# then after much training

logits = model(x)

# and sample from the logits accordingly
# or you can use the generate function

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 1024, 4)

I also think there is something deeper going on, and have generalized this to any number of dimensions. You can use it by importing the HierarchicalCausalTransformer

import torch
from rq_transformer import HierarchicalCausalTransformer

model = HierarchicalCausalTransformer(
    num_tokens = 16000,                   # number of tokens
    dim = 512,                            # feature dimension
    dim_head = 64,                        # dimension of attention heads
    heads = 8,                            # number of attention heads
    depth = (4, 4, 2),                    # 3 stages (but can be any number) - transformer of depths 4, 4, 2
    max_seq_len = (16, 4, 5)              # the maximum sequence length of first, stage, then the fixed sequence length of all subsequent stages
).cuda()

x = torch.randint(0, 16000, (1, 10, 4, 5)).cuda()

loss = model(x, return_loss = True)
loss.backward()

# after a lot training

sampled = model.generate(temperature = 0.9, filter_thres = 0.9) # (1, 16, 4, 5)

Todo

  • move hierarchical causal transformer to separate repository, seems to be working

Citations

@unknown{unknown,
    author  = {Lee, Doyup and Kim, Chiheon and Kim, Saehoon and Cho, Minsu and Han, Wook-Shin},
    year    = {2022},
    month   = {03},
    title   = {Autoregressive Image Generation using Residual Quantization}
}
@misc{press2021ALiBi,
    title   = {Train Short, Test Long: Attention with Linear Biases Enable Input Length Extrapolation},
    author  = {Ofir Press and Noah A. Smith and Mike Lewis},
    year    = {2021},
    url     = {https://ofir.io/train_short_test_long.pdf}
}

rq-transformer's People

Contributors

lucidrains avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.