Comments (10)
There is a preallocated buffer on the XLA side, but it is not currently passed to the pure_callback
.
@hawkinsp should we have an API for this, wdyt?
from jax.
Hey @Joshuaalbert, sorry for the silence. Here is a quick update
- @yashk2810 is working on a proposal to allow "write once" mutation for
jax.Array
s which are currently immutable (and thus cannot be used for output buffers); - I'm exploring a parallel idea --
dlpack.callback
, a new callback API using DLPack capsules for both inputs and outputs.
from jax.
Another quick update: I prototyped dlpack.callback
, but after discussing it with a few JAX team members, I decided not to move forward with it as JAX has too many callback APIs already.
Instead, the plan is to change existing callback APIs to support mutable_results=
. I am waiting on a few changes in XLA FFI, but once they land, it should be fairly straightforward to implement this.
from jax.
No news yet in a sense that none of my attempts landed. Hopefully, in the coming weeks :)
from jax.
@Joshuaalbert — JAX doesn't currently offer a public API for this (customizing multiple transforms for a single callable), but it's on our radar. I'd say that this is off topic for this issue thread, but feel free to open another with more info about your use cases for using both JVP and VJP for one callback!
from jax.
Also, is it safe to reuse the buffer of input values, e.g. like
def _fn(x):
np.square(x, out=x)
return x
result_shape_dtype = jax.ShapeDtypeStruct(
shape=np.shape(x),
dtype=x.dtype
)
return jax.pure_callback(_fn, result_shape_dtype, x, vectorized=True)
EDIT: That would be a nope as it gives jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: output array is read-only
from jax.
Also, the main reason this would be helpful is because the output buffers for the science applications I'm working on are really large so if there is already one allocated by XLA it would save lots of memory to use that one.
from jax.
@superbobry @hawkinsp any update available for this?
from jax.
Hi @superbobry, any news?
from jax.
@superbobry can you check out JAXbind which offers a way to specify both JVP and VJP for external callbacks? I argue that it should be also possible within JAX to specify both.
from jax.
Related Issues (20)
- Mistake (?) in the "How to think in Jax" doc HOT 2
- ndarray.at.set(mode="drop") gives incorrect value for the last element in the array HOT 4
- partial eval silently skips effects HOT 3
- NaN when computing gradient of squared norm evaluated at 0. HOT 2
- TracerBoolConversionError when jitting jax.numpy.linalg.norm HOT 4
- Allow for `is_leaf` in tree_all
- jax-metal: non-deterministic behavior of `jnp.take_along_axis` HOT 2
- jax.scipy.special.expi extremely slow when applied to a jnp.array HOT 3
- TPU Initialization Failed
- jax-metal on m1 couldn't compile random.PRNGKey properly HOT 1
- jax.numpy.linalg.inv returns spurious results when called with array of matrices HOT 2
- jax.debug.print printing empyt output HOT 1
- PJRT CUDA Plugin and Custom Call extensions HOT 15
- [ROCM] Error: wheel file is invalid HOT 8
- pmax bug on negative numbers on CPU HOT 3
- jax-metal: minimum/maximum/min/max/argmin/argmax invalid output with NaNs
- Tracers are now hashable. HOT 12
- `jit-eval_shape-<callable PyTree with numpy __eq__ semantics>` crashes on JAX 0.4.29
- Passed `stacklevel=2` for the new `tree_flatten` deprecation warning. HOT 4
- jnp fails after enable x64 with Apple M1 chip HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.