GithubHelp home page GithubHelp logo

patrick-kidger / equinox Goto Github PK

View Code? Open in Web Editor NEW
1.8K 22.0 124.0 17.84 MB

Elegant easy-to-use neural networks + scientific computing in JAX. https://docs.kidger.site/equinox/

License: Apache License 2.0

Python 100.00%
deep-learning jax neural-networks equinox

equinox's Introduction

Equinox

Equinox is your one-stop JAX library, for everything you need that isn't already in core JAX:

  • neural networks (or more generally any model), with easy-to-use PyTorch-like syntax;
  • filtered APIs for transformations;
  • useful PyTree manipulation routines;
  • advanced features like runtime errors;

and best of all, Equinox isn't a framework: everything you write in Equinox is compatible with anything else in JAX or the ecosystem.

If you're completely new to JAX, then start with this CNN on MNIST example.

Coming from Flax or Haiku? The main difference is that Equinox (a) offers a lot of advanced features not found in these libraries, like PyTree manipulation or runtime errors; (b) has a simpler way of building models: they're just PyTrees, so they can pass across JIT/grad/etc. boundaries smoothly.

Installation

pip install equinox

Requires Python 3.9+ and JAX 0.4.13+.

Documentation

Available at https://docs.kidger.site/equinox.

Quick example

Models are defined using PyTorch-like syntax:

import equinox as eqx
import jax

class Linear(eqx.Module):
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

and fully compatible with normal JAX operations:

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
grads = loss_fn(model, x, y)

Finally, there's no magic behind the scenes. All eqx.Module does is register your class as a PyTree. From that point onwards, JAX already knows how to work with PyTrees.

Citation

If you found this library to be useful in academic work, then please cite: (arXiv link)

@article{kidger2021equinox,
    author={Patrick Kidger and Cristian Garcia},
    title={{E}quinox: neural networks in {JAX} via callable {P}y{T}rees and filtered transformations},
    year={2021},
    journal={Differentiable Programming workshop at Neural Information Processing Systems 2021}
}

(Also consider starring the project on GitHub.)

See also: other libraries in the JAX ecosystem

Always useful
jaxtyping: type annotations for shape/dtype of arrays.

Deep learning
Optax: first-order gradient (SGD, Adam, ...) optimisers.
Orbax: checkpointing (async/multi-host/multi-device).
Levanter: scalable+reliable training of foundation models (e.g. LLMs).

Scientific computing
Diffrax: numerical differential equation solvers.
Optimistix: root finding, minimisation, fixed points, and least squares.
Lineax: linear solvers.
BlackJAX: probabilistic+Bayesian sampling.
sympy2jax: SymPy<->JAX conversion; train symbolic expressions via gradient descent.
PySR: symbolic regression. (Non-JAX honourable mention!)

Awesome JAX
Awesome JAX: a longer list of other JAX projects.

equinox's People

Contributors

ahmed-alllam avatar andyehrenberg avatar anh-tong avatar artur-galstyan avatar as3895 avatar benjamin-walker avatar boris-kuz avatar bowlingmh avatar colehaus avatar gurvan avatar hawkinsp avatar homerjed avatar imilas avatar j5b avatar jatentaki avatar jenkspt avatar jondeaton avatar jvmncs avatar knyazer avatar lockwo avatar louisdesdoigts avatar marcelroed avatar mcbal avatar mk-0 avatar nasyxx avatar packquickly avatar paganpasta avatar patrick-kidger avatar uuirs avatar vidhanio avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

equinox's Issues

Error at `optax` model init

JAX newbie here. Trying to port the PyTorch MNIST example but I'm getting an error during model init.

model = Net(jax.random.PRNGKey(seed))
optim = optax.adam(lr)
opt_state = optim.init(model) # error at this line
TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.

My best guess is that this might be due to filtering. But I wasn't able to figure it out how to make this work using optax.

It throws the error at the first jax.nn.relu layer.

class Net(eqx.Module):
    layers: list

    def __init__(self, key):
        keys = jax.random.split(key, 4)
        self.layers = [
            eqx.nn.Conv2d(1, 32, 3, 1, key=keys[0]),
            jax.nn.relu, # fails here
            eqx.nn.Conv2d(32, 64, 3, 1, key=keys[1]),
            jax.nn.relu,
            eqx.nn.MaxPool2D(2, 1),
            eqx.nn.Dropout(0.25),
            jnp.ravel,
            eqx.nn.Linear(9216, 128, key=keys[2]),
            jax.nn.relu,
            eqx.nn.Dropout(0.5),
            eqx.nn.Linear(128, 10, key=keys[3])
        ]

    def __call__(self, x):
        for layer in self.layers[:-1]:
            x = layer(x)
        output = jax.nn.log_softmax(x, axis=1)
        return output

What am I doing wrong?

Frozen field

Is there a special reason on why the fields in equinox.Module are frozen? It makes it harder to implement modules with running states like BatchNorm.

Feature request: Equinox version of `optax.incremental_update(...)`

Hey again @patrick-kidger :)

Would you be willing to introduce an Equinox version of optax.incremental_update (-> equinox.update.incremental_update), in the same way you have done for optax.apply_updates (-> equinox.update.apply_updates)? i.e. Use the same process of applying a tree_map, but checking for _is_none, so that we can pass in a filtered PyTree.

I have already written code for this (which I can PR) but I understand if you'd prefer to minimise code bloat. If it's not in core Equinox, I'll just implement it elsewhere for my use-case.

Thanks!

Initialize weights of a module

Hello,
is there a way to initialize the weights of a module. For instance I have the code

`
class Update(eqx.Module):
l1: eqx.Module
l2: eqx.Module

def __init__(self, key):
    k1, k2 = jrandom.split(key, 2)

    self.l1 = eqx.nn.Conv2d(48, 128, kernel_size=3, padding=1, use_bias=False, key=k1)
    self.l2 = eqx.nn.Conv2d(128, 16, kernel_size=3, padding=1, use_bias=False, key=k2)
    self.l2.weight = self.l2.weight.at[:].set(0.0)

def __call__(self, state_grid):
    state_grid = jax.nn.relu(self.l1(state_grid))
    return self.l2(state_grid)

`
but this throws the FrozenInstanceError for modifying the weight value.
Is there a way to do this?

