Comments (2)
My understanding of filter_jit is not 100%, but a common intuition I have for some problems is basically that filter_jit is like filtering by is_array, then jax.jit-ing a function with f(arr, static) and marking static as static_argnames in jax jit. What this means is that jax looks at the hashes of those functions to cache each static input. So what you have here is in the second example, the hash of the second model.forward is different (because its a totally different class initialization, I assume jax assigns it a different hash in the cache) so it triggers a recompile because jax thinks it's a different static value (when we know it isn't). If you remove the jax.jit from the fn and jit it outside (ensuring that it is the exact same cached function being used) it works without error.
import jax
import jax.numpy as jnp
import equinox as eqx
import optax
from typing import Callable
from jaxtyping import Array
class ForwardKinematicTrajectory(eqx.Module):
trajectory: eqx.Module
scale_mixer: Array
forward: Callable
def __init__(self, fk_fn: Callable):
self.trajectory = eqx.nn.Linear(1, 3, key=jax.random.PRNGKey(0))
self.scale_mixer = jnp.ones((3, 1))
self.forward = fn
def __call__(self, data):
"""Meant to be vmapped"""
return self.forward(self.trajectory(data) * self.scale_mixer)
def loss(model: ForwardKinematicTrajectory, X, y):
return jnp.mean(jnp.square(jax.vmap(model)(X) - y))
@eqx.filter_jit
@eqx.debug.assert_max_traces(max_traces=1)
def step(model, opt_state, X, y, loss_grad, opt):
val, grads = loss_grad(model, X, y)
updates, opt_state = opt.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return val, model, opt_state
# generate data
X = jnp.linspace(-1, 1, 50).reshape(-1, 1)
y = jnp.zeros((50, 3, 3))
# init model
fn = jax.jit(jnp.sin)
model = ForwardKinematicTrajectory(fn)
# optimizer
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))
# loss
loss_grad = eqx.filter_value_and_grad(loss)
loss_grad = eqx.filter_jit(loss_grad)
# this will run
for i in range(100):
val, model, state = step(model, opt_state, X, y, loss_grad, optimizer)
if i % 10 == 0:
print(val)
# reinitalize model
print("reinit")
model = ForwardKinematicTrajectory(fn)
for i in range(100):
# this will recompile the step function
val, model, state = step(model, opt_state, X, y, loss_grad, optimizer)
if i % 10 == 0:
print(val)
from equinox.
Thanks for the explanation! Makes sense and fixed my problem!
from equinox.
Related Issues (20)
- Sharding - shard `eqx.Module` as well as inputs? HOT 19
- Type annotations for "struct of arrays" ? HOT 1
- Optimally training ensembles. HOT 4
- Elegant way to index pytrees at arbitrary leaves using equinox? HOT 2
- Incorrect error handling in error_if HOT 3
- Broken link in documentation
- How to print HLO for jit accelated function? HOT 1
- Graph Neural Networks HOT 1
- `UnexpectedTracerError` when vmapping `SpectralNorm` HOT 6
- gradient with sequential and jax.nn.<activation functions> HOT 4
- Weird `jax` error when trying vmap twice while using batchnorm HOT 5
- lax.scan for equinox Modules ? HOT 5
- Regarding trainable parameters in equinox HOT 9
- `eqx.tree_at` fails with `TypeError: Operation undefined, {x} is not a leaf of the pytree` when custom pytree attributes are modified within `__init__` HOT 5
- Getting a type error when initializing Linear module (unexpected keyword argument 'key') HOT 1
- merging multiple eqx.Module classes HOT 2
- How to serialize a model + state pair? HOT 2
- Strange error with multiple tuples as attributes and `equinox.tree_at`? HOT 5
- Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1 HOT 7
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from equinox.