Comments (3)
Style 1
@parallelize
def train_step(data, params) -> grad
...
acc_grad = 0
for num_acc in range(5):
grad = train_step(data, params) # call all-reduce
acc_grad += grad
We needs to support a new sharding specification both in XLA and Jax: PartialResult
Style 2
@parallelize
def train_step(data, params) -> grad
acc_grad = 0
for num_acc in range(5):
grad = train_step(data, params)
acc_grad += grad
XLA needs to simplify allreduce(x) + allreduce(y)
to allreduce(x+y)
Style 3
@parallelize
def forward_backward_step(data, params) -> grad
pass
@parallelize
def gradient_update_step(grad, params) -> new_params
pass
acc_grad = 0
for num_acc in range(5):
grad = forward_backward_step(data, params) # call all-reduce
acc_grad += grad
params = gradient_update_step(acc_grad, params)
We need to change our interface
from alpa.
Backgroud
To enable efficient gradient accumulation in both SPMD and pipeline parallelism. We must generate at least two XLA executables: accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad
and apply_grad(state, grad) -> new_state
In SPMD mode, they are used as follows
micro_batches = split(batch, num_micro_batches)
acc_grad = 0
for i in range(num_micro_batches):
sync = (i == num_micro_batches - 1) # Whether to use all-reduce to sync the local gradients
acc_grad = accumulate_grad(state, micro_batches[i], acc_grad, sync)
apply_grad(state, acc_grad)
In Pipeline mode, they are used as follows
# For each worker
acc_grad = 0
for i in range(num_micro_batches):
cur_micro_batch = recv_micro_batch_from_previsou_stage()
sync = (i == num_micro_batches - 1) # Whether to use all-reduce to sync the local gradients
acc_grad = accumulate_grad(state, cur_micro_batch, acc_grad, do_sync)
apply_grad(state, acc_grad)
Discussions
How to get these two functions?
- A1: Force users to provide two functions.
- A1.1: Users define
We then derive
compute_grad(state, batch) -> grad apply_grad(state, grad) -> new_state
accumulate_grad
fromcompute_grad
- A1.2: Users define
We then derive
forward(state, batch) -> loss apply_grad(state, grad) -> new_state
accumulate_grad
fromforward
. This method is less general but works better for pipeline_marker. - Pros: By forcing users to provide these functions, we can make fewer assumptions and guesses.
- Cons: Not compatible with existing jax/flax programs
- A1.1: Users define
- A2: Use
parax.grad
to replacejax.grad
:@parallelize def func(optimizer, batch): def loss_func(params): return ... grads = parax.grad(loss_func)(optimizer.target) new_optimizer = optimizer.apply_gradient(grads) return new_optimizer
parax.grad
inserts a separator after gradient computation. We can reuse pipeline_marker for this separator. This separator partitions the original jaxpr intocompute_grad
andapply_grad
. We then deriveaccumulate_grad
fromcompute_grad
. - A3: We use static analysis (batch dimension propagation) to separate a whole computational graph into
compute_grad
,apply_grad
- Pros: Compatible with existing jax/flax programs
- Cons: Does not work for pipeline_marker. The static analysis is hard to be robust.
Where should the accumulation loop be?
- B1: Put the loop in HLO
- Pros
- Less python runtime overhead
- Cons
- Does not work in pipeline mode due to our current existing implementation
- Has to deal with while loop in HLO passes
- Pros
- B2: Put the loop in our python runtime
The pros and cons are the opposites of the above ones.
How to enable in-place updates?
The two functions accumulate_grad(state, micro_batch, old_grad, sync) -> new_grad
and apply_grad(state, grad) -> new_state
should be compiled with the following constraints:
old_grad
,new_grad
and their corresponding parameters instate
should share the same sharding spec.old_grad
andnew_grad
should share the same memory location.new_state
andstate
should share the same memory location and the sharding spec.
To share the same memory location, we have to set donate_invars (or alias) in XLA.
To share the same sharding spec, we have to pass these constraints to the ILP solver.
How to handle the sync
argument of accumulate_grad
?
- C1: Compile two versions of
accumulate_grad
, one does sync and the other does not - C2: Compile one executable and use two branches in XLA.
- C3: Compile two executables:
accmulate_grad
andsync_grad
, we then dispatched them in our python runtime loop. However, this makes it impossible to overlap all-reduce and computation.
How to handle other optimizations?
We want the memory of gradients to be continuous. This can benefit all-reduce.
from alpa.
from alpa.
Related Issues (20)
- Will alpa support jax 0.4.x and cuda 12.x?
- cupy package mismatches with CUDA version in the docs HOT 2
- Unable to use pipeline parallelism with multi-node meshes HOT 1
- PLS, a paper related question I want to ask HOT 1
- Question abuot licence / usage HOT 1
- Problem in building Alpa-modified Jaxlib. HOT 5
- IndexError: `InlinedVector::at(size_type) const` failed bounds check
- Check failed: operand_dim < ins->operand(0)->shape().rank() (2 vs. 2)Does not support this kind of Gather. HOT 2
- How to build debug-version Alpa-modified jaxlib HOT 3
- when i check installation by running python3 -m alpa.test_install,AssertionError happend HOT 6
- Unsupported parallel mode in shard-only auto perf test: load_solution
- How to use Alpa to serve BERT models
- Error about python3 -m alpa.test_install
- A question about file /alpa/benchmark/gen_serving_database.py
- Any solution to support llama2 finetune?
- Why did you choose ray instead of using torch distributed? HOT 2
- Ray spill out of disk error when using alpa to auto-parallelize llama HOT 2
- [Bug] Segment fault when using alpa to parallelize llama with jax 0.4.6 environment HOT 2
- How to profile Alpa models and get the trace HOT 1
- Check Installation failled HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from alpa.