GithubHelp home page GithubHelp logo

Comments (4)

yashk2810 avatar yashk2810 commented on June 13, 2024

Try specifying spmd_axis_name to vmap? That argument basically tells vmap on which axis of the mesh to shard the mapped away dimension. Since you only have 1 dimension try: jax.vmap(jax.numpy.fft.fft, spmd_axis_name-'i')(x) with a with_sharding_constraint on jnp.fft.fft

See an example here:

jax/tests/pjit_test.py

Lines 753 to 771 in f5cc272

def testVMapShardingConstraintWithSpmdAxis(self):
f = pjit(
jax.vmap(
lambda x: with_sharding_constraint(x, P(None)),
spmd_axis_name='x',
),
in_shardings=P('x'),
out_shardings=P('x'),
)
x = jnp.arange(16 * 4).reshape((16, 4))
jaxpr = jax.make_jaxpr(f)(x)
pjit_eqn, = jaxpr.eqns
constraint_eqn, = pjit_eqn.params['jaxpr'].eqns
op = constraint_eqn.params['sharding']._to_xla_hlo_sharding(x.ndim)
self.assertTrue(op.is_tiled())
self.assertListEqual(op.tile_assignment_dimensions(), [2, 1])
self.assertListEqual(op.tile_assignment_devices(), [0, 1])
self.assertFalse(op_shardings.is_op_sharding_replicated(op))

from jax.

francois-rozet avatar francois-rozet commented on June 13, 2024

Hello @yashk2810, I tried your recommendation and there is still communication between devices (%all-gather). Besides, adding jax.lax.with_sharding_constraint everywhere the fft is used (which could be deep inside a codebase) might not always be possible.

mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')

distributed = jax.sharding.NamedSharding(mesh, spec)

x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)

def my_fft(x):
    x = jax.lax.with_sharding_constraint(x, jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None)))
    x = jax.numpy.fft.fft(x)
    return x

jit_vmap_fft = jax.jit(jax.vmap(my_fft, spmd_axis_name='i'), in_shardings=distributed, out_shardings=distributed)

print(jit_vmap_fft.lower(x).compile().as_text())
HloModule jit_my_fft, is_scheduled=true, entry_computation_layout={(f32[4,1024]{1,0})->c64[4,1024]{1,0}}, num_partitions=4

%fused_computation (param_1: c64[16,1024], param_1.3: u32[]) -> c64[4,1024] {
  %param_1 = c64[16,1024]{1,0} parameter(0)
  %param_1.3 = u32[] parameter(1)
  %convert.5 = s32[] convert(u32[] %param_1.3), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %constant_6 = s32[] constant(4), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %multiply.3 = s32[] multiply(s32[] %convert.5, s32[] %constant_6), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %constant_5 = s32[] constant(0), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  ROOT %dynamic-slice.3 = c64[4,1024]{1,0} dynamic-slice(c64[16,1024]{1,0} %param_1, s32[] %multiply.3, s32[] %constant_5), dynamic_slice_sizes={4,1024}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}

%fused_computation.1 (param_0.1: f32[4,1024]) -> c64[4,1024] {
  %param_0.1 = f32[4,1024]{1,0} parameter(0)
  ROOT %convert.6 = c64[4,1024]{1,0} convert(f32[4,1024]{1,0} %param_0.1), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/convert_element_type[new_dtype=complex64 weak_type=False]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}

ENTRY %main.13_spmd (param: f32[4,1024]) -> c64[4,1024] {
  %param = f32[4,1024]{1,0} parameter(0), sharding={devices=[4,1]<=[4]}, metadata={op_name="jit(my_fft)/jit(main)/sharding_constraint[sharding=GSPMDSharding({devices=[4,1]<=[4]}) resource_env=ResourceEnv(mesh=Mesh(), ()) unconstrained_dims=set()]" source_file="/tmp/ipykernel_1632380/926359114.py" source_line=10}
  %wrapped_convert.3 = c64[4,1024]{1,0} fusion(f32[4,1024]{1,0} %param), kind=kLoop, calls=%fused_computation.1, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/convert_element_type[new_dtype=complex64 weak_type=False]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %all-gather-start = (c64[4,1024]{1,0}, c64[16,1024]{1,0}) all-gather-start(c64[4,1024]{1,0} %wrapped_convert.3), channel_id=1, replica_groups={{0,1,2,3}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}, backend_config={"is_sync":true,"no_parallel_custom_call":false}
  %all-gather-done = c64[16,1024]{1,0} all-gather-done((c64[4,1024]{1,0}, c64[16,1024]{1,0}) %all-gather-start), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %fft.2 = c64[16,1024]{1,0} fft(c64[16,1024]{1,0} %all-gather-done), fft_type=FFT, fft_length={1024}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  %partition-id = u32[] partition-id(), metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
  ROOT %fusion = c64[4,1024]{1,0} fusion(c64[16,1024]{1,0} %fft.2, u32[] %partition-id), kind=kLoop, calls=%fused_computation, frontend_attributes={fingerprint_before_lhs="0ab18df49405e4dc6b846656f8bd58c2"}, metadata={op_name="jit(my_fft)/jit(main)/vmap(jit(fft))/jit(fft)/fft[fft_type=FftType.FFT fft_lengths=(1024,)]" source_file="/tmp/ipykernel_1632380/1374808900.py" source_line=9}
}

from jax.

francois-rozet avatar francois-rozet commented on June 13, 2024

@yashk2810 In the following example, is there a way to modify the function g, such that it does not lead to communication between devices, regardless of f?

import jax

mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')

distributed = jax.sharding.NamedSharding(mesh, spec)

x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)

def f(x):  # arbitrary function (impossible to modify)
    return jax.numpy.fft.fft(x)

@jax.jit
def g(x):  # my function (can modify)
    y = jax.vmap(f)(x)
    return y

print(g.lower(x).compile().as_text())  # should not contain %all-gather

from jax.

francois-rozet avatar francois-rozet commented on June 13, 2024

Using jax.experimental.shard_map.shard_map prevents device communication, but requires to use (and know) the mesh and spec objects within g.

import jax

mesh = jax.sharding.Mesh(jax.devices(), 'i')
spec = jax.sharding.PartitionSpec('i')

distributed = jax.sharding.NamedSharding(mesh, spec)

x = jax.numpy.ones((16, 1024))
x = jax.device_put(x, distributed)

def f(x):  # arbitrary function (impossible to modify)
    return jax.numpy.fft.fft(x)

@jax.jit
def g(x):
    y = shard_map(
        f=f,
        mesh=mesh,
        in_specs=spec,
        out_specs=spec,
    )(A, x)

    return y

print(g.lower(x).compile().as_text())  # does not contain %all-gather

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.