GithubHelp home page GithubHelp logo

qiaoziqing / bmtrain Goto Github PK

View Code? Open in Web Editor NEW

This project forked from openbmb/bmtrain

0.0 0.0 0.0 2 MB

Efficient Training (including pre-training and fine-tuning) for Big Models

License: Apache License 2.0

Python 73.20% C++ 9.61% Cuda 6.64% C 9.92% Shell 0.31% Dockerfile 0.31%

bmtrain's Introduction

🚄 BMTrain


Documentation Status GitHub release (latest by date including pre-releases) GitHub

1. 安装

From PyPI (recommended)

$ pip install bmtrain

From source

$ git clone https://github.com/OpenBMB/BMTrain.git
$ cd BMTrain
$ python setup.py install

2. 使用

Step 1: 启用 bmtrain

要使用bmtrain需要在代码中引入bmtrain工具包,并在代码的开头使用bmtrain.init_distributed

import bmtrain as bmt
bmt.init_distributed(
    seed=0,
    # ...
)

注意: 使用bmtrain时请不要使用pytorch自带的distributed模块,包括torch.distributed.init_process_group以及相关通信函数。

Step 2: 使用 ZeRO3 优化

使用ZeRO3优化需要对模型代码进行简单替换:

  • torch.nn.Module -> bmtrain.DistributedModule
  • torch.nn.Parameter -> bmtrain.DistributedParameter

并在合适的模块上使用Checkpointing

原始代码:

import torch
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            SomeTransformerBlock(),
            SomeTransformerBlock(),
            SomeTransformerBlock()
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x

替换后代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

Step 3: 通信优化

为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用TransformerBlockList来进一步优化。 在使用时需要对代码进行简单替换:

  • torch.nn.ModuleList -> bmtrain.TransformerBlockList
  • for module in self.module_list: x = module(x, ...) -> x = self.module_list(x, ...)

原始代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = torch.nn.ModuleList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        for module in self.module_list:
            x = module(x, 1, 2, 3)
        return x
    

替换后代码:

import torch
import bmtrain as bmt
class MyModule(bmt.DistributedModule):
    def __init__(self):
        super().__init__()
        self.param = bmt.DistributedParameter(torch.empty(1024))
        self.module_list = bmt.TransformerBlockList([
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock()),
            bmt.CheckpointBlock(SomeTransformerBlock())
        ])
    
    def forward(self):
        x = self.param
        x = self.module_list(x, 1, 2, 3)
        return x
    

Step 4: 运行分布式训练代码

bmtrain支持pytorch原生的分布式训练启动器:

torch.distributed.launch

$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py

torchrun

$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py

更多信息请参考pytorch官方文档:Launch utility

3. 其它说明

BMTrain工具包对pytorch进行了底层修改,如果你的程序输出了意料之外的结果,可以在issue中提交相关信息。

更多例子请参考 examples 文件夹。

bmtrain's People

Contributors

a710128 avatar achazwl avatar shengdinghu avatar koshinryuu avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.