GithubHelp home page GithubHelp logo

Comments (36)

rosshemsley avatar rosshemsley commented on August 15, 2024 3

(Just writing to reassure you that we have not forgotten about this, it's just a busy week over here :) )

from optax.

pharringtonp19 avatar pharringtonp19 commented on August 15, 2024 1

@NeilGirdhar I just thought that as experienced users of optax, they might offer better insight, than an end user like myself, on the pros/cons of such a proposal

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024 1

@mtthss is now back, and we are taking a look at this in more detail :)

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024 1

@cgarciae Just so you know, I created a shim in tjax.gradient that has all 35 of the optax transforms that solves both your problems:

  • You can access all of the parameters using attribute access, and modify them using dataclasses.replace, and
  • you can pass them dynamically (since they're pytrees), and only the parameters that have to be static are static (preventing recompilation).

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024 1

Hey @NeilGirdhar,

We have discussed these proposals extensively within the optax maintainers team, and gathered feedback from a number of users and developers. We believe that the approach outlined below would be the best way forward,

The request here is for JAX to treat instances of optax.GradientTransformation as pytrees. In @NeilGirdhar’s proposal, this means combining optimizer hyperparameters with the init / update functions into a single dataclass. In @cgarciae’s proposal the state is additionally also included in the dataclass.

Making gradient transformations into pytree classes provides the ability to pass these objects to JAX jitted functions as dynamic rather than static arguments. On balance, the idea is reasonable and might be easier to understand for users who like OO patterns and the specific dataclass-based strand of OO programming that is being used in some JAX codebases. There are also downsides to this approach though - notably, the current optax API has a very sharp boundary between “functions implementing user behavior” (the init / update methods) and “data containers'' (i.e. the state). In the above proposals, hyperparameters and/or state are “folded into” the optimizer class, and the pure functions become methods on a mutable class instead. This is not a problem per se (it’s a common pattern in OO!), but it is different from the programming patterns of many of our users.

So what is the way forward?
As Optax developers, we do care a lot about users being able to extend and modify the optax components. The whole design focuses on composability and extensibility. We believe the best way forwards here is thus not to use dataclasses in each individual component, but rather to create wrappers that turn optax optimisers into dataclasses with the desired behavior.

This is the approach taken by @n2cholas in his inject_hyperparams (already upstreamed into optax) and is the approach taken by @cgarciae (in Treex). Even when living in libraries outside of optax, these wrappers are lighter to maintain than the current approach in tjax: for instance in @cgarciae’s Treex only one wrapper is maintained, and all optax transformations are exposed after wrapping them programmatically (instead of defining new separate classes for each optax component as in tjax, where new code to be written every time components are added to optax).

Our suggestion is thus to proceed in two stages
In an initial stage tjax could define its own wrapper (similarly to Treex).
Once a common dataclass implementation is provided by JAX itself we would be happy to see extensions of this kind upstreamed into optax (as we were to see inject_hyperparams).

E.g. we could imagine exposing a triplet of extensions:
inject_hyperparams (@n2cholas’s wrapper)
wrap_transformation_as_pytree (implementing @NeilGirdhar’s extension)
wrap_transformation_as_pytree_with_state (implementing @cgarciae’s extension)

We hope the above is acceptable, and while this might not be everybody’s favorite solution, we want to stress that we really appreciate everybody’s feedback and contributions, and we hope to continue seeing you all engaged in the development of optax and contributing to the library.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Hey @NeilGirdhar, thanks for your message. dataclasses certainly do have nice properties for these kind of cases - and by the way, chex.dataclass is a jax / pytree friendly implementation of this which may work as a drop-in for the dataclass implemented in your example.

With respect to recompliation, I didn't fully follow your example where recompilation would occur, would it be possible to write up an example e.g. in a colab? (https://research.google.com/colaboratory/) to give a concrete example?

Thanks!

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley
My mistake, it doesn't just cause recompilation. The gradient transform objects simply can't be passed to jitted functions. They would be able to be passed if they were re-implemented as dataclasses:

from typing import Any

import jax.numpy as jnp
from jax import grad, jit
from jax.lax import scan
from optax import adam
from tjax import dataclass
from tjax.gradient import adam as adam2


@dataclass
class State:
    parameters: float
    gt_state: Any

gt1 = adam(0.05)

def loss(parameters):
    return jnp.square(parameters - 1.0)

def inference(gt: Any):
    def do_one(state: State, _: None):
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory  #final_state.parameters

print(inference(gt1))

# print(jit(inference)(gt1))  # jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '<function chain.<locals>.init_fn at 0x7fa5cc682790>' of type <class 'function'> is not a valid JAX type.

gt2 = adam2(0.05)
print(jit(inference)(gt2))  # Works!

and by the way, chex.dataclass is a jax / pytree

Thanks, I'm happy to use chex.dataclass. However, I should point out thatchex.dataclass doesn't support static fields (google-deepmind/chex#104). The tjax.dataclass also has a plugin for MyPy.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Ok, I think I understand - it looks like the difference here is maybe not because it's a dataclass, but because the init function is a method rather than an attribute?

The reason for the jit error with your example is because JAX doesn't know how to jit a parameter that is a function type - jit is designed to work on functions that map to pytrees of arrrays.

The typical pattern for solving this is use a closure, like this:

gt = optax.adam(1e-4)

@jax.jit
def inference():
    def do_one(state: State, _: None):
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory  #final_state.parameters

print(inference()) # works, and is jitted

The optax approach is designed to mirror the functional style of JAX - for instance, it means one can do this (which is another way this is often used):

opt = optax.adam(1e-4)
opt_state = jax.jit(opt.init)(...)

(I think this would fail for the class-based approach, due to the self parameter)

This style is perhaps not that common in the Python world (which tends to be quite OO-heavy), but it's considered quite idiomatic for JAX :)

I think I'll mark this issue as 'wontfix' for now - but do let us know if you think we're missing something else!

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley

Sorry, but I think you misunderstood.

The point is that you should be able to parametrize any algorithm on the gradient transformation. For example, you should be able to pass adam(...) or fromage(...), etc. to a single jitted function and have that function work. Your example basically shows that you are currently forced to fix the gradient transformation—it cannot be passed in as a parameter.

it looks like the difference here is maybe not because it's a dataclass, but because the init function is a method rather than an attribute?

The init function is not a method at all. It's a function-valued attribute. (Worse even, it's a closure that closes over its parameters.) That's the problem.

The reason for the jit error with your example is because JAX doesn't know how to jit a parameter that is a function type - jit is designed to work on functions that map to pytrees of arrrays.

Right.

My proposal is to convert the optax closures into abstract methods on a pytree-like class. That way,jit will treat the class type as a static field and the class parameters as dynamic fields. This means that it recompiles when the class type changes, and doesn't recompile when the class parameters change. This is superior to the current behavior of crashing.

(I think this would fail for the class-based approach, due to the self parameter)

You're mistaken. The self parameter works like any other parameter. jit doesn't treat it any differently. However, this is a non-issue. Usually, users shouldn't be jitting methods like this, but instead they should jit the entire training algorithm.

If you want to jit the optax methods (which isn't necessary), you should jit the unbound method:

class GradientTransformState
  @jit
  def init(self, ...):

If the caller has correctly jitted the enclosing training function, this nested jit call is ignored by Jax.

I think I'll mark this issue as 'wontfix' for now - but do let us know if you think we're missing something else!

I think that's unfortunate because this is is a straightforward fix that has no downsides?

Ultimately, it means I may have to lift all of your code and re-implement it, and I would rather not have to maintain this going forward 😄

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

I also proposed this about a year ago here #1 and had a great discussion with @mtthss. My motivation then was being able to take gradients with respect to the parameters of the gradient transformation. You cannot do this the way the transformations are currently implemented because these values are closed over. Jax cannot take gradients with respect to closed over values because that will leak tracers.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

@NeilGirdhar Thanks for linking this thread, that's very helpful. @mtthss is out at the moment, but will return!

I will take a look through this thread and try to understand a bit better 👍

One consideration is that optax already has a fairly large base of users, so API changes that could break existing user code need to be managed carefully - this sets the bar high for making changes to the core parts of the library (even in cases where there may be clear benefits :) ).

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley Thank you for taking a second look 😄 I believe what I'm proposing has no associated API changes.

from optax.

pharringtonp19 avatar pharringtonp19 commented on August 15, 2024

@NeilGirdhar This seems like a really nice/intuitive idea.

At a high level, your proposal would allow us to "parameterize" an update function in such a way that we could we then vmap the update function across values of these parameters. Is that correct?

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@pharringtonp19 Yes!

Good thinking, I hadn't thought of that benefit.

from optax.

pharringtonp19 avatar pharringtonp19 commented on August 15, 2024

@NeilGirdhar Based on this discussion in jaxopt, I wonder if the authors of that library might have any thoughts on this proposal.

Additionally, I think this proposal would allow us to implement some of the features that higher provides to pytorch

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

wonder if the authors of that library might have any thoughts on this proposal.

Are you suggesting implementing my proposal somewhere else?

In case you didn't know, I ended up implementing all of the optax transforms as dataclasses here: https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient/transforms.py As you can see, this isn't any more complex than what optax currently does.

Of course, my dream is not to have to maintain this, but for optax to simply replace its classes (that use closures) with dataclasses.

Additionally, I think this proposal would allow us to implement some of the features that higher provides to pytorch

Yes, exactly.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

Good idea

from optax.

mtthss avatar mtthss commented on August 15, 2024

We are still discussing/investigating the compilation issue you raised, and we hope to be able to discuss this with you soon.

In the meantime, please take a look at our preferred approach for the meta-learning issue
#1 (comment)

The two issues could be seen as closely related, as until recently we wouldn't have been able to expose the optimiser definition to learning without switching to dataclasses, but we think this is now better solved by using @n2cholas inject_hyperparams. Please let us know what you think in that thread regarding the meta-learning side. We will continue discussing the recompilation issue on this thread as soon as we have performed more investigations.

from optax.

mtthss avatar mtthss commented on August 15, 2024

The point is that you should be able to parametrize any algorithm on the gradient transformation. For example, you should be able to pass adam(...) or fromage(...), etc. to a single jitted function and have that function work. Your example basically shows that you are currently forced to fix the gradient transformation—it cannot be passed in as a parameter.

@NeilGirdhar I think you are correct that in some cases we do want to pass the optimiser as input. However you can pass the optimiser as input by doing:

@functools.partial(jax.jit, static_argnums=0)
def inference(gt: Any):
    def do_one(state: State, _: None):
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory  #final_state.parameters

Since in the optax design the dynamic state is in the opt_state and not in the opt itself this is safe to do.

from optax.

mtthss avatar mtthss commented on August 15, 2024

I think this solves the issue from this thread

It leaves open the question as to whether we need the dataclass approach to expose opt hyperparams to learning. to consolidate the discussion I will close this issue and leave open the sister issue (#1 ) where you suggested the use of dataclasses to expose metaparams to learning.

lets continue discussing there.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

I think this solves the issue from this thread

I'm happy to close this as a duplicate of #1 😄

However, I want to point out that passing the optimizer statically does not solve the problem in this issue. The issue is called "prevent unnecessary recompilation". If you pass the optimizer statically, you will find that the jit will recompile every single time inference is called. You can verify this by adding a print statement inside inference and passing different instantiations of adam.

The dataclass version (adam2) that I proposed does not suffer from this problem.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Hey @NeilGirdhar, to help us ensure we are looking at the same thing, do you think you could write a short example of the recompilation issue in a google colab?

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley Sure, I'll just paste it here so that we have it:

from typing import Any

import jax.numpy as jnp
from jax import grad, jit
from jax.lax import scan
from optax import adam
from tjax import dataclass
from tjax.gradient import adam as adam2


@dataclass
class State:
    parameters: float
    gt_state: Any

def loss(parameters):
    return jnp.square(parameters - 1.0)

def inference(gt: Any):
    global compiled_times
    compiled_times += 1
    def do_one(state: State, _: None):
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory  #final_state.parameters

compiled_times = 0
static_jit_inference = jit(inference, static_argnums=0)
for _ in range(100):
    gt1 = adam(0.05)  # Even a fixed learning rate induces recompilation since the jit is hashing the id of the argument.
    static_jit_inference(gt1)
print(f"optax compiled {compiled_times} times")

compiled_times = 0
jit_inference = jit(inference)
for _ in range(100):
    gt2 = adam2(0.05)  # Even changing the learning rate does not induce a recompilation since the jit is only hashing the type.  The learning rate is passed dynamically.
    jit_inference(gt2)
print(f"tjax compiled {compiled_times} times")

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Thanks for the example!

Indeed, I think the optax recompilation happens because JAX is seeing that you are giving it arguments , that are not pytrees, that have a different hash.

In JAX, this will always cause a recompilation (so this is very much working as intended).

I suspect the tjax example works because the hash of the objects is the same - which is due I think to your dataclass being carefully designed to allow this :) - you are effectively passing a pair of static functions, along with some numerical data (which will I think turn out to be equivalent to the below example).

Most JAX users (at DeepMind) use JAX to define pure functions that take pytrees and return pytrees, so it's not a pattern I've seen before to pass a general class as an argument to a JAX function like this.

Most users would use a closure (or static arg) as follows:

def build_inference(opt: Any):
  @jax.jit
  def inference(learning_rate):
    gt = opt(learning_rate)

    def do_one(state: State, _: None):
        print('Compiled')
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory
  
  return inference

Here's an example of that working in a colab.

Do you think you could use the pattern above to solve the recompilation issue?

[By the way, the use of a global (compiled_count) outside of the function looks fishy to me: JAX functions should avoid having side effects - so this might result in undefined behaviour :) ]

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

