GithubHelp home page GithubHelp logo

llama-x's Introduction

Environment

conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
pip install transformers==4.36.2
pip install -r requirements.txt
pip install trl
pip install flash-attn==2.3.6 --no-build-isolation

Train SFT

Both SFT and DPO default to using the Vicuna-1.1 template. model_name_or_path refers to the model address, data_path refers to the training file address, and output_dir refers to the output address. Deepspeed defaults to deepseed zero-3 cpu offloading.

  deepspeed train_freeform_multiturn.py \
    --model_name_or_path mistralai/Mistral-7B-v0.1 \
    --data_path data/sample_data_sft.json \
    --model_max_length 2048 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 8 \
    --gradient_accumulation_steps 4 \
    --save_strategy steps \
    --save_steps 11 \
    --save_total_limit 10 \
    --learning_rate 5e-6 \
    --weight_decay 0.0 \
    --warmup_steps 30 \
    --logging_steps 2 \
    --lr_scheduler_type cosine \
    --gradient_checkpointing True \
    --deepspeed configs/stage3_offloading_accelerate.json \
    --output_dir save_dir/llamax/auto_gsm8k_stage1_llama3_70b_dialogue_clean \
    --bf16 True \

Train DPO

deepspeed dpo_train.py \
    --model_name_or_path /blob/caxu/outputmodel/7b_lmsys10w_5wevolmix_instag1w_1800step_e3_4096/tmp-checkpoint-1700/ \
    --json_path data/sample_data.json \
    --data_split train \
    --output_dir /share/project/weihao/save_dir/checkpoints/train_ppo_1to5_reward_sppo_hard_nll_fix_6pair_no_duplicate_beta_0.03_hf_trl  \
    --num_train_epochs 1 \
    --beta 0.03 \
    --model_max_length 2048  \
    --per_device_train_batch_size 4  \
    --per_device_eval_batch_size 1  \
    --gradient_accumulation_steps 4  \
    --save_global_steps False \
    --eval_steps 50 \
    --save_strategy "steps"  \
    --save_steps 100  \
    --save_total_limit 25  \
    --learning_rate 5e-7  \
    --warmup_ratio 0.1 \
    --logging_steps 1  \
    --lr_scheduler_type "linear"  \
    --do_eval False \
    --evaluation_strategy "no"  \
    --conv_template "vicuna_v1.1" \
    --run_name "Deita-7b" \
    --seed 46 \
    --gradient_checkpointing True \
    --deepspeed configs/stage3_offloading_accelerate.json \
    --bf16 True \
    --report_to wandb \

llama-x's People

Contributors

zeng-wh avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.