GithubHelp home page GithubHelp logo

sgconv's Introduction

Yuhong Li*, Tianle Cai*, Yi Zhang, Deming Chen, Debadeepta Dey

Update log

  • 10/17/2022: Core code released.
  • 11/3/2022: Standalone code released. Easier to use if you want to try SGConv on your own model!
  • Upcoming: Full code release.

Overview

Convolutional models have been widely used in multiple domains. However, most existing models only use local convolution, making the model unable to handle long-range dependency efficiently. Attention overcomes this problem by aggregating global information but also makes the computational complexity quadratic to the sequence length.

Recently, Gu et al. [2021] proposed a model called S4 inspired by the state space model. S4 can be efficiently implemented as a global convolutional model whose kernel size equals the input sequence length. S4 can model much longer sequences than Transformers and achieve significant gains over SoTA on several long-range tasks. Despite its empirical success, S4 is involved. It requires sophisticated parameterization and initialization schemes. As a result, S4 is less intuitive and hard to use.

Here we aim to demystify S4 and extract basic principles that contribute to the success of S4 as a global convolutional model. We focus on the structure of the convolution kernel and identify two critical but intuitive principles enjoyed by S4 that are sufficient to make up an effective global convolutional model:

  1. The parameterization of the convolutional kernel needs to be efficient in the sense that the number of parameters should scale sub-linearly with sequence length. image
  2. The kernel needs to satisfy a decaying structure that the weights for convolving with closer neighbors are larger than the more distant ones.

Based on the two principles, we propose a simple yet effective convolutional model called Structured Global Convolution (SGConv).

image

SGConv exhibits strong empirical performance over several tasks:

  1. With faster speed, SGConv surpasses S4 on Long Range Arena and Speech Command datasets. image
  2. When plugging SGConv into standard language and vision models, it shows the potential to improve both efficiency and performance.

Code

Based on the amazing codebase by HazyResearch. Please refer to the repo to install the dependencies.

Standalone

In gconv_standalone.py, we provide a standalone implementation of SGConv. You can use it as a drop-in replacement for your existing models. The example of how to use it is shown in test.ipynb, where we tried on sequence with 1M tokens and it cost ~20GB GPU memory per layer.

import torch
from gconv_standalone import GConv

layer = GConv(
    d_model=256,
    d_state=64,
    l_max=1_000_000,
    bidirectional=True,
    kernel_dim=32,
    n_scales=None,
    decay_min=2,
    decay_max=2,
)

x = torch.randn(1, 256, 1_000_000)
y, k = layer(x, return_kernel=True)

Citation