(by the way, I tried to run your example in this colab, but the tjax code failed - I didn't dig too deep, but it would have been cool to try it!)

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

Indeed, I think the optax recompilation happens because JAX is seeing that you are giving it arguments , that are not pytrees, that have a different hash.

Exactly.

I suspect the tjax example works because the hash of the objects is the same

Actually, they don't have hash values at all (since they're dataclasses).

The reason the tjax example works is because the dataclasses are pytrees, and JAX will only hash the type of the tree, the shapes of its dynamic fields (which are unchanged), and the values of its static fields (of which there are none).

which is due I think to your dataclass being carefully designed to allow this :)

Nope. There's no such trickery at all. I could have used the flax.dataclass and achieved the same result.

Most JAX users (at DeepMind) use JAX to define pure functions that take pytrees and return pytrees, so it's not a pattern I've seen before to pass a general class as an argument to a JAX function like this.

The functions in my example are pure functions (besides the compilation counter). There is no editing of global state.

Most users would use a closure (or static arg) as follows:

Sorry, I think you're missing the point of this issue. This issue is about preventing recompilation of jitted functions that accept the optimizer. Your example's jitted function does not accept the optimizer. I tried to explain this more carefully in this previous comment: #197 (comment)

[By the way, the use of a global (compiled_count) outside of the function looks fishy to me: JAX functions should avoid having side effects - so this might result in undefined behaviour :) ]

