GithubHelp home page GithubHelp logo

foreveract / chatglm_mutli_gpu_tuning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from cshaitao/chatglm_mutli_gpu_tuning

0.0 0.0 0.0 139 KB

deepspeed+trainer简单高效实现多卡微调大模型

License: MIT License

Shell 0.94% Python 99.06%

chatglm_mutli_gpu_tuning's Introduction

ChatGLM_multi_gpu_zero_Tuning

实现multi_gpu + zero 微调ChatGLM, 目前已实现LoRA, Ptuning-v2, Freeze三种微调方式。 相关debug记录可以见知乎

使用该代码训练的法律大模型LexiLaw可见LexiLaw

Data Structure

data.py中使用InstrutionDataset和InstrutionCollator处理数据。

每条数据包含instruction,inpuy,answer,示例如下:

{"instruction": "", "input": "工作时把右小指挤段又接上的,XXXX左右的工资,在说我们感觉太不公平了,刚手术做了XX个小时", "answer": "这种情况,是属于工伤的。按照法律规定,需要先行认定工伤,认定工伤后,再进行劳动能力鉴定,然后再具体索赔。涉及赔偿项目比较多,一般包括医疗费、护理费、伙食费、停工留薪期内的工资,构成伤残的,还可以主张一次性伤残补助金、一次性伤残就业补助金、一次性工伤医疗补助金等,每一项赔偿的标准不一样,需要根据实际情况来具体计算。由于工伤案件,涉及程序较多、法律关系复杂,建议委托专业律师代理处理。我们经常代理工伤赔偿案件,如果愿意,我们可以协助处理。详情可以电联,或者面谈。"}

InstrutionDataset中的prefix只有ptuning时使用, InstrutionCollator中max_len和max_input_len用来控制输入长度

Train

LoRA

1.修改lora.sh中模型和数据路径

2.运行sh lora.sh

    CUDA_VISIBLE_DEVICES=${TOT_CUDA} deepspeed --master_port=$PORT --num_gpus=3 lora.py \
        --train_path ./instrution_data.json \
        --max_len 768 \
        --max_input_len 512 \
        --model_name_or_path ./chatGLM-6B \
        --tokenizer_name ./chatGLM-6B \
        --lora_rank 8 \
        --per_device_train_batch_size 16 \
        --gradient_accumulation_steps 4 \
        --num_train_epochs 10 \
        --save_steps 900 \
        --learning_rate 1e-5 \
        --fp16 \
        --remove_unused_columns false \
        --logging_steps 50 \
        --output_dir /output \
        --deepspeed /ds_config.json \

单卡运行可以改为 num_gpus == 1

LoRA的参数如下,可根据实际情况调整:

    peft_config = LoraConfig(
        task_type="CAUSAL_LM",
        lora_alpha=32,  
        target_modules=["query_key_value"],
        inference_mode=False,
        r=training_args.lora_rank,
        lora_dropout=0.1,
        bias="none",
        fan_in_fan_out = False
    )

P-tuning-v2

根据ChatGLM-6B官方P—tuning代码修改。

1.修改ptuning.sh中模型和数据路径

2.运行sh ptuning.sh

    CUDA_VISIBLE_DEVICES=${TOT_CUDA} deepspeed --master_port=$PORT --num_gpus=2 finetune_ptuning.py \
        --train_path ./instrution_data.json \
        --max_len 768 \
        --max_input_len 512 \
        --model_name_or_path /chatGLM-6B \
        --tokenizer_name/chatGLM-6B \
        --per_device_train_batch_size 8 \
        --gradient_accumulation_steps 4 \
        --num_train_epochs 10 \
        --save_steps 2000 \
        --learning_rate 1e-5 \
        --fp16 \
        --logging_steps 50 \
        --prefix_projection True \
        --pre_seq_len $PRE_SEQ_LEN \
        --output_dir /output \
        --deepspeed ds_config.json \

其中$PRE_SEQ_LEN是soft prompt的长度, 可以根据实际情况修改。

Freeze

1.修改freeze.sh中模型和数据路径

2.运行sh freeze.sh

    CUDA_VISIBLE_DEVICES=${TOT_CUDA} deepspeed --master_port=$PORT --num_gpus=3 finetune_freeze.py \
        --train_path  \
        --max_len 768 \
        --max_input_len 512 \
        --model_name_or_path /chatGLM-6B \
        --tokenizer_name /chatGLM-6B \
        --lora_rank 8 \
        --per_device_train_batch_size 16 \
        --gradient_accumulation_steps 4 \
        --num_train_epochs 10 \
        --save_steps 900 \
        --learning_rate 1e-5 \
        --fp16 \
        --remove_unused_columns false \
        --logging_steps 50 \
        --output_dir output_freeze \
        --deepspeed ds_config.json \

可通过以下代码修改可训练的层数:

    for name, param in model.named_parameters():
        if not any(nd in name for nd in ["layers.27", "layers.26", "layers.25", "layers.24", "layers.23"]):
            param.requires_grad = False

Requirements

python=3.9
transformers==4.28.1
tqdm==4.64.1
datasets==2.8.0
pytorch==1.12.1
deepspeed==0.9.1
peft==0.3.0 

一定要使用peft==0.3.0

Todo

  • 增加模型并行和多卡inference

Contact

If you find our work useful, please do not save your star!

If you have any questions, please email [email protected]

chatglm_mutli_gpu_tuning's People

Contributors

cshaitao avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.