Add FAQ

  • Optax compatibility: optax.optimiser(eqx.filter(model, eqx.is_array))
  • Make sure your model is a tree, not a DAG.

Pull Requests - Activation Layers

Hi Patrick,

I wanted to get your thoughts on adding activation layers to Equinox's API. Similar to how Pytorch has torch.nn.ReLU().

For example a basic GELU class example:

class GELU(Module):
    approximate: bool = static_field()

    def __init__(self, approximate: bool = False):
        super().__init__()
        self.approximate = approximate

    def __call__(self, x, *, key = None):
        if self.approximate:
            sqrt_2_over_pi = jnp.sqrt(2 / jnp.pi).astype(x.dtype)
            cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3))))
            return x * cdf
        else: 
            return jnp.array(x * (lax.erf(x / jnp.sqrt(2)) + 1) / 2, dtype=x.dtype)

If this is something you think would be beneficial I can open up a pull request with added documentation.

Thank you,

Enrico

Module with subscription?

Hi, first of all, great library, super interesting!

Quick question: sometimes is convenient to access a sub-module by subscription, like:

print(model.linear1)
print(model["linear1"])

but this is currently not supported by equinox's Module. Have you considered making a module sub-scriptable? If not, you would accept a small PR with this implementation?

The change will be minimal, just the method __getitem__() to be implemented. The only thing to check would be if this breaks with the static fields (from a quick look it shouldn't, but I can check).

Thanks!

how to serialize equinox models?

none of the tutorials include information on saving checkpoints - right now I'm using pickle, but it would be nice to be able to use something with bfloat16 compatibility like eg flax.serialization.

any suggestions?

Tweak filtering to match JAX?

At the moment we mostly do isinstance checks to determine what is an array, or array-like.

However see https://github.com/google/jax/blob/4a17c78605e7fc69a69a999e2f6298db79d3837a/jax/_src/numpy/lax_numpy.py#L542 which is JAX's internal definition of what is array-like.

Mostly it looks like these overlap, but there's a couple of edge cases. (Such as __jax_array__.)

Besides this, JAX now supports meaningful isinstance(..., jnp.ndarray), so we should bump the minimum version required and switch to using that.

Better support for broadcasting in linear layer

In equinox's implementation, for a weight matrix w of shape (d_out, d_in), the computational rule for the linear layer update is (https://github.com/patrick-kidger/equinox/blob/main/equinox/nn/linear.py#L65)

y = w x + b

This is problematic if x is a higher-dimensional tensor with shape (d_in, d_1, ..., d_n).
The outcome of w x is a tensor of shape (d_out, d_1, ..., d_n), while b is a tensor of shape (d_out,).
Since numpy's broadcasting rule works from right-to-left (https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules), broadcasting the two won't work.

Example:

import equinox as eqx
import equinox.experimental as eqxe
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jrandom
seed = 0

key = jrandom.PRNGKey(seed)
key, key_linear, key_input = jrandom.split(key, 3)
#print("key=", key)
d_in = 3
d_out = 4
x = jrandom.normal(key_input, shape=(d_in, 5))
linear = eqx.nn.Linear(d_in, d_out, use_bias=True, key=key_linear)
# fails with ValueError due to impossibility of broadcasting
res = linear(x)

To remedy this, I would propose to change the implementation to PyTorch style (https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html). This would allow arbitrary left-most dimensions.
If we could add a reshape to make it work with tensors with arbitrary number of trailing dimensions, that would be even better :)

Have `BatchNorm` handle multiple vmaps when they all have the same `axis_name`.

At the moment

bn = eqx.nn.BatchNorm("batch", 5)
x = jr.normal(key, (5,))
jax.vmap(jax.vmap(bn, axis_name="batch"), axis_name="batch)(x)

will raise weird and wonderful errors.

In practice this should probably still be an error because I don't think this is a concept that makes sense in JAX, but ideally we'd be a bit clearer about what the problem is.

Impact of implicit stateful PyTrees?

I find the following warning in the document

equinox.experimental.BatchNorm updates its running statistics as a side effect of its forward pass. Side effects are quite unusual in JAX; as such BatchNorm is considered experimental. Let us know how you find it!

Could you elaborate on what could go wrong with such a side effect? I think it might be the same catch documented in treex, but I'm not sure if we need to output the Module to update its state in equinox since Python functions are kinda pass-by-reference, so the update should be reflected in the passed-in Module object without returning anything.

By the way, elegy/treex is planning on implementing immutable PyTrees. Looks like you are taking different design paths and evolving separately!

Any plan to abstract this library out to provide OOP with jax?

As per title, do you have any plan to abstract this library out to allow OOP in jax?
I am thinking about something like a jaxclass decorator, or a JaxObject base class.

I undersand this is already possible by dealing with jax PyTrees, but the question is about whether there is any plan to bake this into an easer API.

Observer patterns in JAX

Hello,

First, thank you for providing Equinox, it is very informative and there is a lot to learn from it ! :)

Second, I come from the PyTorch ecosystem where I implemented the observer pattern in a very light framework for my research, like FastAI of PyTorch Lightning. For the past weeks I have been learning JAX/Flax and wanted to do the same, this turns out harder as we are constrained to functional programming, hence no state.
However, I came across Equinox and one of your experimental feature, that is stateful operation. Knowing this, do you think it is possible to implements callbacks for JAX and if yes what is your opinion on this ?

Apologies if my issue does not fit here or if the answer to it is obvious.

Thank you very much

Implement `filter_vmap`, `filter_pmap`

These are a little more niche, but would be nice to have. For example filter_vmap could be used to construct ensembles of models easily.