In general you're right, but the use of the global is to track recompilation. Since every compilation will re-enter the function, this is a fine way of tracking recompilation. I think there's a JIT stats somewhere, but I've never used it. Feel free to link to it and we can use that instead if it makes you more comfortable.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

(by the way, I tried to run your example in this colab, but the tjax code failed - I didn't dig too deep, but it would have been cool to try it!)

That looks to be because colab still uses Python 3.6, which tjax's type annotations don't support. tjax and JAX both dropped Python 3.6 as per NEP 29. I just pushed a change that might address the incompatibility.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Nope. There's no such trickery at all. I could have used the flax.dataclass and achieved the same result.

:) my meaning with "trickery" is that usually dataclasses don't have methods - they are structured data containers. By decorating your optimizer class as a dataclasses, it is treated as data by the JAX API, even though they are "semantically" regular classes with methods and data attributes. I can't shake the feeling that this could break the programming model somehow - although I admit I haven't got a concrete reason why! Maybe it's just fine and I'm not using to see this kind of code :)

The functions in my example are pure functions (besides the compilation counter). There is no editing of global state.

Agree they are pure - the key part is that the arguments aren't traditional pytrees here - this is the thing that's slightly unusual for me.

Sorry, I think you're missing the point of this issue. This issue is about preventing recompilation of jitted functions that accept the optimizer. Your example's jitted function does not accept the optimizer. I tried to explain this more carefully in this previous comment: #197 (comment)

