GithubHelp home page GithubHelp logo

huyanxin / sru Goto Github PK

View Code? Open in Web Editor NEW

This project forked from asappresearch/sru

0.0 3.0 0.0 361 KB

Training RNNs as Fast as CNNs (https://arxiv.org/abs/1709.02755)

License: MIT License

Shell 20.79% Python 79.21%

sru's Introduction

About

SRU is a recurrent unit that can run over 10 times faster than cuDNN LSTM, without loss of accuracy tested on many tasks.


Average processing time of LSTM, conv2d and SRU, tested on GTX 1070

For example, the figure above presents the processing time of a single mini-batch of 32 samples. SRU achieves 10 to 16 times speed-up compared to LSTM, and operates as fast as (or faster than) word-level convolution using conv2d.

Reference:

Training RNNs as Fast as CNNs

@article{lei2017sru,
  title={Training RNNs as Fast as CNNs},
  author={Lei, Tao and Zhang, Yu},
  journal={arXiv preprint arXiv:1709.02755},
  year={2017}
}

Requirements

Install requirements via pip install -r requirements.txt. CuPy and pynvrtc needed to compile the CUDA code into a callable function at runtime. Only single GPU training is supported.


Examples

The usage of SRU is similar to nn.LSTM. SRU likely requires more stacking layers than LSTM. We recommend starting by 2 layers and use more if necessary (see our report for more experimental details).

import torch
from torch.autograd import Variable
from cuda_functional import SRU, SRUCell

# input has length 20, batch size 32 and dimension 128
x = Variable(torch.FloatTensor(20, 32, 128).cuda())

input_size, hidden_size = 128, 128

rnn = SRU(input_size, hidden_size,
    num_layers = 2,          # number of stacking RNN layers
    dropout = 0.0,           # dropout applied between RNN layers
    rnn_dropout = 0.0,       # variational dropout applied on linear transformation
    use_tanh = 1,            # use tanh?
    use_relu = 0,            # use ReLU?
    bidirectional = False    # bidirectional RNN ?
)
rnn.cuda()

output, hidden = rnn(x)      # forward pass

# output is (length, batch size, hidden size * number of directions)
# hidden is (layers, batch size, hidden size * number of directions)

Make sure cuda_functional.py and the shared library cuda/lib64 can be found by the system, e.g.

export LD_LIBRARY_PATH=/usr/local/cuda/lib64
export PYTHONPATH=path_to_repo/sru

Instead of using PYTHONPATH, the SRU module now can be installed as a regular package via python setup.py install or pip install. See this PR.



Contributors

https://github.com/taolei87/sru/graphs/contributors

To-do

  • ReLU activation
  • support multi-GPU (context change)
  • Layer normalization + residual to compare with highway connection (current version)

sru's People

Contributors

ajsyp avatar taolei87 avatar taoleicn avatar yzhang87 avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.