At a technical level these look like they're going to be pretty fiddly because of out_axes, however. filter_vmap is doable by monkey-patching the JAX interpreter. (Not great but liveable.) filter_pmap doesn't have a clear route forward, though. Moreover jax.pmap doesn't currently support None in out_axes, which might throw a wrench in things.

Missing nn submodule

I installed equinox with the command:

pip install git+https://github.com/patrick-kidger/equinox.git

However, I got an error when running

import equinox
---------------------------------------------------------------------------
ImportError                               Traceback (most recent call last)
<ipython-input-2-d7e322bf3be2> in <module>()
----> 1 import equinox

/usr/local/lib/python3.7/dist-packages/equinox/__init__.py in <module>()
----> 1 from . import nn
      2 from .filters import is_array_like, is_inexact_array
      3 from .gradf import gradf, value_and_grad_f
      4 from .jitf import jitf
      5 from .module import Module

ImportError: cannot import name 'nn' from 'equinox' (/usr/local/lib/python3.7/dist-packages/equinox/__init__.py)

---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.
---------------------------------------------------------------------------

It seems to me that you missed nn module in setup.py. When I add equinox.nn to packages list, it works!

Best way of implemeting spectral norm

Hi,
I'm coming to Equinox as a user of Diffrax. There I'd like to compare a Neural ODE use-case I have implemented in DiffEqFlux.jl to a Diffrax implementation. In this use-case, it turned out to be profitable to use spectral normalization. So I am wondering how to best implement spectral normalization as an Equinox layer.

The challenge is that spectral normalization has two stateful variables u0, v0 which are approximations of the left- and right-singular vectors of the weight matrix. The difference to batch normalization's statefulness is that their update rules only depend on the weights of the layer, not the inputs to the layer. This means that the forward call can remain pure with merely read-access to u0 and v0. The update of u0 and v0 can happen outside of the forward call after the gradient descent step in the training loop.

for step in range(n_steps):
    # compute loss for model
    loss = compute_loss(model)
    # update spectral norm
    update_spectral_norm(model)

So my current idea how to implement this would look something like the following:

class SpectralNormalization(eqx.Module):
    # is static_field() enough to prevent back-propagation?
    layer: eqx.Module
    layer_var: string = static_field()
    u0: Array = static_field()
    v0: Array = static_field()
    
    def __init__(self, layer: eqx.Module, layer_var="weight"):
        self.layer = layer
        self.layer_var = layer_var
       # TODO: init u0 and v0 randomly with shape appropriate for layer's layer_var

    def __call__(x):
        W = getattr(x, self.layer_var)
        v0 = self.v0 # does this need a stop_gradient()?
        u0 = self.v0  # does this need a stop_gradient()
        sigma = jnp.matmul(jnp.matmul(v0, W), jnp.transpose(u0))[0, 0]
        # for a layer of the form W*x + b, i.e., linear layers and conv layers, this results in computing
        # W/sigma * x + b, which is precisely the spectral normalization
        return self.layer(x / sigma)

The update_spectral_norm method would look something like

def update_spectral_norm(model):
    # poor man's multiple dispatch
    if isinstance(model, "SpectralNormalization"):
        # update u0 and v0 with the power method
   else
       pass

This way of updating u0 and v0 outside the training loop should keep everything pure.

Could you comment on whether static_field would be enough to prevent u0 and v0 from being optimized or would I also need a stop_gradient in the forward call?
Also, I'd be interested in your opinion on the poor man's multiple dispatch solution to updating the spectral normalization. I have seen that you have added some experimental way of handling statefulness. Using this could be an option as well, but given that it's explicitly marked as experimental and dangerous I'm a bit hesitant about this.
Thanks for your help!

Improve `tree_at`.

At the moment this can be a bit finickity, most noticeably when doing type-dependent things like

tree_at(lambda m: [x for x in jax.tree_leaves(m) if eqx.is_array(m)], model, ...)

It's super nonobvious how best to change this though. Needs some thought!

As an aside, even with the current implementation it should be possible to substitute sub-PyTrees just by wrapping a jax.tree_leaves around substitution and substitutee. This would be good to have by default as well. (And to change between sub-pytrees of different structures, really.)

filter_vmap doesn't allow in_axes as a vmapkwarg

Hi, love the idea for the library! I was recently converting a canonical example from stax to equinox and found a discrepancy with what I expected from the **vmapkwargs parameter in eqx.filter_vmap. In particular, the wrapper from filter_vmap is not respecting that the user might pass in_axes as part of **vmapkwargs, as evidenced by this minimally-reproducing gist. After a quick look through the code, the same might be true of out_axes and also for filter_pmap.

I've traced the TypeError to how _VmapWrapper constructs the in_axes argument to jax.vmap; one fix would be a matter of adding something like

in_axes = __self._vmapkwargs.pop("in_axes", None)
in_axes = in_axes or _resolve_axes(
    [...]
)

here before passing **self._vmapkwargs to jax.vmap later on.

I honestly haven't fully grokked the meaning of some of filter_vmap's kwargs; maybe this is its intended behavior? If so, would appreciate a clarification on what I'm doing wrong by trying to supply a value for in_axes to it.

Thanks!

Add Twitter cards

By adding a few meta tags it's possible to get Twitter to display a custom image etc. when linking to the documentation.

Adding this in to the docs sounds like a good way to procrastinate from more serious work!

[Bug?] Possible incompatible interaction between optax and equinox?

Hi @patrick-kidger, I'm reporting what I think might be a bug in the interaction between equinox and optax (though I am a bit unsure as to which library to report it in).

Specifically, there is a weird error that shows up whenever I use the MLP module with a subset of optax optimizers.

Here is a minimal reproducible example. With the following code block:

import optax 
import equinox as eqx 
from jax.example_libraries import stax
from jax import random, nn

