Comments (8)
Here's an example of how this might look using the array API:
import jax.experimental.array_api # side-effecting import required now, but won't be needed in the future
import jax.numpy as jnp
class DiagonalArrayNamespace:
@staticmethod
def exp(x):
return DiagonalArray(jnp.exp(x.diagonal))
class DiagonalArray:
def __init__(self, diagonal):
self.diagonal = jnp.asarray(diagonal)
def __array_namespace__(self):
return DiagonalArrayNamespace()
def func(x):
xp = x.__array_namespace__()
return xp.exp(x)
x = jnp.arange(4)
print(func(x))
# [ 1. 2.7182817 7.389056 20.085537 ]
d = DiagonalArray(x)
print(func(d))
# <__main__.DiagonalArray object at 0x7a3b784592a0>
The benefit of this approach is that func
will be compatible with any type that implements the array API. Additionally, as long as you implement the relevant parts of the namespace for your type, it could be used directly with other packages that are array API aware, such as future versions of scipy and scikit-learn.
We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array) which would return either a JAX array or a custom class array (both are fine).
We do not have any plans to support this kind of custom overloading of the jax.numpy
API. Instead, you should write your code to use the Array API, and pass it types that provide __array_namespace__
, such as JAX arrays, numpy arrays, or your your own custom arrays.
from jax.
I'll offer one other option here, which is Quax. This allows you to define array-ish types and then perform dispatch on them. One of the examples we have is for LoRA.
This basically is a direct solution to (2): define how your custom type interacts with each JAX primitive and in principle you have compatibility with arbitrary JAX code.
However the downsides are (a) needing to define how this works with every primitive you want to interact with, and (b) it's JAX-only: no NumPy/PyTorch/etc. So no free lunch I guess.
from jax.
Thanks for the question!
We've talked about this, but have not implemented any support for __array_ufunc__
or other custom dispatch mechanisms. The thinking is that it adds too much complexity and indirection to the API, and for the user makes it harder to reason about what a particular line of code might be doing.
Instead, I'd suggest going the Array API route – for example, you can currently do from jax.experimental import array_api as xp
, and then get an array API namespace that works on JAX arrays. Similarly, you could create an xp
namespace that implements the numpy namespace you need for your own types, and then write your code in a way that dispatches to the correct namespace.
This is the solution that the numpy community landed on after experimenting with __array_ufunc__
, __array_function__
, and other dispatch mechanisms, and finding them difficult to support. Libraries like scipy
and sklearn
are now working on array API support within their own APIs, so doing this would make your custom dispatch compatible with libraries beyond JAX as well.
What do you think?
from jax.
Thanks for your quick answer! I have been looking at the Array API documentation you linked, but it's not so clear to me how to achieve the exemple above using this.
Is there any exemple of usage of jax.experimental.array_api
? Or would you have a short implementation of my usecase above?
Our general need is quite simple. We have a custom class that implements a fancy JAX array, and would like to keep compatibility with the rest of JAX/NumPy/etc., i.e. allow calling jnp.any_function(my_custom_array)
which would return either a JAX array or a custom class array (both are fine).
from jax.
Thanks for the exemple @jakevdp, it's very useful.
Indeed, it seems there are two different usecases:
- writing custom functions (e.g.
fancy_exp
) that accept arrays from any library (JAX, NumPy, PyTorch, scipy, ...) - writing custom arrays that are compatible with functions from any library (
jnp.exp
,np.exp
,torch.exp
, ...), and that can override these certain of these functions when useful (e.g. because of speedup)
Your solution indeed supports 1, but we'd rather be looking for 2. What is the reasoning behind not wanting to support 2? NumPy seems to be allowing it.
from jax.
(2) will never be feasible in the long term. The numpy team tried many approaches over the years to make this work (starting with python operator dispatch & __array__
, then trying dispatching numpy
functions to object methods (this is why e.g. np.reshape(x)
first tries x.reshape()
) then trying __array_function__
, then __array_ufunc__
, etc.) and in the end found that none of these approaches were sufficient in practice. The partialy-implemented mechanisms still exist in the codebase for backward compatibility, but are not a mechanism that the NumPy team recommends.
The problem is that the numpy API has no formal specification (there are many strange, underspecified implementation details that may be dependend upon in the logic of any particular library that uses it), and this makes the above approaches too brittle to be used broadly and depended upon by the larger ecosystem. Further, it makes it much more difficult to reason about what a particular piece of code is doing when it may be opaquely dispatching to arbitrary implementations defined by the object you pass to it.
The solution in the end was to define a well-specified subset of the numpy API for which downstream libraries can implement the full semantics in a well-defined and transparent way: this is the Array API, and that is the solution that numpy developers will point you to if you ask them the same question today.
For JAX, we want to learn from the experience of the NumPy team and just implement the final, working solution rather than one of the many half-working false starts that eventually led to that working solution.
from jax.
Ok that makes sense. Thanks a lot for your detailed answer.
from jax.
I'd add that (1) is probably the best way forward for what you have in mind too: if the downstream tools that you use adapt their implementations to make use of the Array API standard, then you actually have a hope of supporting this set of functionality for your own custom type. Of course, this requires downstream implementations to change, but without that I don't think you'd ever land on a robust solution to implementing all the corner cases of these similar-but-different APIs for your own type.
from jax.
Related Issues (20)
- JAX JIT Compilation Issue with jnp.std and Configuration Setting HOT 2
- FP8 XLA matmul fusion into `__cublas$lt$matmul$f8` not fully working HOT 2
- guidance on how to properly use lax.reduce on an multi-dimension array. HOT 1
- Partial of a vmap function HOT 2
- Support dlpack pinned host memory device type (kDLCPUPinned / kDLCUDAHost)
- Issue with shardmap when trying to use static arguments HOT 6
- TPUVM program hangs when using jax_compilation_cache_dir with wrong gcs directory HOT 8
- Getting different results from CPU vs. CUDA backend HOT 3
- Accidentally exponential-time analysis pass in compiler? HOT 9
- Local shards accessed using devices_indices_map are scaled by # of pods HOT 2
- Incorrect Jacobian of `jax.scipy.special.logsumexp` when `b` contains a `0` HOT 1
- Accuracy of `jax.experimental.sparse.spsolve` HOT 3
- How can I implement a custom operations with stateful resource? HOT 6
- all inputs are nans in a next step iterationwhile using `jax.scipy.optimize.minimize` HOT 1
- INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version HOT 5
- [pallas] Documentation for tpu compiler options
- Deadlock When Running distributed.initialize() on one TPU Host by Accident HOT 1
- Pipelined model(1F1B pipeline parallelism) implementation in Jax.
- Wrong normalization in FFT methods HOT 4
- DCT normalization affects distribution of spectral coefficients
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from jax.