@misc{li2022makes,
      title={What Makes Convolutional Models Great on Long Sequence Modeling?}, 
      author={Yuhong Li and Tianle Cai and Yi Zhang and Deming Chen and Debadeepta Dey},
      year={2022},
      eprint={2210.09298},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

sgconv's People

Contributors

ctlllll avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sgconv's Issues

Complex Tensors

Greetings,

May I kindly request an adaptation of the code where the convolution operation is done without resorting to fft? This is largely because torch DDP does not support complex tensors well (see pytorch/pytorch#80080), preventing the usage of the model in a distributed training scenario (related to a current issue #3).

Any updates or pointers to this or possible getarounds are greatly appreciated!

Update:

I managed to find the issue. It is not related to complex tensors as all SGConv parameters are either torch.float32 or torch.int64. The issue is related to self.kernel_norm_initialized, which is registered as torch.bool. Despite that the NCCL backend has supported torch.bool (pytorch/pytorch#41959), it seems it has been the cause of the issue. Changing this registry from torch.bool to torch.float32 resolved the problem.

Can't get it to run with multi-GPU

Here is my code:

import os
import time
import datetime

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

import argparse
from tensorboardX import SummaryWriter

gpu_devices = '0,1,2,3'
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices


device = 'cuda' if torch.cuda.is_available() else 'cpu'

net = GConv(
    d_model=256,
    d_state=64,
    l_max=1_000_000,
    bidirectional=True,
    kernel_dim=32,
    n_scales=None,
    decay_min=2,
    decay_max=2,
)

net = nn.DataParallel(net)
net = net.to(device)
num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print('The number of parameters of model is', num_params)
                
x = torch.randn(1, 256, 1_000_000)
x = x.to(device)

y, k = net(x, return_kernel=True)

And here is the error I am getting:

IndexError: Caught IndexError in replica 0 on device 0.
Original Traceback (most recent call last):
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/ec2-user/SageMaker/SGConv/gconv_standalone.py", line 416, in forward
self.kernel_list[i],
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 462, in getitem
idx = self._get_abs_string_index(idx)
File "/home/ec2-user/anaconda3/envs/pytorch_p38/lib/python3.8/site-packages/torch/nn/modules/container.py", line 445, in _get_abs_string_index
raise IndexError('index {} is out of range'.format(idx))
IndexError: index 0 is out of range

Seeking details of the final SGConv model used for LRA results

Hi,

Thanks a lot for this wonderful work and for sharing it with the community.

I want to reproduce SGConv's results on LRA and have the following questions regarding that. Please help me with them.

  1. Could you please direct me to some references that you used for the final evaluation code (data processing, the evaluation metric computation, etc) so that I can replicate the complete setup?

  2. What are the values of different hyperparameters used for LRA tasks:
    (a) What are the values of inputs to init method - some are in the Appendix of the paper and you have mentioned some values in this notebook, but I am not sure if these were the ones used for LRA tasks. Could you please provide task-specific values of these hyperparameters for LRA?
    (b) How many GConv layers are there in the final model?
    (c) Init method's parameter list defines mode as a string with one value "cat_randn" (line 277). Is this the value used for all experiments?

Thanks!

Have you tried even longer sequences? Like billions of tokens?

Hello. I actually ended up giving a presentation on this paper because I found it so fascinating. I know this paper performs very well on long-range tasks in the 200k range, but have you tried it on even longer-range tasks than that, like in the billions of tokens away?

Issue finding where GConv was implemented in the repo

Hello,

I might just not be looking in the right place, but has the code been released in the preview branch showing an implementation of the SGConv layer applied to the tasks described in the paper? I cannot seem to find that.

Thanks!

2d filters

Hey, nice work!

I was wondering. It seems that in image tasks you convert the features to 1d and then apply the filter. Would it be possible to create 2d filters using the same idea? Did you try that?

Reproduce results of LRA benchmark

Hi @ctlllll

How do I reproduce the results of the LRA benchmark presented in the paper? The GitHub repo only contains code for the SGConv block, please share the code for full network architecture and training code to reproduce the results.

I have already tried the code of the state-space-model and replaced the S4 block with the SGConv block and having a hard time reproducing the results.

Missing the relevant code for experiment reproduction

Hello!

Great work on the GConv modules.

Previously, you and your team had said that the rest of the codebase would be released upon acceptance however (unless I am mistaken), the code to replicate the experiments outlined in the paper have still not been released.

Would it be possible to release that code? It would provide very valuable information.

Thank you!

The question about gconv.py

Hello. I read the code and have two questions.

  1. It seems that d_state is not used in the code? I am curious about this parameter.
  2. Why multiplier is between 1 and 4 by default? I think this need to be a value smaller than 1, like 1/2 in the paper.

checkpoint loading issue

Greetings,

I ran into some trouble loading an SGConv network from checkpoint. Particularly, I encountered the following problem:

size mismatch for kernel_norm: copying a param with shape torch.Size([2, 256, 1]) from checkpoint, the shape in current model is torch.Size([256, 1]).

A few things I have noticed:

  • The above error can be replicated by running the following code in ipynb.
  • In the code below, the error goes away if I forward the layer before loading.

Any pointers on what caused this error and how it can be solved are greatly welcome. Thanks!


import torch
from gconv_standalone import GConv

layer = GConv(
d_model=256,
d_state=64,
l_max=1_000_000,
bidirectional=True,
kernel_dim=32,
n_scales=None,
decay_min=2,
decay_max=2,
)

x = torch.randn(1, 256, 1000)
x = x.cuda()
layer.cuda()
y, k = layer(x, return_kernel=True)

path = './dummy_ckpt'
torch.save({
'state_dict': layer.state_dict()
}, path)

shell = GConv(
d_model=256,
d_state=64,
l_max=1_000_000,
bidirectional=True,
kernel_dim=32,
n_scales=None,
decay_min=2,
decay_max=2,
).cuda()

ckpt = torch.load(path)
shell.load_state_dict(ckpt['state_dict'])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.