optimizer = optax.adabelief(learning_rate=1e-3) # errors out on adam, adamw; no errors with sgd
model = eqx.nn.Sequential(
    [
        eqx.nn.Linear(in_features=1, out_features=1024, key=random.PRNGKey(45)),
        nn.relu,  # no problem when commented out
        eqx.nn.Linear(in_features=1024, out_features=1, key=random.PRNGKey(39)),
    ]
)
opt_state = optimizer.init(model)

I get the following stack trace (hidden in details):


---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/ericmjl/github/incubator/score-models/notebooks/scratch.ipynb Cell 11' in 
      7 # model = GaussianModel()
      8 model = eqx.nn.Sequential(
      9     [
     10         eqx.nn.Linear(in_features=1, out_features=1024, key=random.PRNGKey(45)),
   (...)
     13     ]
     14 )
---> 15 opt_state = optimizer.init(model)

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/optax/_src/combine.py:45, in chain..init_fn(params)
44 def init_fn(params):
---> 45 return tuple(fn(params) for fn in init_fns)

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/optax/_src/combine.py:45, in (.0)
44 def init_fn(params):
---> 45 return tuple(fn(params) for fn in init_fns)

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/optax/_src/transform.py:457, in scale_by_belief..init_fn(params)
456 def init_fn(params):
--> 457 mu = jax.tree_map(jnp.zeros_like, params) # First moment
458 s = jax.tree_map(jnp.zeros_like, params) # Second Central moment
459 return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s)

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/jax/_src/tree_util.py:184, in tree_map(f, tree, is_leaf, *rest)
182 leaves, treedef = tree_flatten(tree, is_leaf)
183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/jax/_src/tree_util.py:184, in (.0)
182 leaves, treedef = tree_flatten(tree, is_leaf)
183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/jax/_src/numpy/lax_numpy.py:1889, in zeros_like(a, dtype, shape)
1887 @_wraps(np.zeros_like)
1888 def zeros_like(a, dtype=None, shape=None):
-> 1889 _check_arraylike("zeros_like", a)
1890 lax_internal._check_user_dtype_supported(dtype, "zeros_like")
1891 if np.isscalar(shape):

File ~/anaconda/envs/score-models/lib/python3.9/site-packages/jax/_src/numpy/util.py:298, in _check_arraylike(fun_name, *args)
295 pos, arg = next((i, arg) for i, arg in enumerate(args)
296 if not _arraylike(arg))
297 msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 298 raise TypeError(msg.format(fun_name, type(arg), pos))

TypeError: zeros_like requires ndarray or scalar arguments, got <class 'jax._src.custom_derivatives.custom_jvp'> at position 0.

On the other hand, if one were to comment out the nn.relu line above, we would have no issues running the code above.

Not knowing the internals of equinox well enough (I only started playing with it today), I'm not exactly sure how to pinpoint what the exact issue is, so I decided to report the matter here first.

Help on what went wrong in a simple MLP implemented with equinox

Hi the equinox community,

(This is actually a question instead of issue. I did not find a better place to post questions, apologize about that!)

I recently start to learn jax and equinox as a pytorch user. After reading about the basics, I started to practice by implementing a simple MLP, where the Linear module can accept an extra parameter on whether or not to add an activation (relu in this case). The gist can be found here: https://gist.github.com/breakds/7330f67619fbf82d1daad9767353b217

After getting it to run, the obtained model is actually quite "linear" which is different from what I have expected. Pretty sure I made some mistake there, probably due to my misunderstanding of jax/equinox. Can someone with better experience help pointing that out?

Thanks!

[Feature request] Easier way to specify parameters to freeze

Thank you for this library! This is the library that makes most sense among JAX's nn libraries (I don't have to learn additional transformations on those libs).
What I found quite tedious at the moment is specifying parameters to freeze.
In your example, the filter has to be specified based on the structure of the model.
This becomes very tedious if we want to try out different architectures as we have to write the filter for each model.
In PyTorch, we can use buffer in this case. In equinox, there is no easy way to have something like buffer.

Option 1
One option is to use eqx.static_field, but this removes the field from PyTree which is not what we always want.

Option 2
Another option is to use jax.lax.stop_gradient, this works as below:

class Module(eqx.Module):
    mult: jnp.ndarray
    mult_b: jnp.ndarray  # let's say this is the buffer

    def __init__(self, mult: jnp.ndarray, mult_b: jnp.ndarray) -> None:
        self.mult = mult
        self.mult_b = mult_b  # using jax.lax.stop_gradient(mult_b) doesn't work here

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.mult * jax.lax.stop_gradient(self.mult_b) * x
        return x

if __name__ == "__main__":
    key = jax.random.PRNGKey(123)
    key1, key2, keyx = jax.random.split(key, 3)
    mult = jax.random.normal(key1, (1,))
    mult_b = jax.random.normal(key2, (1,))
    x = jax.random.normal(keyx, (1,))
    module = Module(mult, mult_b)
    out = module(x)

    grad = eqx.filter_grad(lambda mod, x: jnp.sum(module.__class__.__call__(mod, x)))
    g = grad(module, x)
    print(g)  # we want the gradient of mult_b is 0 or None

The example above works, but it is quite cumbersome to write jax.lax.stop_gradient in every method in the class.

Option 3
Another option is to use filter by field name (maybe something like eqx.filter_by_field_name). I imagine that goes something like this:

class Module(eqx.Module):
    mult: jnp.ndarray
    mult_b: jnp.ndarray  # let's say this is the buffer

    def __init__(self, mult: jnp.ndarray, mult_b: jnp.ndarray) -> None:
        self.mult = mult
        self.mult_b = mult_b

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        x = self.mult * self.mult_b * x
        return x

if __name__ == "__main__":
    key = jax.random.PRNGKey(123)
    key1, key2, keyx = jax.random.split(key, 3)
    mult = jax.random.normal(key1, (1,))
    mult_b = jax.random.normal(key2, (1,))
    x = jax.random.normal(keyx, (1,))
    module = Module(mult, mult_b)
    out = module(x)

    name_regex = r"\w+_b"
    filter_spec = jax.tree_map(lambda _: False, (module, x))
    filter_spec = eqx.filter_by_field_name(filter_spec, name_regex, replace=True)
    grad = eqx.filter_grad(lambda mod, x: jnp.sum(module.__class__.__call__(mod, x)),
                           filter_spec=filter_spec)
    g = grad(module, x)
    print(g)  # we want the gradient of mult_b is 0 or None

