GithubHelp home page GithubHelp logo

lucidrains / hourglass-transformer-pytorch Goto Github PK

View Code? Open in Web Editor NEW
66.0 6.0 5.0 35 MB

Implementation of Hourglass Transformer, in Pytorch, from Google and OpenAI

License: MIT License

Python 100.00%
artificial-intelligence deep-learning transformers

hourglass-transformer-pytorch's Introduction

Hourglass Transformer - Pytorch

Implementation of Hourglass Transformer, in Pytorch.

Install

$ pip install hourglass-transformer-pytorch

Usage

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 256,               # number of tokens
    dim = 512,                      # feature dimension
    max_seq_len = 1024,             # maximum sequence length
    heads = 8,                      # attention heads
    dim_head = 64,                  # dimension per attention head
    shorten_factor = 2,             # shortening factor
    depth = (4, 2, 4),              # tuple of 3, standing for pre-transformer-layers, valley-transformer-layers (after downsample), post-transformer-layers (after upsample) - the valley transformer layers can be yet another nested tuple, in which case it will shorten again recursively
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x) # (1, 1024, 256)

For something more sophisticated, two hourglasses, with one nested within the other

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 256,
    dim = 512,
    max_seq_len = 1024,
    shorten_factor = (2, 4),     # 2x for first hour glass, 4x for second
    depth = (4, (2, 1, 2), 3),   # 4@1 -> 2@2 -> 1@4 -> 2@2 -> 3@1
)

x = torch.randint(0, 256, (1, 1024))
logits = model(x)

Funnel Transformer would be approximately

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 1024,
    causal = False,
    attn_resampling = False,
    shorten_factor = 2,
    depth = (2, (2, (2, 2, 2), 2), 2)
)

x = torch.randint(0, 20000, (1, 1024))
logits = model(x)

For images, instead of average pool and repeat for the down and upsampling functions, they found that linear projections worked a lot better. You can use this by setting updown_sample_type = 'linear'

import torch
from hourglass_transformer_pytorch import HourglassTransformer

model = HourglassTransformer(
    dim = 512,
    shorten_factor = 2,
    depth = (4, 2, 4),
    updown_sample_type = 'linear'
)

img_tokens = torch.randn(1, 1024, 512)
model(img_tokens) # (1, 1024, 512)

Although results were not presented in the paper, you can also use the Hourglass Transformer in this repository non-autoregressively.

import torch
from hourglass_transformer_pytorch import HourglassTransformerLM

model = HourglassTransformerLM(
    num_tokens = 20000,
    dim = 512,
    max_seq_len = 1024,
    shorten_factor = 2,
    depth = (4, 2, 4),
    causal = False          # set this to False
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones((1, 1024)).bool()

logits = model(x, mask = mask) # (1, 1024, 20000)

Enwik8 autoregressive example

$ python train.py

Todo

  • work with non-autoregressive, accounting for masking
  • account for masking for attention resampling
  • account for shift padding when naive downsampling

Citations

@misc{nawrot2021hierarchical,
    title   = {Hierarchical Transformers Are More Efficient Language Models}, 
    author  = {Piotr Nawrot and Szymon Tworkowski and Michał Tyrolski and Łukasz Kaiser and Yuhuai Wu and Christian Szegedy and Henryk Michalewski},
    year    = {2021},
    eprint  = {2110.13711},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

hourglass-transformer-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

hourglass-transformer-pytorch's Issues

Issue with an example

@lucidrains Just Awesome to see this super fast launch !

The third example under "For Images...." throws the following error .,

image

Then , when I gave the inputs with num_tokens = 20000,
max_seq_len = 1024, I got the following runtime error

image

No option for rotary positional embeddings

For the ImageNet64 generation experiments, the original authors used rotary positional embeddings. I think it would be worth adding this feature to the repository.

I can prepare a pull request that solves this~

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.