GithubHelp home page GithubHelp logo

Gradient Accumulation about alpa HOT 3 CLOSED

alpa-projects avatar alpa-projects commented on July 26, 2024
Gradient Accumulation

from alpa.

Comments (3)

merrymercy avatar merrymercy commented on July 26, 2024

Style 1

@parallelize
def train_step(data, params) -> grad
    ...

acc_grad = 0
for num_acc in range(5):
    grad = train_step(data, params)   # call all-reduce
    acc_grad += grad

We needs to support a new sharding specification both in XLA and Jax: PartialResult

Style 2

@parallelize
def train_step(data, params) -> grad
    acc_grad = 0
    for num_acc in range(5):
        grad = train_step(data, params)
        acc_grad += grad

XLA needs to simplify allreduce(x) + allreduce(y) to allreduce(x+y)

Style 3

@parallelize
def forward_backward_step(data, params) -> grad
    pass

@parallelize
def gradient_update_step(grad, params) -> new_params
    pass

acc_grad = 0
for num_acc in range(5):
    grad = forward_backward_step(data, params)   # call all-reduce
    acc_grad += grad
params = gradient_update_step(acc_grad, params)

We need to change our interface

from alpa.

merrymercy avatar merrymercy commented on July 26, 2024

Backgroud

To enable efficient gradient accumulation in both SPMD and pipeline parallelism. We must generate at least two XLA executables: accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad and apply_grad(state, grad) -> new_state

In SPMD mode, they are used as follows

micro_batches = split(batch, num_micro_batches)
acc_grad = 0

for i in range(num_micro_batches):
    sync = (i == num_micro_batches - 1)  # Whether to use all-reduce to sync the local gradients
    acc_grad = accumulate_grad(state, micro_batches[i], acc_grad, sync)

apply_grad(state, acc_grad)

In Pipeline mode, they are used as follows

# For each worker
acc_grad = 0

for i in range(num_micro_batches):
    cur_micro_batch = recv_micro_batch_from_previsou_stage()
    sync = (i == num_micro_batches - 1)  # Whether to use all-reduce to sync the local gradients
    acc_grad = accumulate_grad(state, cur_micro_batch, acc_grad, do_sync)

apply_grad(state, acc_grad)

Discussions

How to get these two functions?

  • A1: Force users to provide two functions.
    • A1.1: Users define
      compute_grad(state, batch) -> grad
      apply_grad(state, grad) -> new_state
      We then derive accumulate_grad from compute_grad
    • A1.2: Users define
      forward(state, batch) -> loss
      apply_grad(state, grad) -> new_state
      We then derive accumulate_grad from forward. This method is less general but works better for pipeline_marker.
    • Pros: By forcing users to provide these functions, we can make fewer assumptions and guesses.
    • Cons: Not compatible with existing jax/flax programs
  • A2: Use parax.grad to replace jax.grad:
    @parallelize
    def func(optimizer, batch):
         def loss_func(params):
              return ...
    
        grads = parax.grad(loss_func)(optimizer.target)
        new_optimizer = optimizer.apply_gradient(grads)
        return new_optimizer
    parax.grad inserts a separator after gradient computation. We can reuse pipeline_marker for this separator. This separator partitions the original jaxpr into compute_grad and apply_grad. We then derive accumulate_grad from compute_grad.
  • A3: We use static analysis (batch dimension propagation) to separate a whole computational graph into compute_grad, apply_grad
    • Pros: Compatible with existing jax/flax programs
    • Cons: Does not work for pipeline_marker. The static analysis is hard to be robust.

Where should the accumulation loop be?

  • B1: Put the loop in HLO
    • Pros
      • Less python runtime overhead
    • Cons
      • Does not work in pipeline mode due to our current existing implementation
      • Has to deal with while loop in HLO passes
  • B2: Put the loop in our python runtime
    The pros and cons are the opposites of the above ones.

How to enable in-place updates?

The two functions accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad and apply_grad(state, grad) -> new_state should be compiled with the following constraints:

  1. old_grad, new_grad and their corresponding parameters in state should share the same sharding spec.
  2. old_grad and new_grad should share the same memory location.
  3. new_state and state should share the same memory location and the sharding spec.

To share the same memory location, we have to set donate_invars (or alias) in XLA.
To share the same sharding spec, we have to pass these constraints to the ILP solver.

How to handle the sync argument of accumulate_grad?

  • C1: Compile two versions of accumulate_grad, one does sync and the other does not
  • C2: Compile one executable and use two branches in XLA.
  • C3: Compile two executables: accmulate_grad and sync_grad, we then dispatched them in our python runtime loop. However, this makes it impossible to overlap all-reduce and computation.

How to handle other optimizations?

We want the memory of gradients to be continuous. This can benefit all-reduce.

from alpa.

zhuohan123 avatar zhuohan123 commented on July 26, 2024

#87 #90

from alpa.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.