Comments (4)
Try specifying spmd_axis_name
to vmap
? That argument basically tells vmap on which axis of the mesh to shard the mapped away dimension. Since you only have 1 dimension try: jax.vmap(jax.numpy.fft.fft, spmd_axis_name-'i')(x)
with a with_sharding_constraint on jnp.fft.fft
See an example here:
Lines 753 to 771 in f5cc272
from jax.
Hello @yashk2810, I tried your recommendation and there is still communication between devices (%all-gather
). Besides, adding jax.lax.with_sharding_constraint
everywhere the fft
is used (which could be deep inside a codebase) might not always be possible.
mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')
distributed = jax.sharding.NamedSharding(mesh, spec)
x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)
def my_fft(x):
x = jax.lax.with_sharding_constraint(x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)))
x = jax.numpy.fft.fft(x)
return x
jit_vmap_fft = jax.jit(jax.vmap(my_fft, spmd_axis_name='i'), in_shardings=distributed, out_shardings=distributed)
print(jit_vmap_fft.lower(x).compile().as_text())
HloModule jit_my_fft, is_scheduled=true, entry_computation_layout={(f32[4,1024]{1,0})->c64[4,1024]{1,0}}, num_partitions=4
%fused_computation (param_1: c64[16,1024], param_1.3: u32[]) -> c64[4,1024] {
%param_1 = c64[16,1024]{1,0} parameter(0)
%param_1.3 = u32[] parameter(1)
%convert.5 = s32[] convert(u32[] %param_1.3), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%constant_6 = s32[] constant(4), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%multiply.3 = s32[] multiply(s32[] %convert.5, s32[] %constant_6), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%constant_5 = s32[] constant(0), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
ROOT %dynamic-slice.3 = c64[4,1024]{1,0} dynamic-slice(c64[16,1024]{1,0} %param_1, s32[] %multiply.3, s32[] %constant_5), dynamic_slice_sizes={4,1024}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}
%fused_computation.1 (param_0.1: f32[4,1024]) -> c64[4,1024] {
%param_0.1 = f32[4,1024]{1,0} parameter(0)
ROOT %convert.6 = c64[4,1024]{1,0} convert(f32[4,1024]{1,0} %param_0.1), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/convert_element_type[new_dtype=complex64 weak_type=False]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}
ENTRY %main.13_spmd (param: f32[4,1024]) -> c64[4,1024] {
%param = f32[4,1024]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="jit(my_fft)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[4,1]<=[4]}) resource_env=ResourceEnv(mesh=Mesh(), ()) unconstrained_dims=set()]" source_file="/tmp/ipykernel_1632380/926359114.py" source_line=10}
%wrapped_convert.3 = c64[4,1024]{1,0} fusion(f32[4,1024]{1,0} %param), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/convert_element_type[new_dtype=complex64 weak_type=False]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%all-gather-start = (c64[4,1024]{1,0}, c64[16,1024]{1,0}) all-gather-start(c64[4,1024]{1,0} %wrapped_convert.3), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}, backend_config={"is_sync":true,"no_parallel_custom_call":false}
%all-gather-done = c64[16,1024]{1,0} all-gather-done((c64[4,1024]{1,0}, c64[16,1024]{1,0}) %all-gather-start), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%fft.2 = c64[16,1024]{1,0} fft(c64[16,1024]{1,0} %all-gather-done), fft_type=FFT, fft_length={1024}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
%partition-id = u32[] partition-id(), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
ROOT %fusion = c64[4,1024]{1,0} fusion(c64[16,1024]{1,0} %fft.2, u32[] %partition-id), kind=kLoop, calls=%fused_computation, frontend_attributes={fingerprint_before_lhs="0ab18df49405e4dc6b846656f8bd58c2"}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}
from jax.
@yashk2810 In the following example, is there a way to modify the function g
, such that it does not lead to communication between devices, regardless of f
?
import jax
mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')
distributed = jax.sharding.NamedSharding(mesh, spec)
x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)
def f(x): # arbitrary function (impossible to modify)
return jax.numpy.fft.fft(x)
@jax.jit
def g(x): # my function (can modify)
y = jax.vmap(f)(x)
return y
print(g.lower(x).compile().as_text()) # should not contain %all-gather
from jax.
Using jax.experimental.shard_map.shard_map
prevents device communication, but requires to use (and know) the mesh
and spec
objects within g
.
import jax
mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')
distributed = jax.sharding.NamedSharding(mesh, spec)
x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)
def f(x): # arbitrary function (impossible to modify)
return jax.numpy.fft.fft(x)
@jax.jit
def g(x):
y = shard_map(
f=f,
mesh=mesh,
in_specs=spec,
out_specs=spec,
)(A, x)
return y
print(g.lower(x).compile().as_text()) # does not contain %all-gather
from jax.
Related Issues (20)
- ⚠️ Nightly upstream-dev CI failed ⚠️ HOT 1
- jnp.fft.ifft imprecision for GPU
- psum_scatter does not allow scatter_dimension to be negative HOT 3
- spsolve exits with error when inverting matrix sum HOT 4
- jax.random seems to have unnecessary buffer allocations on stack HOT 6
- buggy interaction: remat, automatic partitioning, and unsafe `rbg`-based RNGs
- Seeking guidance for landing spot of `scipy.stats.levy_stable` in Jax
- dynamic config scope under `jit` doesn't change partitionable threefry behavior
- Unexpected speedup from wrapping function call in trivial jax.lax.cond statement
- Persistent compilation cache does not work HOT 2
- ROCm 6.1, 7900 xtx: bfloat16 support not enabled? HOT 1
- Remaining deprecations for array API compliance
- Crash in `eval_jaxpr` with 0.4.27 HOT 17
- '+ptx84' is not a recognized feature for this target (ignoring feature) HOT 13
- linalg.solve produces NaNs on GPU, but not on CPU
- [pallas] Interpreter mismatch for masked OOB indexing
- Strided indexing turns into gather HOT 3
- Compilation cache does not work with custom partitioning HOT 2
- Cannot pass token to custom primitive when using explicit device placement HOT 2
- Frequent Segfault crashes with v0.4.28 HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.