Where I'm coming from is that the current optax API is not designed to work in this way (passing a gradient transformation as an argument) - and if users try this, they will get a hard error. The example I wrote is to (hopefully) show that this API limitation doesn't necessarily block users from expressing what they would like with optax, though.

Do you have an example where passing the gradient transformation as an argument turns out to be important (and you couldn't use the above solution?)

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

That looks to be because colab still uses Python 3.6, which tjax's type annotations don't support. tjax and JAX both dropped Python 3.6 as per NEP 29. I just pushed a change that might address the incompatibility.

I think colab is 3.7 now:

import sys
print(sys.version_info)

gives

sys.version_info(major=3, minor=7, micro=12, releaselevel='final', serial=0)

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

I can see where you are coming from with this proposal, so perhaps I can give some context to why we are trying to dig deeper here,

Last week I made the following change to the optax API,

-- OptState = NamedTuple  # Transformation states are (possibly empty) namedtuples.
++ OptState = chex.ArrayTree  # States are arbitrary nests of `jnp.ndarrays`.

Which semantically doesn't look like it should change anything, however...
It took nearly three weeks to successfully roll this out to every user of Optax without breaking user code 🥲

Making even small changes to the base of the library is very time consuming, so we have to be very selective about which changes we take on. The most convincing reason to make a change would be if there are optimization problems that users couldn't express using optax without making the change. So users can help us the most by uncovering those cases, and showing us specifically how the API is breaking when trying to solve their optimization problem.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

Do you have an example where passing the gradient transformation as an argument turns out to be important (and you couldn't use the above solution?)

Here's a concrete example: Training can be used as the application of an iterated function with a fixed point. Therefore, you can use the "two phase method" (see https://github.com/gehring/fax) to find higher-order derivatives around the fixed point.
To implement this, you either need to fix the gradient transformation with which you do the updates or else pass it as a parameter to the two-phase algorithm.

Why support passing it as a parameter? The big benefit is that you can then calculate higher-order gradients with respect to the learning parameters. If you make gradient transformations into closures then they cannot be passed dynamically. That's just unnecessarily limiting.

:) my meaning with "trickery" is that usually dataclasses don't have methods

Oh! I thought you were suggesting that I'd hacked the hash values to match or something like that. Yes, I agree that a lot of dataclass examples don't have methods because they tend to be rapid prototypes.

If you have time, take a look at my efax library and its use of dataclasses to rapidly generate a variety of exponential family distributions. It generates a fairly large interface using very little code. I strongly believe that datalcasses are an excellent way of writing readable JAX code.

I love that you've moved from NamedTuple to chex.ArrayTree However, dataclasses have some advantages over ArrayTree:

  • they support static fields (which come up quite often),
  • they have types, which means your type checker knows the difference between an OptState and Params and will report transpositions of the two, and
  • they support dynamic binding—something optax currently accomplishes by storing closures, but that will get ugly as soon as your gradient transformation type hierarchy is more than one level).

