GithubHelp home page GithubHelp logo

appdays / warp-rnnt Goto Github PK

View Code? Open in Web Editor NEW

This project forked from 1ytic/warp-rnnt

0.0 0.0 0.0 161 KB

CUDA-Warp RNN-Transducer

License: MIT License

C++ 16.87% Python 43.43% C 2.89% Cuda 35.65% CMake 1.15%

warp-rnnt's Introduction

PyPI Downloads

CUDA-Warp RNN-Transducer

A GPU implementation of RNN Transducer (Graves 2012, 2013). This code is ported from the reference implementation (by Awni Hannun) and fully utilizes the CUDA warp mechanism.

The main bottleneck in the loss is a forward/backward pass, which based on the dynamic programming algorithm. In particular, there is a nested loop to populate a lattice with shape (T, U), and each value in this lattice depend on the two previous cells from each dimension (e.g. forward pass).

CUDA executes threads in groups of 32 parallel threads called warps. Full efficiency is realized when all 32 threads of a warp agree on their execution path. This is exactly what is used to optimize the RNN Transducer. The lattice is split into warps in the T dimension. In each warp, variables between threads exchanged using a fast operations. As soon as the current warp fills the last value, the next two warps (t+32, u) and (t, u+1) are start running. A schematic procedure for the forward pass is shown in the figure below, where T - number of frames, U - number of labels, W - warp size. The similar procedure for the backward pass runs in parallel.

Performance

NVIDIA Profiler shows advantage of the warp implementation over the non-warp implementation.

This warp implementation:

Non-warp implementation warp-transducer:

Unfortunately, in practice this advantage disappears because the memory operations takes much longer. Especially if you synchronize memory on each iteration.

warp_rnnt (gather=False) warp_rnnt (gather=True) warprnnt_pytorch transducer (CPU)
T=150, U=40, V=28
N=1 0.50 ms 0.54 ms 0.63 ms 1.28 ms
N=16 1.79 ms 1.72 ms 1.85 ms 6.15 ms
N=32 3.09 ms 2.94 ms 2.97 ms 12.72 ms
N=64 5.83 ms 5.54 ms 5.23 ms 23.73 ms
N=128 11.30 ms 10.74 ms 9.99 ms 47.93 ms
T=150, U=20, V=5000
N=1 0.95 ms 0.80 ms 1.74 ms 21.18 ms
N=16 8.74 ms 6.24 ms 16.20 ms 240.11 ms
N=32 17.26 ms 12.35 ms 31.64 ms 490.66 ms
N=64 out-of-memory out-of-memory out-of-memory 944.73 ms
N=128 out-of-memory out-of-memory out-of-memory 1894.93 ms
T=1500, U=300, V=50
N=1 5.89 ms 4.99 ms 10.02 ms 121.82 ms
N=16 95.46 ms 78.88 ms 76.66 ms 732.50 ms
N=32 out-of-memory 157.86 ms 165.38 ms 1448.54 ms
N=64 out-of-memory out-of-memory out-of-memory 2767.59 ms

Benchmarked on a GeForce RTX 2070 Super GPU, Intel i7-10875H CPU @ 2.30GHz.

Note

  • This implementation assumes that the input is log_softmax.

  • In addition to alphas/betas arrays, counts array is allocated with shape (N, U * 2), which is used as a scheduling mechanism.

  • core_gather.cu is a memory-efficient version that expects log_probs with the shape (N, T, U, 2) only for blank and labels values. It shows excellent performance with a large vocabulary.

  • Do not expect that this implementation will greatly reduce the training time of RNN Transducer model. Probably, the main bottleneck will be a trainable joint network with an output (N, T, U, V).

  • Also, there is a restricted version, called Recurrent Neural Aligner, with assumption that the length of input sequence is equal to or greater than the length of target sequence.

Install

There are two bindings for the core algorithm:

Reference

warp-rnnt's People

Contributors

1ytic avatar teapoly avatar iceychris avatar maxwellzh avatar appdays avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.