GithubHelp home page GithubHelp logo

accelerated-pytorch-transformers-generation's People

Contributors

fxmarty avatar gante avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

accelerated-pytorch-transformers-generation's Issues

Cache preallocation in decoding

Motivation

We suspect the concatenations of the key/value buffer at each generation step to be expensive.

FasterTransformer does preallocate the kv cache. Preallocating may also help torch.compile, also I don't quite get why (there are still dynamic shapes in the model itself, so why care about the model I/O?).

Results

Some optimizations may still be missing. On commit 1acc7e4ea234d7143700234e51877acc9859f9ed

image
image
image
image
image
image

Todos

  • past_key_values buffer should be initialized in the background, instead of requiring the user to initialize it himself and requiring to pass generate(**inputs, past_key_values=past_key_values).
  • Current implementation is likely to break with accelerate with naive pipeline parallelism, as the buffer is currently initialized on a single device
  • Preallocated kv cache still does not help with small models / batch size.
  • Support an iterative buffer, e.g. that auto-extends each 512 tokens, instead of initializing a buffer of size max_new_tokens. This may help reducing memory usage (and speed? probably not).
  • Implement tests
  • Should this be in optimum or transformers?
  • Support all (is it possible?) decoding strategies, instead of currently only greedy_search
  • Have it work with cross-attention

TODOs

  • Experiment with torch.compile (find torchdynamo graph breaks and remove them if possible?)
  • Experiment with iteratively allocated KV cache, see this suggestion, this would avoid aten::slice calls
  • Can we avoid aten::copy_ calls and aten::slice calls?
  • Test on CPU
  • Support cross-attention
  • Support encoder-decoder architectures
  • Support preallocated attention_mask
  • Support preallocated token_type_ids (llama does not use it though)
  • Would a single qkv_proj help instead of the current q_proj separate from kv_proj? [yes, helps slightly]
  • Make sure the implementation is valid not on the argmax but on the logits directly
  • Investigate whether stopping to inherit from nn.Module may help, see this __getattr__ overhead

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.