Comments (5)
That's true, tree operations won't necessarily fail directly, but won't act as one might expect. E.g. jax.tree_map(lambda x: jnp.array(x), mytree)
does not result in a pytree of arrays (42 is still an int)
from equinox.
This is failing because pytree = jtu.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf)
in eqx.tree_at can't make {"a": 42} a LeafWrapper. Specifically, tree map in general doesn't modify {"a"} (e.g. if you just treemap to cast everything to an array, a will always be an int because when you create the pytree by unflattening you make it so). To be honest, I'm not sure if this is actually a valid pytree because it seems to break some of the assumptions of jax tree utils, but if it is, it seems hard to build around this. You can't apply any tree operations to it, so maybe you could look at the before and after of a tree map to determine if something is different.
Conceptually, it makes more sense to have "a" be static if it won't change, or have a be an optional flag in the constructor which defaults to 42, but that might just be me
from equinox.
As for the validity of the pytree, this is an interesting question. A tree_map
over such custom pytree won't fail, e.g.
jax.tree_util.tree_map(lambda x:3 if isinstance(x, int) else x, mytree)
from equinox.
When you mention an optional flag, you think about something like:
class CustomPytree:
def __init__(self, dict, attr, static, a_exists=False):
self.dict = dict
# preprocess attribute -> always set key "a" to 42
if not a_exists:
self.dict["a"] = 42
self.attr = attr
self.static = static
jax.tree_util.register_pytree_node(
CustomPytree,
lambda self: (
(self.dict, self.attr),
{"static": self.static},
),
lambda aux, children: CustomPytree(*children, aux["static"], a_exists=True),
)
mytree = CustomPytree(dict={"b": 2}, attr=1, static=1)
upon which eqx.tree_at(lambda t: t.attr, mytree, 3)
and jax.tree_map(lambda x: jnp.array(x), mytree)
now seem to work.
But this seems hacky for you ?
from equinox.
Thank you for your answer !
This is failing because
pytree = jtu.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf)
in eqx.tree_at can't make {"a": 42} a LeafWrapper. Specifically, tree map in general doesn't modify {"a"} (e.g. if you just treemap to cast everything to an array, a will always be an int because when you create the pytree by unflattening you make it so). To be honest, I'm not sure if this is actually a valid pytree because it seems to break some of the assumptions of jax tree utils, but if it is, it seems hard to build around this. You can't apply any tree operations to it, so maybe you could look at the before and after of a tree map to determine if something is different.
I took a better look at eqx.tree_at
code, if I understand correctly 42
is supposed to be a _LeafWrapper but is actually a static int because my assignment dict["a"] = 42
erases the _LeafWrapper object, right ? Interestingly, this error seems reminiscent of JAX's doc warning about custom pytree initialization, except Equinox adds the custom class _LeafWrapper
on top of it. Like casting self.dict["a"] = jax.numpy.array(self.dict["a"]) actually fails because jnp.array
cannot handle _LeafWrapper
dtype.
In the end, one just needs to really be thoughtful in any kind of processing of attributes in __init__
of custom pytree nodes (including eqx.Module
).
Conceptually, it makes more sense to have "a" be static if it won't change, or have a be an optional flag in the constructor which defaults to 42, but that might just be me
I agree with you but in the more complex code underlying this MWE I'd like to keep dict["a"]
dynamic. Actually, in the real code I had this optional flag behavior, only optionally setting "a"
to 42
depending on a static attribute. Something like
if self.static==1:
self.dict["a"] = 42
I did not include it in my original MWE since eqx.tree_at
fails in the same way if the first mytree
is built with static=1
. Now I realize this makes perfect sense and was an error on my part, since this condition in always met in all subsequent calls to __init__
during unflattening. I shall rewrite the test into something like
if static == 1 and "a" not in self.dict.keys():
self.dict["a"] = 42
Thank you again, I'm closing this issue for now as I think I have the anwsers I need.
from equinox.
Related Issues (20)
- Graph Neural Networks HOT 1
- `UnexpectedTracerError` when vmapping `SpectralNorm` HOT 6
- gradient with sequential and jax.nn.<activation functions> HOT 4
- Weird `jax` error when trying vmap twice while using batchnorm HOT 5
- lax.scan for equinox Modules ? HOT 5
- Regarding trainable parameters in equinox HOT 9
- Getting a type error when initializing Linear module (unexpected keyword argument 'key') HOT 1
- merging multiple eqx.Module classes HOT 2
- How to serialize a model + state pair? HOT 2
- Strange error with multiple tuples as attributes and `equinox.tree_at`? HOT 5
- Jax 0.4.27: ValueError: safe_map() argument 2 is shorter than argument 1 HOT 7
- Comparisons to NNX HOT 3
- Custom JVP/VJP definition within Module? HOT 3
- Unroll in eqx's internal `scan` HOT 11
- Conv Transpose doesn't exactly transpose HOT 4
- Static/Dynamic issues with filter_eval_shape HOT 1
- Kolmogorov Arnold Networks with equinox [kanx] HOT 2
- Flash Attention HOT 6
- A simple DeepMLP is taking 11GB/16GB in GPU RAM HOT 4
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from equinox.