GithubHelp home page GithubHelp logo

1980dragon / longnet Goto Github PK

View Code? Open in Web Editor NEW

This project forked from kyegomez/longnet

0.0 0.0 0.0 2.97 MB

Fork of "LongNet: Scaling Transformers to 1,000,000,000 Tokens"

Home Page: https://discord.gg/qUtxnK2NMf

License: Apache License 2.0

Python 100.00%

longnet's Introduction

Agora

This implementation of LongNet is brought to you by Agora, we're an all-new open source AI research organization with 1,500+ AI researchers all striving to advance Humanity!

Agora banner

Join us and help contribute to LongNet and or recieve support in the Agora discord!

LongNet: Scaling Transformers to 1,000,000,000 Tokens

This is an implementation for the paper LongNet: Scaling Transformers to 1,000,000,000 Tokens by Jiayu Ding, Shuming Ma, Li Dong, Xingxing Zhang, Shaohan Huang, Wenhui Wang, Furu Wei. The LongNet is a Transformer variant designed to scale sequence length up to more than 1 billion tokens without sacrificing performance on shorter sequences.

Introduction

Scaling sequence length has become a critical bottleneck in the era of large language models. However, existing methods struggle with either computational complexity or model expressivity, rendering the maximum sequence length restricted. In this paper, they introduce LongNet, a Transformer variant that can scale sequence length to more than 1 billion tokens, without sacrificing the performance on shorter sequences. Specifically, we propose dilated attention, which expands the attentive field exponentially as the distance grows.

Features

LongNet has significant advantages:

  1. It has a linear computation complexity and a logarithm dependency between tokens.
  2. It can be served as a distributed trainer for extremely long sequences.
  3. Its dilated attention is a drop-in replacement for standard attention, which can be seamlessly integrated with the existing Transformer-based optimization.

Experiment results demonstrate that LongNet yields strong performance on both long-sequence modeling and general language tasks. Their work opens up new possibilities for modeling very long sequences, e.g., treating a whole corpus or even the entire Internet as a sequence.

Here's the updated usage and installation section with two methods: git clone or pip install LongNet:

Installation

You can install LongNet using one of the following methods:

Method 1: Git Clone

  1. Clone the LongNet repository from GitHub:
git clone https://github.com/kyegomez/LongNet.git
  1. Navigate to the cloned directory:
cd LongNet
  1. Install the required dependencies:
pip install -r requirements.txt

Method 2: Pip Install

  1. Install LongNet directly from PyPI using pip:
pip install LongNet

Please note that LongNet requires a compatible Python version (tested with Python 3.7).

Usage

Once you have installed LongNet, you can use the DilatedAttention class as follows:

import torch
import torch.nn as nn
from LongNet import DilatedAttention

# Replace this with your correct GPU device
device = "cuda:0"
dtype = torch.float16

# Create an instance of DilatedAttention
d_model = 512
num_heads = 8
dilation_rate = 2
segment_size = 64
dropout = 0.2  # Specify the dropout rate
attention = DilatedAttention(
    d_model=d_model,
    num_heads=num_heads,
    dilation_rate=dilation_rate,
    segment_size=segment_size,
    dropout=dropout,
).to(device, dtype=dtype)

# Create some dummy input data
batch_size = 16
seq_len = 128
input_dim = d_model
inputs = torch.randn(batch_size, seq_len, input_dim, device=device, dtype=dtype)

# Forward pass
outputs = attention(inputs)

# Print the output shape
print(outputs.shape)  # Expected: [batch_size, seq_len, d_model]

In the example above, we create an instance of the DilatedAttention class with the specified hyperparameters. We then generate some dummy input data and pass it through the attention mechanism to obtain the outputs. Finally, we print the shape of the output tensor.

DilatedAttention Documentation

The DilatedAttention class implements dilated attention, which expands the attentive field exponentially as the distance between tokens grows. It inherits from torch.nn.Module and can be used as a drop-in replacement for standard attention mechanisms in Transformer models.

Parameters

  • d_model (int): The dimensionality of the input and output embeddings.
  • num_heads (int): The number of attention heads.
  • dilation_rate (int): The dilation rate for sparsifying the input sequence.
  • segment_size (int): The size of each segment after sparsification.
  • dropout (float, optional): The dropout probability to apply to the attention output. Default: 0.0 (no dropout).

Inputs

  • x (Tensor): The input tensor of shape (batch_size, seq_len, d_model).

Outputs

  • output (Tensor): The output tensor of shape (batch_size, seq_len, d_model).

Please note that the input tensor should be on the correct device (e.g., GPU) and have the appropriate data type (dtype).

Citation

@inproceedings{ding2023longnet,
  title={LongNet: Scaling Transformers to 1,000,000,000 Tokens},
  author={Ding, Jiayu and Ma, Shuming and Dong, Li and Zhang, Xingxing and Huang, Shaohan and Wang, Wenhui and Wei, Furu},
  booktitle={Proceedings of the 10th International Conference on Learning Representations},
  year={2023}
}

Share with Friends

Share LongNet with your friends and colleagues who might find it useful. Simply click on the links below to share on various platforms:

Thank you for sharing!

Share LongNet Repository

Roadmap

  • Integrate Alibi and xpos for even further ridicoulus length extrapolation

  • Create a multi-modality verison with sub layer norm

  • Integrate QK Layernorm

  • Integrate One write query head maybe

  • Recreate in Triton or Jax for ultra mega speed boost

longnet's People

Contributors

kyegomez avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.