GithubHelp home page GithubHelp logo

Comments (10)

superbobry avatar superbobry commented on June 13, 2024 2

There is a preallocated buffer on the XLA side, but it is not currently passed to the pure_callback.

@hawkinsp should we have an API for this, wdyt?

from jax.

superbobry avatar superbobry commented on June 13, 2024 2

Hey @Joshuaalbert, sorry for the silence. Here is a quick update

  • @yashk2810 is working on a proposal to allow "write once" mutation for jax.Arrays which are currently immutable (and thus cannot be used for output buffers);
  • I'm exploring a parallel idea -- dlpack.callback, a new callback API using DLPack capsules for both inputs and outputs.

from jax.

superbobry avatar superbobry commented on June 13, 2024 1

Another quick update: I prototyped dlpack.callback, but after discussing it with a few JAX team members, I decided not to move forward with it as JAX has too many callback APIs already.

Instead, the plan is to change existing callback APIs to support mutable_results=. I am waiting on a few changes in XLA FFI, but once they land, it should be fairly straightforward to implement this.

from jax.

superbobry avatar superbobry commented on June 13, 2024 1

No news yet in a sense that none of my attempts landed. Hopefully, in the coming weeks :)

from jax.

dfm avatar dfm commented on June 13, 2024 1

@Joshuaalbert — JAX doesn't currently offer a public API for this (customizing multiple transforms for a single callable), but it's on our radar. I'd say that this is off topic for this issue thread, but feel free to open another with more info about your use cases for using both JVP and VJP for one callback!

from jax.

Joshuaalbert avatar Joshuaalbert commented on June 13, 2024

Also, is it safe to reuse the buffer of input values, e.g. like

def _fn(x):
    np.square(x, out=x)
    return x

result_shape_dtype = jax.ShapeDtypeStruct(
    shape=np.shape(x),
    dtype=x.dtype
)

return jax.pure_callback(_fn, result_shape_dtype, x, vectorized=True)

EDIT: That would be a nope as it gives jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: output array is read-only

from jax.

Joshuaalbert avatar Joshuaalbert commented on June 13, 2024

Also, the main reason this would be helpful is because the output buffers for the science applications I'm working on are really large so if there is already one allocated by XLA it would save lots of memory to use that one.

from jax.

Joshuaalbert avatar Joshuaalbert commented on June 13, 2024

@superbobry @hawkinsp any update available for this?

from jax.

Joshuaalbert avatar Joshuaalbert commented on June 13, 2024

Hi @superbobry, any news?

from jax.

Joshuaalbert avatar Joshuaalbert commented on June 13, 2024

@superbobry can you check out JAXbind which offers a way to specify both JVP and VJP for external callbacks? I argue that it should be also possible within JAX to specify both.

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.