Comments (11)
Thanks for the suggestion!
But actually we can't support that: the point of io_callback
is that it supports arbitrary Python effects (other than exceptions...), and effects can't be automatically parallelized. That is, vmap
-of-io_callback
is intentionally unsupported, let alone vmap
-of-while_loop
-of-io_callback
.
For example, take this function:
import jax
from jax.experimental import io_callback
x_saved: float
def set_x(x):
global x_saved
x_saved = float(x)
def get_x():
return x_saved
@jax.jit
def f(x):
io_callback(set_x, None, x)
x = io_callback(get_x,
result_shape_dtypes=jax.ShapeDtypeStruct((), 'float32'))
return x
print(f(3.14))
It works fine, and indeed it should, according to io_callback
's semantics! But we can't make it work with vmap
: we have no way to save a batch of intermediates.
Here's another example:
import jax
from jax.experimental import io_callback
x_lst = []
def append_x(x):
global x_lst
x_lst.append(x)
@jax.jit
def f(x):
io_callback(append_x, None, x)
io_callback(append_x, None, 2 * x)
f(3.14)
print(x_lst)
If we run jax.vmap(f)(xs)
, in what order should the results be saved in x_lst
? Even if we defined a vmap ordering for each separate io_callback
, running vmap(f)(xs)
wouldn't have the same side-effect as jnp.stack([f(x) for x in xs]
.
These are toy examples, but you can imagine real applications where we want to understand the ordering, e.g. if we are using io_callback
for offloading of unbounded-size residuals for autodiff.
So the promise of io_callback
, to support arbitrary Python side effects, is too strong to allow vmap
-style vectorization/parallelization. We need more structure from the user.
There is a bug here, which is that we raised a NotImplementedError
: that kinda implies that we would someday implement it, but really we want to raise an exception that this just isn't allowed (and never will be with io_callback
).
If you say a bit more about the problem you're trying to solve, maybe we can come up with some alternative to io_callback
which does support vmap
.
What do you think?
from jax.
This makes a lot of sense, thanks for the explanation!
In my case however it should work because each row of the vmap uses io_callback to dump intermediate data to a different numpy array. This is because our result arrays have unpredictable sizes.
Do you see another way of achieving this? Would it make sense to allow vmap-of-io_callback with an unsafe
flag perhaps?
from jax.
Do you see another way of achieving this? Would it make sense to allow vmap-of-io_callback with an unsafe flag perhaps?
Yeah, I think that's roughly what we'd need: a way to promise extra structure, like "any side-effects here are commutative, so the order doesn't matter." That could be exposed as a different callback API (commtuative_callback
or something) or an option. Since the space of extra structure may be combinatorial, it'd probably be better to have a single API where the user can express whatever structure-promises they want, and that lets the callback work in more places.
Actually, there's an ordered
option on io_callback
now... I wonder if that already does what you want?
from jax.
Hrm this seems pretty busted:
import jax
import jax.numpy as jnp
from jax.experimental import io_callback
x_lst = []
def append_x(x):
global x_lst
x_lst.append(x)
@jax.jit
def f(x):
io_callback(append_x, None, x, ordered=False)
io_callback(append_x, None, 2 * x, ordered=False)
f(3.14)
print(x_lst)
jax.vmap(f)(jnp.arange(3.))
jax.effects_barrier()
print(x_lst)
[Array(3.14, dtype=float32), Array(6.28, dtype=float32)]
[Array(3.14, dtype=float32), Array(6.28, dtype=float32)]
from jax.
Yeah wow, it's because the io_callback batching rule is calling the pure_callback_batching_rule
, but that is getting DCE'd because it claims to be pure...
from jax.
#20725 will fix that issue...
In your vmap
-of-while_loop
-of-io_callback
, is the predicate of your loop batched? If so, there's another issue: when we batch a while_loop
we rely on being able to run, but mask off the results of, loop iterations for elements in the batch that have already finished. That is, when we vmap
a while_loop(cond_fun, body_fun, init_val)
, we generate a computation like
def vmapped_loop(batched_init_val):
def batched_cond(carry):
return vmap(cond_fun)(carry).any() # if any batch elts need more iters, we keep going
def batched_body(carry):
run_step = vmap(cond_fun)(carry)
new_carry = vmap(body_fun)(carry)
return jnp.where(run_step, new_carry, carry)
return while_loop(batched_cond, batched_body, batched_init_val)
That works fine so long as we are okay with running loop iterations for the batch elements that have already terminated (i.e. for which cond_fun produces a False), like when we know body_fun
is functionally pure (indeed noticing that your computer is spending extra time, or generating extra heat, is not a functionally pure thing to do...). But in the presence of side effects, we can't play this trick safely, even if you promise your side-effects are commutative! (We'd need idempotence too, or something.)
That's only an issue if the result of cond_fun ends up being batched though.
WDYT?
from jax.
(The NotImplementedError
currently gets raised even if the predicate is not batched. That's something we could relax.)
from jax.
If your predicate is unbatched, then #20726 should help! If it is batched, we have to figure something else out.
from jax.
Our predicate is batched. But in our case I believe it shouldn't be an issue because we can post-process our result dump to remove the results that have ran too far.
To give you a bit more detail of my usecase, I'm trying to vmap a batch of simulations that produce one sample result per iteration of the while loop. The while loop iterates through time and the condition for stopping a simulation would depend on each simulation's initial settings. So I have an io_callback that dumps a fixed-sized jax array buffer to a numpy array in the while_loop.
from jax.
Have you tried using jax.debug.debug_callback
? It might already be what we need ('no error checks, just like a raw print statement').
from jax.
This seems to do the trick. Thanks!
from jax.
Related Issues (20)
- ndarray.at.set(mode="drop") gives incorrect value for the last element in the array HOT 4
- partial eval silently skips effects HOT 3
- NaN when computing gradient of squared norm evaluated at 0. HOT 2
- TracerBoolConversionError when jitting jax.numpy.linalg.norm HOT 4
- Allow for `is_leaf` in tree_all
- jax-metal: non-deterministic behavior of `jnp.take_along_axis` HOT 2
- jax.scipy.special.expi extremely slow when applied to a jnp.array HOT 3
- TPU Initialization Failed
- jax-metal on m1 couldn't compile random.PRNGKey properly HOT 1
- jax.numpy.linalg.inv returns spurious results when called with array of matrices HOT 2
- jax.debug.print printing empyt output HOT 1
- PJRT CUDA Plugin and Custom Call extensions HOT 15
- [ROCM] Error: wheel file is invalid HOT 8
- pmax bug on negative numbers on CPU HOT 3
- jax-metal: minimum/maximum/min/max/argmin/argmax invalid output with NaNs
- Tracers are now hashable. HOT 12
- `jit-eval_shape-<callable PyTree with numpy __eq__ semantics>` crashes on JAX 0.4.29
- Passed `stacklevel=2` for the new `tree_flatten` deprecation warning. HOT 6
- jnp fails after enable x64 with Apple M1 chip HOT 2
- Unexpected AllReduce in backward pass with shard_map, custom_vjp, and pallas
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.