GithubHelp home page GithubHelp logo

Support IO effect in vmap-of-while. about jax HOT 11 OPEN

achsvg avatar achsvg commented on June 14, 2024
Support IO effect in vmap-of-while.

from jax.

Comments (11)

mattjj avatar mattjj commented on June 14, 2024

Thanks for the suggestion!

But actually we can't support that: the point of io_callback is that it supports arbitrary Python effects (other than exceptions...), and effects can't be automatically parallelized. That is, vmap-of-io_callback is intentionally unsupported, let alone vmap-of-while_loop-of-io_callback.

For example, take this function:

import jax
from jax.experimental import io_callback

x_saved: float
def set_x(x):
  global x_saved
  x_saved = float(x)
def get_x():
  return x_saved

@jax.jit
def f(x):
  io_callback(set_x, None, x)
  x = io_callback(get_x,
                  result_shape_dtypes=jax.ShapeDtypeStruct((), 'float32'))
  return x

print(f(3.14))

It works fine, and indeed it should, according to io_callback's semantics! But we can't make it work with vmap: we have no way to save a batch of intermediates.

Here's another example:

import jax
from jax.experimental import io_callback

x_lst = []
def append_x(x):
  global x_lst
  x_lst.append(x)

@jax.jit
def f(x):
  io_callback(append_x, None, x)
  io_callback(append_x, None, 2 * x)

f(3.14)
print(x_lst)

If we run jax.vmap(f)(xs), in what order should the results be saved in x_lst? Even if we defined a vmap ordering for each separate io_callback, running vmap(f)(xs) wouldn't have the same side-effect as jnp.stack([f(x) for x in xs].

These are toy examples, but you can imagine real applications where we want to understand the ordering, e.g. if we are using io_callback for offloading of unbounded-size residuals for autodiff.

So the promise of io_callback, to support arbitrary Python side effects, is too strong to allow vmap-style vectorization/parallelization. We need more structure from the user.

There is a bug here, which is that we raised a NotImplementedError: that kinda implies that we would someday implement it, but really we want to raise an exception that this just isn't allowed (and never will be with io_callback).

If you say a bit more about the problem you're trying to solve, maybe we can come up with some alternative to io_callback which does support vmap.

What do you think?

from jax.

achsvg avatar achsvg commented on June 14, 2024

This makes a lot of sense, thanks for the explanation!

In my case however it should work because each row of the vmap uses io_callback to dump intermediate data to a different numpy array. This is because our result arrays have unpredictable sizes.

Do you see another way of achieving this? Would it make sense to allow vmap-of-io_callback with an unsafe flag perhaps?

from jax.

mattjj avatar mattjj commented on June 14, 2024

Do you see another way of achieving this? Would it make sense to allow vmap-of-io_callback with an unsafe flag perhaps?

Yeah, I think that's roughly what we'd need: a way to promise extra structure, like "any side-effects here are commutative, so the order doesn't matter." That could be exposed as a different callback API (commtuative_callback or something) or an option. Since the space of extra structure may be combinatorial, it'd probably be better to have a single API where the user can express whatever structure-promises they want, and that lets the callback work in more places.

Actually, there's an ordered option on io_callback now... I wonder if that already does what you want?

from jax.

mattjj avatar mattjj commented on June 14, 2024

Hrm this seems pretty busted:

import jax
import jax.numpy as jnp
from jax.experimental import io_callback

x_lst = []
def append_x(x):
  global x_lst
  x_lst.append(x)

@jax.jit
def f(x):
  io_callback(append_x, None, x, ordered=False)
  io_callback(append_x, None, 2 * x, ordered=False)

f(3.14)
print(x_lst)
jax.vmap(f)(jnp.arange(3.))
jax.effects_barrier()
print(x_lst)
[Array(3.14, dtype=float32), Array(6.28, dtype=float32)]
[Array(3.14, dtype=float32), Array(6.28, dtype=float32)]

from jax.

mattjj avatar mattjj commented on June 14, 2024

Yeah wow, it's because the io_callback batching rule is calling the pure_callback_batching_rule, but that is getting DCE'd because it claims to be pure...

from jax.

mattjj avatar mattjj commented on June 14, 2024

#20725 will fix that issue...

In your vmap-of-while_loop-of-io_callback, is the predicate of your loop batched? If so, there's another issue: when we batch a while_loop we rely on being able to run, but mask off the results of, loop iterations for elements in the batch that have already finished. That is, when we vmap a while_loop(cond_fun, body_fun, init_val), we generate a computation like

def vmapped_loop(batched_init_val):
  def batched_cond(carry):
    return vmap(cond_fun)(carry).any()  # if any batch elts need more iters, we keep going

  def batched_body(carry):
    run_step = vmap(cond_fun)(carry)
    new_carry = vmap(body_fun)(carry)
    return jnp.where(run_step, new_carry, carry)

return while_loop(batched_cond, batched_body, batched_init_val)

That works fine so long as we are okay with running loop iterations for the batch elements that have already terminated (i.e. for which cond_fun produces a False), like when we know body_fun is functionally pure (indeed noticing that your computer is spending extra time, or generating extra heat, is not a functionally pure thing to do...). But in the presence of side effects, we can't play this trick safely, even if you promise your side-effects are commutative! (We'd need idempotence too, or something.)

That's only an issue if the result of cond_fun ends up being batched though.

WDYT?

from jax.

mattjj avatar mattjj commented on June 14, 2024

(The NotImplementedError currently gets raised even if the predicate is not batched. That's something we could relax.)

from jax.

mattjj avatar mattjj commented on June 14, 2024

If your predicate is unbatched, then #20726 should help! If it is batched, we have to figure something else out.

from jax.

achsvg avatar achsvg commented on June 14, 2024

Our predicate is batched. But in our case I believe it shouldn't be an issue because we can post-process our result dump to remove the results that have ran too far.
To give you a bit more detail of my usecase, I'm trying to vmap a batch of simulations that produce one sample result per iteration of the while loop. The while loop iterates through time and the condition for stopping a simulation would depend on each simulation's initial settings. So I have an io_callback that dumps a fixed-sized jax array buffer to a numpy array in the while_loop.

from jax.

mattjj avatar mattjj commented on June 14, 2024

Have you tried using jax.debug.debug_callback? It might already be what we need ('no error checks, just like a raw print statement').

from jax.

achsvg avatar achsvg commented on June 14, 2024

This seems to do the trick. Thanks!

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.