I would vote for the 3rd option, because it is the most convenient, but it would need an additional function: eqx.filter_by_field_name
What do you think?
I'm happy to create a PR if you agree with this.

Implementing Batch Normalization

In Flax, Batch Normalization is a bit finicky since each call to apply requires marking batch_stats as mutable and updating the batch_stats afterward.

bn = flax.linen.BatchNorm(use_running_average=True)

x = jnp.arange(24).reshape(3, 6)

vars = bn.init(random.PRNGKey(0), x)

# Mark the batch stats as mutable so we can update them in the variable dictionary
x_normed, mutated_vars = bn.apply(vars, x, mutable=['batch_stats'])

vars = {**vars, **mutated_vars}  # Update the variables with our diff

x_normed2, mutated_vars2 = bn.apply(vars, x, mutable=['batch_stats'])

How could this be implemented as a Module in Equinox? I'm happy to submit an implementation given some guidance.

Slightly unintuitive `eqx.nn.Sequential`

My first intuition when working with eqx.nn.Sequential (e.g. for a simple MLP) is like below:

import jax
import equinox as eqx

key = jax.random.PRNGKey(0)
key1, key2, key = jax.random.split(key, 3)
module = eqx.nn.Sequential([
    eqx.nn.Linear(5, 10, key=key1),
    jax.nn.log_sigmoid,
    eqx.nn.Linear(10, 2, key=key2),
])
keyx, key = jax.random.split(key, 2)
x = jax.random.normal(keyx, (5,))
y = module(x)

However, the code above does not work because of the error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/anaconda3/envs/py39/lib/python3.9/site-packages/equinox/nn/composed.py", line 124, in __call__
    x = layer(x, key=key)
TypeError: log_sigmoid() got an unexpected keyword argument 'key'

So the module has to be slightly modified into:

module = eqx.nn.Sequential([
    eqx.nn.Linear(5, 10, key=key1),
    lambda x, key: jax.nn.log_sigmoid(x),  # adding a thin wrapper
    eqx.nn.Linear(10, 2, key=key2),
])

Is this a deliberate choice?
Why not detecting whether a module has key in its __call__ method and then call it appropriately (whether to use key or not)?

Performance vs `jax.tree_util.Partial`

Hi @patrick-kidger,
Following on from (this issue) [https://github.com/patrick-kidger/diffrax/issues/76], I tried to determine what was causing the jit to bottleneck for diffrax. I believe that it may have something to do with equinox. Here is an example of my reasoning based on the canonical example for equinox

import equinox 
import jax
jax.config.update("enable_x64")

class Linear(equinox.Module):
    weight: jax.numpy.ndarray
    bias: jax.numpy.ndarray

    def __init__(self, in_size, out_size, key):
        wkey, bkey = jax.random.split(key)
        self.weight = jax.random.normal(wkey, (out_size, in_size))
        self.bias = jax.random.normal(bkey, (out_size,))

    def __call__(self, x):
        return self.weight @ x + self.bias

@jax.jit
@jax.grad
def loss_fn(model, x, y):
    pred_y = jax.vmap(model)(x)
    return jax.numpy.mean((y - pred_y) ** 2)

batch_size, in_size, out_size = 32, 2, 3
model = Linear(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))

%%timeit
grads = loss_fn(model, x, y)

So the idea is to repeat the equinox implementation functionally and see what happens

def init_linear(in_size, out_size, key):
    wkey, bkey = jax.random.split(key)
    weight = jax.random.normal(wkey, (out_size, in_size))
    biases = jax.random.normal(bkey, (out_size,))
    return weight, biases

batch_size, in_size, out_size = 32, 2, 3
weights, biases = init_linear(in_size, out_size, key=jax.random.PRNGKey(0))

@jax.tree_util.Partial
@jax.jit
def run_model(x, /, weights=weights, biases=biases):
    return weights @ x + biases

x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))

%%timeit
grads = loss_fn(run_model, x, y)

Sorry for being a thorn, and thanks in advance for chasing this up.
Jordan

Have `get_state`/`set_state` support setting values with compatible batch axes

At the moment, get_state and set_state demand that their arguments have precisely the same shape, dtype, and choice of batch axes.

It would make sense to allow some compatibility between different kinds of batch axes, for example because of it taking a few fixed-point iterations for all of the vmap'd batch axes to flow through a model.

Issues when applying multiple optax optimizers within nested Equinox Modules

Hi Patrick,

We are currently developing an auto-diff optical modelling package that makes heavy use of Equinox (dLux, up on PyPi in the next few days)! We are however running into some issues when trying to optimize our models using optax and I believe the problem stems from the way Equinox is interacting with optax. We are also hoping that our use case can help build functionality into Equinox for optimizing more complex nested Equinox Modules with different optax optimizers/schedules.

The basic structure of the package is an Equinox.Module() in which some parameters are arrays, some are lists of other Equinox.Module() objects and others are Equinox.Modules() which contain lists of other Equinox.Modules. Like this:

Model(eqx.Module()):
 -> param1 = np.array([val0, val1, ....])
 -> param2 =  np.array([val0, val1, ....])
 -> param3 = [Layer1(eqx.Module()),
              Layer2(eqx.Module()),
              Layer3(eqx.Module())]

 -> param4 = SubModel(eqx.Module()):
      -> param1 = [Layer1(eqx.Module()),
                   Layer2(eqx.Module()),
                   Layer3(eqx.Module())]

