GithubHelp home page GithubHelp logo

parallel-prompt-decoding's Introduction

PPD

Parallel Prompt Decoding: Accelerate LLM Inference with Parallel Prompting

PPD (Parallel Prompt Decoding) is a cost-efficient method to accelerate LLMs with trained appended prompt tokens. Our technique stands out for three key features:

  • Orthogonal Optimization: Orthogonal to speculative decoding, PPD provides the potential for synergistic integration.
  • Memory Efficiency: With a minimal runtime memory overhead of just 0.0004%, PPD is highly suitable for edge and mobile settings.
  • Training Efficiency: The training process is efficient, requiring only 16 hours on a single A100-40GB GPU.

PPD on Vicuna-7b.

The key intuition of PPD lies in the observation that if trained properly, prompt tokens appended to the input can approximate tokens generated at future timesteps, thereby partially recovering the missing conditional dependency information for multi-token generation.


Inspired by the human natural language generation process where continuous words like common expressions and phrases are produced simultaneously, PPD introduces the use of prompt tokens for multi-token prediction. Specifically, these trained prompt tokens are appended to the original input sequence, enabling the concurrent generation of multiple output tokens in a single forward pass.



PPD is designed to address the following challenges faced by the current speculative decoding methods:

  • Low step compression rate due to conditional independence assumption.
  • High complexity and cost of training/maintaining draft models.
  • Limited applicability in edge and mobile environments due to memory overhead.

Through extensive experiments across LLMs ranging from MobileLlama to Vicuna-13B on a wide range of benchmarks, our approach demonstrates up to 2.49 $\times$ speedup and maintains a minimal runtime memory overhead of just 0.0004%.


Evaluated on a single A100 with MT-Bench.

Our paper is available (here)! If you found it helpful, pls cite us:

