GithubHelp home page GithubHelp logo

Comments (8)

jakevdp avatar jakevdp commented on July 17, 2024 1

Here's an example of how this might look using the array API:

import jax.experimental.array_api # side-effecting import required now, but won't be needed in the future

import jax.numpy as jnp


class DiagonalArrayNamespace:
  @staticmethod
  def exp(x):
    return DiagonalArray(jnp.exp(x.diagonal))


class DiagonalArray:
  def __init__(self, diagonal):
      self.diagonal = jnp.asarray(diagonal)

  def __array_namespace__(self):
    return DiagonalArrayNamespace()

def func(x):
  xp = x.__array_namespace__()
  return xp.exp(x)

x = jnp.arange(4)
print(func(x))
# [ 1.         2.7182817  7.389056  20.085537 ]

d = DiagonalArray(x)
print(func(d))
# <__main__.DiagonalArray object at 0x7a3b784592a0>

The benefit of this approach is that func will be compatible with any type that implements the array API. Additionally, as long as you implement the relevant parts of the namespace for your type, it could be used directly with other packages that are array API aware, such as future versions of scipy and scikit-learn.

We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array) which would return either a JAX array or a custom class array (both are fine).

We do not have any plans to support this kind of custom overloading of the jax.numpy API. Instead, you should write your code to use the Array API, and pass it types that provide __array_namespace__, such as JAX arrays, numpy arrays, or your your own custom arrays.

from jax.

patrick-kidger avatar patrick-kidger commented on July 17, 2024 1

I'll offer one other option here, which is Quax. This allows you to define array-ish types and then perform dispatch on them. One of the examples we have is for LoRA.

This basically is a direct solution to (2): define how your custom type interacts with each JAX primitive and in principle you have compatibility with arbitrary JAX code.
However the downsides are (a) needing to define how this works with every primitive you want to interact with, and (b) it's JAX-only: no NumPy/PyTorch/etc. So no free lunch I guess.

from jax.

jakevdp avatar jakevdp commented on July 17, 2024

Thanks for the question!

We've talked about this, but have not implemented any support for __array_ufunc__ or other custom dispatch mechanisms. The thinking is that it adds too much complexity and indirection to the API, and for the user makes it harder to reason about what a particular line of code might be doing.

Instead, I'd suggest going the Array API route – for example, you can currently do from jax.experimental import array_api as xp, and then get an array API namespace that works on JAX arrays. Similarly, you could create an xp namespace that implements the numpy namespace you need for your own types, and then write your code in a way that dispatches to the correct namespace.

This is the solution that the numpy community landed on after experimenting with __array_ufunc__, __array_function__, and other dispatch mechanisms, and finding them difficult to support. Libraries like scipy and sklearn are now working on array API support within their own APIs, so doing this would make your custom dispatch compatible with libraries beyond JAX as well.

What do you think?

from jax.

gautierronan avatar gautierronan commented on July 17, 2024

Thanks for your quick answer! I have been looking at the Array API documentation you linked, but it's not so clear to me how to achieve the exemple above using this.

Is there any exemple of usage of jax.experimental.array_api? Or would you have a short implementation of my usecase above?

Our general need is quite simple. We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array) which would return either a JAX array or a custom class array (both are fine).

from jax.

gautierronan avatar gautierronan commented on July 17, 2024

Thanks for the exemple @jakevdp, it's very useful.

Indeed, it seems there are two different usecases:

  1. writing custom functions (e.g. fancy_exp) that accept arrays from any library (JAX, NumPy, PyTorch, scipy, ...)
  2. writing custom arrays that are compatible with functions from any library (jnp.exp, np.exp, torch.exp, ...), and that can override these certain of these functions when useful (e.g. because of speedup)

Your solution indeed supports 1, but we'd rather be looking for 2. What is the reasoning behind not wanting to support 2? NumPy seems to be allowing it.

from jax.

jakevdp avatar jakevdp commented on July 17, 2024

(2) will never be feasible in the long term. The numpy team tried many approaches over the years to make this work (starting with python operator dispatch & __array__, then trying dispatching numpy functions to object methods (this is why e.g. np.reshape(x) first tries x.reshape()) then trying __array_function__, then __array_ufunc__, etc.) and in the end found that none of these approaches were sufficient in practice. The partialy-implemented mechanisms still exist in the codebase for backward compatibility, but are not a mechanism that the NumPy team recommends.

The problem is that the numpy API has no formal specification (there are many strange, underspecified implementation details that may be dependend upon in the logic of any particular library that uses it), and this makes the above approaches too brittle to be used broadly and depended upon by the larger ecosystem. Further, it makes it much more difficult to reason about what a particular piece of code is doing when it may be opaquely dispatching to arbitrary implementations defined by the object you pass to it.

The solution in the end was to define a well-specified subset of the numpy API for which downstream libraries can implement the full semantics in a well-defined and transparent way: this is the Array API, and that is the solution that numpy developers will point you to if you ask them the same question today.

For JAX, we want to learn from the experience of the NumPy team and just implement the final, working solution rather than one of the many half-working false starts that eventually led to that working solution.

from jax.

gautierronan avatar gautierronan commented on July 17, 2024

Ok that makes sense. Thanks a lot for your detailed answer.

from jax.

jakevdp avatar jakevdp commented on July 17, 2024

I'd add that (1) is probably the best way forward for what you have in mind too: if the downstream tools that you use adapt their implementations to make use of the Array API standard, then you actually have a hope of supporting this set of functionality for your own custom type. Of course, this requires downstream implementations to change, but without that I don't think you'd ever land on a robust solution to implementing all the corner cases of these similar-but-different APIs for your own type.

from jax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.