The nature of these models is that each parameters needs to have a a different learning rate/optimization function. Optax has optax.multi_transform() to handle stitching together multiple optimizers, but this method does not work with our models because our parameters are stored in the equinox objects, not nested dictionaries. I have attempted to manually decompose our models into nested dictionaries, optimise the params in the dictionary and then rebuild a model from the updates parameters but it throws a very strange error:

model_dict = model_to_dict(model)
optimizer = optax.multi_transform(
    {'param1': optax.adam(1e-6),
    {'param2': optax.adam(1e3)}

opt_state = optimizer.init(model_dict)

# Still pass the scene object to the loss function
loss, grads = loss_func(model, data)

grads_dict = model_to_dict(grads)
updates, opt_state = optimizer.update(grads_dict, opt_state)
model_dict = eqx.apply_updates(model_dict, updates)

new_model = dict_to_model(model_dict)

# Error thrown here
loss, grads = loss_func(new_model, data)

This method throws a very strange error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [29], in <cell line: 11>()
     10 errors, grads_out, models = [], [], []
     11 for i in tqdm(range(5)):
---> 12     loss, grads = loss_func(opt_model, images)
     14     grads_dict = scene_to_dict(grads)
     16     # updates, state = tx.update(grads_dict, state, opt_dict)

File ~/mambaforge/envs/morph/lib/python3.8/functools.py:399, in partialmethod._make_unbound_method.<locals>._method(cls_or_self, *args, **keywords)
    397 def _method(cls_or_self, /, *args, **keywords):
    398     keywords = {**self.keywords, **keywords}
--> 399     return self.func(cls_or_self, *self.args, *args, **keywords)

File ~/mambaforge/envs/morph/lib/python3.8/site-packages/jax/_src/device_array.py:41, in _forward_method(attrname, self, fun, *args)
     40 def _forward_method(attrname, self, fun, *args):
---> 41   return fun(getattr(self, attrname), *args)

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Which I can not any sense of. Interestingly, I was able to recreate this error using the eqx.tree_equal() function called on the model and new_model object eqx.tree_equal(new_model, model)

So I am hoping you can help shed some light, or help develop some functionality to support the optimization of nested Equniox Modules with different learning rates! I am completely new to optax and Equinox, and only familiar with Jax at a high level so its entirely possible I am misunderstanding something fundamental about the way these packages interact.

Finally, these are the versions of al the packages that I'm using:

jax: 0.3.4
eqx: 0.3.2
optax: 0.1.1

and it is all being run on an Apple M1 machine. Please let me know if you need any more information!
Cheers

Compilation logging shows generic function name for every JIT'ed function for `filter_jit`, `filter_grad`, `filter_value_and_grad`

I've recently been debugging long compile times in my neural networks and stumbled upon an XLA compiler flag that can be set using

jax.config.update('jax_log_compiles', True)

This setting prints out some messages when each function starts and ends its compilation including the amount of time spent. However, currently, the error message shows a generic function name from within the implementation of filter_jit, filter_grad and filter_value_and_grad, e.g.

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00043702125549316406 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002541542053222656 sec
WARNING:absl:Compiling prim_fun (4766936320 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of broadcast_in_dim in 0.06613397598266602 sec
WARNING:absl:Finished tracing + transforming f_wrapped for jit in 0.0012121200561523438 sec
WARNING:absl:Compiling f_wrapped (5095206976 for args (ShapedArray(float32[1]),).
WARNING:absl:Finished XLA compilation of f_wrapped in 0.00688481330871582 sec

where f_wrapped comes from here

def f_wrapped(dynamic, static_treedef, static_leaves, filter_spec_return):

When looking for the culprit of excessive debugging times, having the names of the functions available is important.

TypeError: unhashable type: 'numpy.ndarray'

Hello,

I was following through the RNN documentation while attempting to implement the PaLM model from this library. I have previously contributed to some of lucidrains PyTorch implementations before and was working on an example for the JAX version of PaLM. When using eqx.filter_jit on the train_step function I receive the error:

ValueError: Non-hashable static arguments are not supported. An error occured during a call to 'train_step' while trying to hash an object of type <class 'tuple'>... The error was: TypeError: unhashable type: 'numpy.ndarray'

Do you have any idea why this may occur?

I greatly appreciate your time and help.

Thank you,

Enrico

The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/

Code:

from random import randrange
import tqdm
import gzip
import numpy as np

from torch.utils.data import DataLoader, Dataset

import jax
from jax import nn
from optax import adam, clip_by_global_norm, chain, apply_every

# https://github.com/lucidrains/PaLM-jax
from palm_jax import PaLM

import equinox as eqx

# constants

NUM_BATCHES = int(1e5)
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 0.5
VALIDATE_EVERY  = 100
SAMPLE_EVERY  = 500
SEQ_LEN = 1024

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

def decode_token(token):
    return str(chr(max(32, token)))

def decode_tokens(tokens):
    return ''.join(list(map(decode_token, tokens)))

# prepare enwik8 data

with gzip.open('../datasets/enwik8.gz') as file:
    X = np.fromstring(file.read(int(95e6)), dtype=np.uint8)
    data_train, data_val = np.split(X, [int(90e6)])

class TextSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        rand_start = randrange(0, self.data.shape[0] - self.seq_len - 1)
        return self.data[rand_start: rand_start + self.seq_len + 1]

    def __len__(self):
        return self.data.shape[0] // self.seq_len

train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))

# setup model and params

key = jax.random.PRNGKey(0)

model = PaLM(
    num_tokens = 256,
    dim = 512,
    depth=12,
    heads=8,
    dim_head=64,
    key = key
)

# cross entropy loss

def cross_entropy(logits, targets, axis = -1):
    logprobs = nn.log_softmax(logits, axis = axis)
    nll = np.take_along_axis(logprobs, np.expand_dims(targets, axis = axis), axis = axis)
    ce = -np.mean(nll)
    return ce

@eqx.filter_value_and_grad
def loss_fn(model, data):
    inp, labels = data[:, :-1], data[:, 1:]
    logits = jax.vmap(model)(inp)
    return cross_entropy(logits, labels, axis = -1)

# optimizer

optim = chain(
    clip_by_global_norm(MAX_GRAD_NORM),
    adam(LEARNING_RATE),
    apply_every(GRADIENT_ACCUMULATE_EVERY)
)

optim_state = optim.init(model)

# train step

@eqx.filter_jit
def train_step(model, data, optim_state):
    loss, grads = loss_fn(model, data)
    updates, optim_state = optim.update(grads, optim_state)
    model = eqx.apply_updates(model, updates)
    return model, optim_state, loss

# training

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    data = next(train_loader).numpy()
    model, optim_state, loss = train_step(model, data, optim_state)
    if i % GRADIENT_ACCUMULATE_EVERY == 0:
        print(f'loss: {loss.item()}')

Neater interface to filtered transformations?

I'd need think about exactly what this interface might be, but it'd be nice to improve the current filter_spec stuff so that we can more easily specify just a few elements, without needing to add lots of default is_arrays for all the elements we don't specify.

(Something like this is probably a prerequisite for ever implementing filter_vmap.)

This probably includes some way of doing deferred construction of the filter_spec until runtime.

Doubt about formulating ensemble / MoE networks

Hi

Firstly, thanks for this great library! I've been exploring equinox and I really like the design and abstraction. That said, I'm rather new to the library and I had a question about an efficient way to compute over an ensemble of networks. Usually, the way I've been doing this for mixture of experts models is

init_e, apply_e = MLP(...)
# assuming that final layer has softmax
init_g, apply_g = MLP([..., num_experts])

# initializing parameters:
Keys = random.split(key = ..., num_experts + 1)
params_e = vmap(init_e)(keys[:-1])
params_g = vmap(init_g)(keys[-1])

# applying for a single example:
experts_out = vmap(apply_e, (0, None))(params_e, input)
gate_out = apply_g(params_g, input)

out = einsum('nd, n -> d', out, gate_out)

I'm curious what's the best way to do this in equinox. A naive solution that comes to my mind is sequentially computing the individual experts inside __call__ with a for loop and then combining with the gating network. As for the initialization, the way I've tried doing it is using an attribute of the class to be a list of MLPs (our experts) which are sequentially initialized in a for loop. Any thoughts?

List a few projects using Equinox

People are starting to build cool stuff with Equinox. It might be worth adding links to these in the documentation for examples.

  • Diffrax
  • sympy2jax
  • Lucidrains' work
  • Winnie's inception work
  • Ben's UNet work
  • dLux
  • All the stuff appearing in the "used by" list.

Add some kind de/serialisation?

Equinox models are just PyTrees so they should be very easy to serialise/deserialise; just save the PyTree to disk in whatever way is desired. It might be worth adding some library functions for this just for convenience. Perhaps checking the device of JAX arrays etc?

This should respect the get_state/set_state stuff that's being put together.

In addition, there should be a version of get_state which inlines its state in the call graph, for faster inference.

Equinox and Einops Layer Integration

Hi Patrick,

I was wondering if you had any input on integrating Einops Rearrange Layers with Equinox Module composites such as eqx.nn.Sequential.

For example:

self.to_patch_embedding = eqx.nn.Sequential([
        Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
        eqx.nn.Linear(patch_dim, dim, key = key_1)
        ])

I have been working on a full ViT transformer example in Equinox (it is about 300 lines so I did not want to post the full code snippet here) and have been trying to find a way to properly implement the layer. Here is a subset of the code:

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class ViT(Module):
    to_patch_embedding: eqx.nn.Sequential
    pos_embedding: jnp.ndarray
    cls_token: jnp.ndarray
    dropout: eqx.nn.Dropout
    transformer: Module
    pool: str = static_field()
    to_latent: eqx.nn.Identity
    mlp_head: eqx.nn.Sequential

    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, key, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()

        key_1, key_2, key_3, key_4, key_5 = random.split(key, 5)

        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 
        assert image_width % patch_width == 0

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}

        self.to_patch_embedding = eqx.nn.Sequential([
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            eqx.nn.Linear(patch_dim, dim, key = key_1)
        ])
        print(self.to_patch_embedding)

        self.pos_embedding = random.normal(key_2, shape = [1, num_patches + 1, dim])
        self.cls_token = random.normal(key_3, shape = [1, 1, dim])
        self.dropout = eqx.nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim = dim, depth = depth, heads = heads, dim_head = dim_head, mlp_dim = mlp_dim, key = key_4, dropout = dropout)

        self.pool = pool
        self.to_latent = eqx.nn.Identity()

        self.mlp_head = eqx.nn.Sequential([
            eqx.nn.LayerNorm(dim),
            eqx.nn.Linear(dim, num_classes, key = key_5),
        ])

    def __call__(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = jnp.concatenate([cls_tokens, x], axis = 1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        if self.pool == 'mean':
            x = jnp.mean(x, axis = 1)
        else:
            x = x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

I greatly appreciate your help.

Thank you,

Enrico Shippole

Pooling layer

Hi Patrick,

Just wondering if pooling layers in CNN are available in Equinox?

Thanks,
JW

Hashable `eqx.Module`

Thank you for this fantastic library -- I love the transparency of the systems you can build when everything is just a python datastructure.

As it stands, eqx.Modules do not have a way of being hashed. This makes decorating methods with @jax.jit impossible, and it complains:

ValueError: Non-hashable static arguments are not supported. An error occured while trying to hash an object of type <class '__main__.SimpleModule'>, SimpleModule(layers=[NeuronLayer(lagrangian=<function SimpleModule.__init__.<locals>.<lambda> at 0x7f818447b160>, shape=(32,), tau=1.0), NeuronLayer(lagrangian=Softmax(beta=1), shape=(500,), tau=0.0)], synapses=[DenseSynapse(W=DeviceArray([[ 1.1177216 , -0.85337347,  0.62181216, ...,  0.47079483,
               2.1223438 ,  0.11009306],
             [-0.19882406, -0.58834153,  0.27299452, ...,  0.3560427 ,
              -1.3553874 , -0.64876467],
             [-0.27144387, -0.24926849,  1.7452046 , ...,  0.90979475,
              -0.37605172, -0.30193415],
             ...,
             [ 0.25044632, -0.96083486,  0.18758047, ...,  0.82221097,
               0.73630685, -0.7022106 ],
             [-1.0072463 , -0.30353332,  1.412912  , ..., -1.0664191 ,
              -1.1468811 ,  0.5145061 ],
             [-0.7490296 ,  2.1055915 , -1.5650765 , ...,  1.9068209 ,
               0.04990109,  0.2296551 ]], dtype=float32))], connections=[(0, 1, 0)]). The error was:
TypeError: unhashable type: 'SimpleModule'

I was curious if there is anything preventing eqx.Modules from providing a default hash function. They are, after all, just data containers with a few methods attached to them.

Tracer error with eqx.jit, jax.grad and dfx.diffeqsolve

I'm experiencing an exception in a code which tries to backpropagate through diffeqsolve, the whole thing being wrapped in eqx.filter_jit. The context is essentially an extension of your score-based diffusion example to calculate likelihoods as shown here. This works correctly, however when I try to take jax.grad of the log-likelihood w.r.t. sample (image), I get an error about a leaked tracer; indeed the leak location does look like it stores something in a struct. The reproducing code along with the error message are here. I hope this is not too difficult to understand, it's mostly refactored code of yours :)

Improve/remove caching for `filter_jit`

At the moment we have an LRU cache here:

@ft.lru_cache(maxsize=4096)

This shouldn't be necessary, as we can implement the wrapped object as an eqx.Module.

class WrappedFunc(eqx.Module):
    func: Callable
    jitkwargs: Dict[str, Any]

    @ft.partial(jax.jit, static_argnums=(0, 1, 2, 5))
    def __call__(self, static_leaves, static_treedef, dynamic_args, dynamic_kwargs, filter_spec_return):
        ...

That is, use the existing jax.jit cache to handle self.func and self.jitkwargs, rather than our own cache.

Standard MLP (equinox.nn.MLP) does not work with `apply_updates` function

From what I can see, the standard MLP included in equinox.nn.MLP breaks when trying to apply updates.

Simple demo:

import equinox as eqx
import jax.random as jrand

# Define a simple MLP
mlp = eqx.nn.MLP(in_size=1,width_size=1,out_size=1,depth=1,key=jrand.PRNGKey(0))

# Should be able to apply updates with itself
eqx.apply_updates(mlp,mlp)

Error thrown:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [4], in <cell line: 1>()
----> 1 eqx.apply_updates(mlp,mlp)

File ~/Desktop/equinox/equinox/update.py:37, in apply_updates(model, updates)
     18 """A `jax.tree_map`-broadcasted version of
     19 ```python
     20 model = model if update is None else model + update
   (...)
     34 The updated model.
     35 """
     36 # Assumes that updates is a prefix of model
---> 37 return jax.tree_map(_apply_update, updates, model, is_leaf=_is_none)

File ~/miniconda3/envs/eqx-fork/lib/python3.10/site-packages/jax/_src/tree_util.py:184, in tree_map(f, tree, is_leaf, *rest)
    182 leaves, treedef = tree_flatten(tree, is_leaf)
    183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniconda3/envs/eqx-fork/lib/python3.10/site-packages/jax/_src/tree_util.py:184, in <genexpr>(.0)
    182 leaves, treedef = tree_flatten(tree, is_leaf)
    183 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 184 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/Desktop/equinox/equinox/update.py:10, in _apply_update(u, p)
      8     return p
      9 else:
---> 10     return p + u

TypeError: unsupported operand type(s) for +: 'custom_jvp' and 'custom_jvp'

The issue seems to arise in tree_map, when we flatten the tree:

leaves, treedef = tree_flatten(tree, is_leaf)

Inspecting the resulting leaves:

[DeviceArray([[0.58076453]], dtype=float32),
DeviceArray([-0.44256163], dtype=float32),
DeviceArray([[0.882236]], dtype=float32),
DeviceArray([0.79829645], dtype=float32),
<jax._src.custom_derivatives.custom_jvp object at 0x7fc6f0e88460>,
<function MLP._identity at 0x7fc6f1db8670>]

We see that the tree flattening has also been applied to the custom jvp object (i.e. the RELU) and the identity function—i.e. the activation fields in the MLP class.

To fix this, I believe we should simply mark the activation and final_activation fields as static (i.e. should not be treated as leaves of the PyTree):

class MLP(eqx.Module):
    """Standard Multi-Layer Perceptron; also known as a feed-forward network."""

    layers: List[Linear]
    activation: Callable = static_field()
    final_activation: Callable = static_field()

    ...

This fixes the issue for me.

Maybe remove/document difference between eqx.filter_grad and jax.grad

When has_aux=True is used, the order of return values is different between equinox.filter_grad and jax.grad
However, this difference is not mentioned in the documentation: https://docs.kidger.site/equinox/api/filtering/filtered-transformations/#equinox.filter_grad

Minimal reproducible example:
`
import jax
import jax.numpy as jnp
import equinox as eqx

@jax.jit
def some_func(x):
return jnp.sum(2xjnp.sin(x)), (x, jnp.sin(x))

some_grad1 = jax.grad(some_func, has_aux=True)
some_grad2 = eqx.filter_grad(some_func, has_aux=True)
print(jax.vmap(some_grad1)(jnp.ones((5,2))))
print(jax.vmap(some_grad2)(jnp.ones((5,2))))
`

output:
`
(DeviceArray([[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465]], dtype=float32), (DeviceArray([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32), DeviceArray([[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096]], dtype=float32)))

((DeviceArray([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]], dtype=float32), DeviceArray([[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096],
[0.84147096, 0.84147096]], dtype=float32)), DeviceArray([[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465],
[2.7635465, 2.7635465]], dtype=float32))
`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.