GithubHelp home page GithubHelp logo

`eqx.tree_at` fails with `TypeError: Operation undefined, {x} is not a leaf of the pytree` when custom pytree attributes are modified within `__init__` about equinox HOT 5 CLOSED

nicolasJouvin avatar nicolasJouvin commented on June 9, 2024 1
`eqx.tree_at` fails with `TypeError: Operation undefined, {x} is not a leaf of the pytree` when custom pytree attributes are modified within `__init__`

from equinox.

Comments (5)

lockwo avatar lockwo commented on June 9, 2024 1

That's true, tree operations won't necessarily fail directly, but won't act as one might expect. E.g. jax.tree_map(lambda x: jnp.array(x), mytree) does not result in a pytree of arrays (42 is still an int)

from equinox.

lockwo avatar lockwo commented on June 9, 2024

This is failing because pytree = jtu.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf) in eqx.tree_at can't make {"a": 42} a LeafWrapper. Specifically, tree map in general doesn't modify {"a"} (e.g. if you just treemap to cast everything to an array, a will always be an int because when you create the pytree by unflattening you make it so). To be honest, I'm not sure if this is actually a valid pytree because it seems to break some of the assumptions of jax tree utils, but if it is, it seems hard to build around this. You can't apply any tree operations to it, so maybe you could look at the before and after of a tree map to determine if something is different.

Conceptually, it makes more sense to have "a" be static if it won't change, or have a be an optional flag in the constructor which defaults to 42, but that might just be me

from equinox.

HGangloff avatar HGangloff commented on June 9, 2024

As for the validity of the pytree, this is an interesting question. A tree_map over such custom pytree won't fail, e.g.

jax.tree_util.tree_map(lambda x:3 if isinstance(x, int) else x, mytree)

from equinox.

HGangloff avatar HGangloff commented on June 9, 2024

When you mention an optional flag, you think about something like:

class CustomPytree:
    def __init__(self, dict, attr, static, a_exists=False):
        self.dict = dict
        # preprocess attribute -> always set key "a" to 42
        if not a_exists:
            self.dict["a"] = 42
        self.attr = attr
        self.static = static


jax.tree_util.register_pytree_node(
    CustomPytree,
    lambda self: (
        (self.dict, self.attr),
        {"static": self.static},
    ),
    lambda aux, children: CustomPytree(*children, aux["static"], a_exists=True),
)

mytree = CustomPytree(dict={"b": 2}, attr=1, static=1) 

upon which eqx.tree_at(lambda t: t.attr, mytree, 3) and jax.tree_map(lambda x: jnp.array(x), mytree) now seem to work.
But this seems hacky for you ?

from equinox.

nicolasJouvin avatar nicolasJouvin commented on June 9, 2024

Thank you for your answer !

This is failing because pytree = jtu.tree_map(_LeafWrapper, pytree, is_leaf=is_leaf) in eqx.tree_at can't make {"a": 42} a LeafWrapper. Specifically, tree map in general doesn't modify {"a"} (e.g. if you just treemap to cast everything to an array, a will always be an int because when you create the pytree by unflattening you make it so). To be honest, I'm not sure if this is actually a valid pytree because it seems to break some of the assumptions of jax tree utils, but if it is, it seems hard to build around this. You can't apply any tree operations to it, so maybe you could look at the before and after of a tree map to determine if something is different.

I took a better look at eqx.tree_at code, if I understand correctly 42 is supposed to be a _LeafWrapper but is actually a static int because my assignment dict["a"] = 42 erases the _LeafWrapper object, right ? Interestingly, this error seems reminiscent of JAX's doc warning about custom pytree initialization, except Equinox adds the custom class _LeafWrapper on top of it. Like casting self.dict["a"] = jax.numpy.array(self.dict["a"]) actually fails because jnp.array cannot handle _LeafWrapper dtype.

In the end, one just needs to really be thoughtful in any kind of processing of attributes in __init__ of custom pytree nodes (including eqx.Module).

Conceptually, it makes more sense to have "a" be static if it won't change, or have a be an optional flag in the constructor which defaults to 42, but that might just be me

I agree with you but in the more complex code underlying this MWE I'd like to keep dict["a"] dynamic. Actually, in the real code I had this optional flag behavior, only optionally setting "a" to 42 depending on a static attribute. Something like

if self.static==1:
  self.dict["a"] = 42

I did not include it in my original MWE since eqx.tree_at fails in the same way if the first mytree is built with static=1. Now I realize this makes perfect sense and was an error on my part, since this condition in always met in all subsequent calls to __init__ during unflattening. I shall rewrite the test into something like

if static == 1 and "a" not in self.dict.keys():
  self.dict["a"] = 42

Thank you again, I'm closing this issue for now as I think I have the anwsers I need.

from equinox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.