GithubHelp home page GithubHelp logo

xxxxhong / remax Goto Github PK

View Code? Open in Web Editor NEW

This project forked from liziniu/remax

0.0 0.0 0.0 1.81 MB

Code for Paper (ReMax: A Simple, Efficient and Effective Reinforcement Learning Method for Aligning Large Language Models)

Shell 3.77% Python 96.23%

remax's Introduction

ReMax: A Simple, Effective, and Efficient Method for Aligning Large Language Models

Overview

ReMax is a reinforcement learning method, tailored for reward maximization in RLHF.

Simple Implementation

ReMax is easy to implement (with 6 lines of code). We provide an implementation based on the DeepSpeed framework in this resposity.

Memory Efficient

ReMax is memory-efficient. Compared with PPO, ReMax can save about 50% GPU memory consumption, which could be allocated for 1.3x large batch size.

Results of tuning Llama2-7B with A100-80GB GPUs
GPUs Offload Method Maximum Batch Size
4 False PPO โŒ (OOM)
4 False ReMax 4x26=104
4 True PPO 4x30=120
4 True ReMax 4x40=160
1 True PPO 1x32=32
1 True ReMax 1x42=42

*: Gradient checkpointing and ZeRO-2 are used for LLM.

*: ZeRO-3 and offload are used for the reward model and the reference model.

Fast Training

ReMax runs fast. It does not need to train a value model and requires fewer computations. Usually, it can achieve about 2x training speed-up.

Results of tuning Llama2-7B with A100-80GB GPUs
GPUs Offload Method Total Training Time
4 False PPO โŒ (OOM)
4 False ReMax 2.4h
4 True PPO 6.0h
4 True ReMax 2.8h
1 True PPO 22.0h
1 True ReMax 10.2h

*: Gradient checkpointing and ZeRO-2 are used for LLM.

*: ZeRO-3 and offload are used for the reward model and the reference model.

*: Measurement is based on 45k training samples (with 1 epoch) from the full-hh-rlhf dataset.

Easy to Tune

ReMax is easy to tune for good performance. On the AlpacaEval benchmark, when judeged by GPT-4, ReMax achieves win rates of 84.22%, 75.28%, and 63.60% over SFT, DPO, and PPO, respectively.

Change Log

  • [2023-12-16] Add response samples of trained models and evaluation results of training speed.
  • [2023-10-18] Release the initial code.

How to use

Prepare

The Python environment can be set up using Anaconda with the provided environment.yml file.

conda env create -f environment.yml
conda activate llm

Step 1 SFT

cd step1_supervised_finetuning

# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh

# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh

Step 2 Reward Learning

cd step2_reward_model_finetuning

# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh

# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh

Step 3 RLHF

cd step3_rlhf_finetuning

# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh

# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh

Acknowledgements

Our code is heavily based on the DeepSpeed-Chat. Please follow the detailed instructions from DeepSpeed-Chat.

Bibtex

If you find this code is helpful, please cite our paper in the following format.

@article{li2023remax,
  title     = {ReMax: A Simple, Effective, and Efficient Method for Aligning Large Language Models},
  author    = {Li, Ziniu and Xu, Tian and Zhang, Yushun and Yu, Yang and Sun, RUoyu and Luo, Zhi-Quan},
  booktitle = {arXiv preprint arXiv:2310.10505},
  year      = {2023},
}

remax's People

Contributors

liziniu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.