Making even small changes to the base of the library is very time consuming,

Yes, I empathize with that. Let's talk about what problems could be caused by this change. First of all, the interface (init and update) is unchanged, and the base class is unchanged. There are currently no subclasses. So all normal uses are unaffected.

The base class GradientTransformation is a NamedTuple. It would be changed to a dataclass, which means that anyone iterating over it for some extremely weird reason would have to use its members directly. My guess is that GradientTransformation was only made iterable for the historical reason that NamedTuple was the only structured pytree when optax was started.

As unlikely as it is, it may be possible that someone is statically passing a gradient transformation to a jitted function. After this change, it has to be done dynamically (by simply removing the parameter from the static_argnums). And such an affected user would benefit immediately from avoiding recompilation and the capacity to take gradients and vectorize over parameters of the gradient transformation.

Both problems are easy to find: the first would cause an error saying that the dataclass is not iterable, and the second would cause an error saying that it's not hashable. And they both have easy fixes.

I suggested this change in issue #1 because I felt strongly about it then. I ended up duplicating your library in tjax.gradient since I didn't have time for a long discussion. But now the amount of time it's taking to maintain a parallel library is growing, which is what's motivating me to revisit this. I understand your point about not taking on extra work. I believe that this will be much less work in the long run. And I'm also willing to help in any way that I can. Please let me know what I can do.