@article{hao2024ppd,
  title={Hardware-aware parallel prompt decoding for memory-efficient acceleration of LLM inference},
  author={Chen, Hao (Mark) and Luk, Wayne and Yiu, Ka Fai Cedric and Li, Rui and Mishchenko, Konstantin and Venieris, Stylianos I and Fan, Hongxiang,
  journal={arXiv preprint arXiv:2405.18628},
  year={2024}
}

Contents

Installation

git clone https://github.com/hmarkc/prompt-decoding.git
cd prompt-decoding
pip install -e .

Model Weights

Original Model PPD embedding weights
lmsys/vicuna-7b-v1.3 hmarkc/ppd-vicuna-7b-v1.3
lmsys/vicuna-13b-v1.3 hmarkc/ppd-vicuna-13b-v1.3
mtgv/MobileLLaMA-1.4B-Chat hmarkc/ppd-MobileLLaMA-1.4B-Chat

Dataset Generation

Truncated Dataset

With a given dataset, a random truncation is performed to reduce the contextual bias of the training of special tokens. Then, a distillation dataset is generated from the truncated dataset.

The truncated datasets need to be generated first. Here is how a dataset for 3 special tokens can be generated for the ShareGPT dataset.

python generate_dataset.py --dataset_type finetune --num_special_tokens 3 --data_path ShareGPT_V4.3_unfiltered_cleaned_split.json --model_max_length 2048

Distillation Dataset

Then, we can generate the distillation dataset from the truncated dataset. --data_path is the path to the previously generated truncated dataset and --model_name_or_path is the model the distribution of which we want to obtain.

python generate_dataset.py --dataset_type distillation  --num_special_tokens 3 --data_path ShareGPT_training_dataset_3_finetune_2048.pt --model_max_length 2048 --model_name_or_path lmsys/vicuna-7b-v1.3

Training Special tokens

Example script to train Vicuna-7b with distillation dataset named "ShareGPT_training_dataset_2_distillation.pt".

accelerate launch --num_processes 4 prompt/train/train.py --model_name_or_path lmsys/vicuna-7b-v1.3 \
    --dataset_path "./ShareGPT_training_dataset_2_distillation.pt" \
    --output_dir test/ \
    --num_train_epochs 1 \
    --save_steps 500 \
    --model_max_length 2048 \
    --num_special_tokens 3 \
    --virtual_tokens_per_special_token 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy "no" \
    --learning_rate 1e-2 \
    --weight_decay 0.0 \
    --warmup_ratio 0.0 \
    --lr_scheduler_type "cosine" \
    --logging_steps 10 \
    --load_in_4bit \
    --vt_attention_type "ensemble" \
    --trainer_type "distillation_trainer"

You need to change the --dataset_path to the location of the distillation dataset and specify --trainer_type as "distillation_trainer" to train with knowledge distillation. --num_special_tokens specifies the number of special tokens for training. --virtual_tokens_per_special_token is the number of virtual tokens used for 1 special token, which should be set to 1 to achieve the lowest latency results.

Inference

We employ a dynamically extended tree attention and top K candidates for inference. The supported evaluation datasets currently include Alpaca Eval, MT Bench, and HumanEval.

Refer to this README.md on how to install libraries for these datasets.

Chat Application

We implemented a simple chat application using gradio. To start a server for the chat application, run python application/webui.py --ppd-path <ppd-model-path>.

MT Bench

  • To obtain the latency of the baseline model (without the use of special tokens), run
python3 gen_model_answer_baseline.py
  --model-path <model-path>
  --model-id <model-id>
  --answer-file <output-file>
  --bench-name mt_bench
  --temperature <temperature>
  • To obtain the latency of the model with special tokens, run
python3 gen_model_answer_prompt_decoding.py
  --model-path <model-path>
  --model-id <model-id>
  --answer-file <output-file>
  --tree-length 105
  --bench-name mt_bench
  --temperature <temperature>

--model-path is the path to the trained special tokens and --tree-length is the length of the sparse tree used.

  • To view the latency results of a generated .jsonl file, run
python get_throughput_results.py data/mt_bench/experiments/vicuna-7b-faster1.jsonl --n 3

--n specifies the number of experiment runs to get the average of.

Alpaca Eval

We use Alpaca Eval dataset as the evaluation dataset. The latency results can be obtained using the same script as MT Bench and adding --bench-name alpaca_eval.

  • To compare the latencies and accept lengths between sparse trees with different sizes, run:
python accept_length.py \
  --dir-path <output-dir> \
  --file-name <file-name> \
  --model-name <model-path> \
  --eval-file-name gen_model_answer_prompt_decoding.py \
  --n 1 \
  --max-length 120 \
  --min-length 60 \
  --length-interval 9 \
  --choices "[75, 105, 135, 165, 195, 225, 255, 285]" \

python3 tree_latency.py \
  --model-path <model-path> \
  --model-id <model-id> \
  --answer-file <output-file> \
  --bench-name alpaca_eval \
  --min-tree-length 60 \
  --max-tree-length 120 \
  --length-interval 3 \
  --max-new-token 1024

This script runs the latency tests on a range of sparse trees.

HumanEval

The latency results of HumanEval can be obtained using the same script as MT Bench and adding --bench-name humaneval.

GSM8K

The latency results of GSM8K can be obtained using the same script as MT Bench and adding --bench-name gsm8k.

parallel-prompt-decoding's People

Contributors

os-hxfan avatar hmarkc avatar

Stargazers

xianxiang huang avatar 唐国梁Tommy avatar Taowen (Tony) avatar Gautham Krishna Gudur avatar  avatar 张大千 avatar  avatar Airlamb avatar  avatar Ons Aouedi avatar Vectory avatar distinction1111 avatar 爱可可-爱生活 avatar Amund Tveit avatar Rogier van Dalen avatar AGIFollow avatar Jiahao Gai avatar Martin Ferianc avatar fwtan avatar Mohamed S. Abdelfattah avatar Jay Rainey avatar Steve Laskaridis avatar  avatar Royson avatar Łukasz Dudziak avatar Javier avatar Shell Hu avatar Anqi Tian avatar  avatar Konstantin Mishchenko avatar Rui Li avatar  avatar Hongyu Zhou avatar  avatar Jiacheng (Alan) Liang avatar  avatar

Watchers

Rui Li avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.