from optax.

mtthss avatar mtthss commented on August 15, 2024

Hello @NeilGirdhar !

I agree that we need to be able to pass an optimiser to a function without triggering recompilation.
But in your example that is easily done as follows:

from typing import Any

from typing import NamedTuple
import jax.numpy as jnp
from jax import grad, jit
from jax.lax import scan
from optax import adam


class State(NamedTuple):
    parameters: float
    gt_state: Any

def loss(parameters):
    return jnp.square(parameters - 1.0)

def inference(gt: Any):
    global compiled_times
    compiled_times += 1

    def do_one(state: State, _: None):
        gradient = grad(loss)(state.parameters)
        new_gradient, gt_state = gt.update(gradient, state.gt_state, state.parameters)
        parameters = state.parameters + new_gradient
        return State(parameters, gt_state), parameters

    state = State(0.0, gt.init(0.0))
    final_state, trajectory = scan(do_one, state, None, 100)
    return trajectory  #final_state.parameters

compiled_times = 0
static_jit_inference = jit(inference, static_argnums=0)

gt1 = adam(0.05)
for _ in range(100):
    static_jit_inference(gt1)
print(f"optax compiled {compiled_times} times")  # OUT: optax compiled 1 times

The issue was creating a new optax optimiser within each iteration of the loop, that does and should trigger recompilation,
as you are changing the static definition of the optimiser. But you don't need to reinstantiate the optimiser to change dynamic state. For instance if you want to change the hyper-params (e.g. to expose them to learning) that is supported by wrapping it with inject_hyperparams. Then the hyper-params are fed as part of the dynamic state and can change dynamically within the training loop without triggering recompilation. See the example I shared in #1.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

Sorry @mtthss I guess I haven't been completely clear. The point of this issue is to prevent unnecessary recompilation. Yes, it's true that there's no recompilation if you pass the exact same optimizer object. The point is to support passing different optimizer objects with different parameters. Did you notice the comment in the code that says "Even changing the learning rate does not induce a recompilation since the jit is only hashing the type."?

Let me explain my actual design and problem: I have a user interface that allows me to set parameters of the problem (e.g., the number of dimensions of the input), and the solution (e.g., the size of the model, the gradient transformation), and various other parameters.

These parameters are passed to a single jitted training function. That function runs, produces an output, which is then visualized. After looking at the visualization, I adjust the parameters, and run the simulation again. I do not want the training function to recompile! It takes something like 30s to compile as it is.

That's why, I'm very careful to make sure that only the things that must be static are static. For example, the number of training iterations is currently static because I'm using a scan. (Using a while cannot currently be optimized on the GPU, but I am considering a hybrid solution.)

Your idea to use inject_hyperparameters doesn't help me at all. I don't want to pass in the object it returns because I'm not trying to set the dynamic parameters inside the jit. I'm setting everything outside the jit. I just want the jit not to recompile when I change the parameters.

The only approach that works around this problem would be to construct the gradient transformation object inside the jitted training function. But this is bad design. By the separation of concerns, the training function shouldn't know anything about optax gradient transformations except the interface (update and init). I don't want to pass in the class type and its parameters, but yes I could do that. What looks simpler to you?

# Using tjax.gradient.adam
gt = adam(learning_rate=1e-2)
train(..., gradient_transorm=gt, ...)
# where
@jit
def train(..., gradient_transform: GradientTransformation, ...):
  # use gt

versus

# Using optax.adam
train(..., gradient_transform_callable=adam, gradient_transform_parameters={'learning_rate': 1e-2}, ...)
# where
@partial(jit, static_args=(k,))
def train(..., gradient_transform_callable: Callable[..., GradientTransformation], gradient_transform_parameters: Mapping[str, Any], ...):
  gradient_transform = gradient_transform_callable(**gradient_transform_parameters)

It should be clear that the first one is much simpler. And the only reason it doesn't work with optax is because the gradient transformations are mysteriously not pytrees.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley By the way, I was thinking about your concerns about breaking code. You said: "It took nearly three weeks to successfully roll this out to every user of Optax without breaking user code".

I don't want to break your users' code either. Here's an idea that provides pytree-like gradient transformation objects without breaking anyone's code. Implement all of the gradient transformation objects using dataclasses in optax.gradient.*. Then create shims in optax.* that are based on NamedTuple just like they are now. Nothing changes for your users at the cost of a little bit of extra work when adding a transformation.

from optax.

cgarciae avatar cgarciae commented on August 15, 2024

Hey, I am very late to this party but I want to support the idea of pytree optimizers:

  • Currently you can't access nor modify the hyper-parameters of an optimizers as they are captured by closures, this is inconvenient for code that e.g. would like to try to figure out the base learning rate set by the user.
  • Its more convenient for users to pass the optimizers directly to the train_step instead of having to pass them by closure.

Treex has this Optimizer class that wraps Optax optimizers and turns them to pytrees for convenience. It only solves the second point as its only a wrapper, to solve the first point you need a proper pytree implementation.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@rosshemsley

Thank you very much this comment, and actually for the whole discussion. We are all (the Optax team, @cgarciae, me) working towards improving the ecosystem of Jax tools, and we're all in the process of figuring out the best way to do things. I also want to thank you for considering this issue so thoroughly and for replying so quickly.

I completely agree with your proposal that waiting for the Jax team to add dataclasses is probably the best thing to do 😄

I just want to comment on a few things in your last message, if you don't mind:

notably, the current optax API has a very sharp boundary between “functions implementing user behavior” (the init / update methods) and “data containers'' (i.e. the state). In the above proposals, hyperparameters and/or state are “folded into” the optimizer class

My proposal doesn't change this. The state is still the same class (I didn't touch the state objects, except to give them a base class). And the optimizer has exactly the same

  • functions, except my version exposes them as member functions rather than function-valued member variables, and
  • variables, except that my version exposes them as attributes rather than closing over them.

and the pure functions become methods on a mutable class instead

Actually, the dataclasses that I use are "frozen", which means they are immutable. The member functions are pure.

I understand this instinct to think that OO means impure. In the case of frozen dataclasses, OO is pure.

but it is different from the programming patterns of many of our users.

My proposal does not change anything for your users. You should test it out: Just replace from optax import adam with from tjax.gradient import Adam. If you do test it, please let me know if it still works. It should be identical.

Can I make a much more modest suggest then? Since you plan on reconsidering this in the future, why not at least warn on those usages of your optimizers that would make upgrading difficult. Specifically, sequence access to the Optimizer. Therefore, I suggest you add something like:

import warnings

def w():
  warnings.warn("Sequence access of GradientTransformation is deprecated.", DeprecationWarning)

class GradientTransformation(NamedTuple):
  """Optax transformations consists of a function pair: (initialise, update)."""
  init: TransformInitFn
  update: TransformUpdateFn

  def __getitem__(self, key):
    w()
    return super().__getitem__(key)

  def __len__(self):
    w()
    return super().__len__()

  def __contains__(self, item):
    w()
    return super().__contains__(item)

  def __iter__(self):
    w()
    return super().__iter__()

  def __reversed__(self):
    w()
    return super().__reversed__()

  def index(self, item):
    w()
    return super().index(item)

  def count(self, item):
    w()
    return super().count(item)

from optax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.