GithubHelp home page GithubHelp logo

Comments (25)

Waterkin avatar Waterkin commented on August 15, 2024 1

@Waterkin It appears your eta is a dict.

Anyway, I think the optax is solution is over-complicated. It's a lot easier to simply use the transformations in tjax.gradient, which have all meta-parameters as ordinary dynamic values.

Thanks for your reply, Neil. I'll see if tjax helps.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

Actually, I'll just use it as is, and if I run into this problem, I'll either reopen this issue or build a shim around optax with a more object oriented design.

from optax.

mtthss avatar mtthss commented on August 15, 2024

Hi Neil,

Meta-gradients are definitely something that is very close to my heart,
As I have been using optax in a number of recent MG papers
E.g. https://arxiv.org/abs/2007.08794, https://arxiv.org/abs/2007.08433

You will find that many of the building blocks in optax are carefully crafted to ensure that they are suitable for meta-gradients (including subtle numerical issues; see for instance the eps_root argument in scale_by_adam - only important if you want to backprop through an update - or the specific implementation of clip_by_norm - also carefully written to support back propping through it).

Regarding your specific query, here is how I would implement meta-learning of a global learning rate
(written quickly just to give the idea)

optim = optax.chain(optax.clip_by_global_norm(10.), optax.scale_by_adam())
meta_optim = optax.chain(optax.clip_by_global_norm(0.1), optax.scale_by_adam())

def inner_update(theta, theta_opt_state, eta, batch):
  """Perform a step of inner update to the agent."""
  grads = jax.grad(inner_loss)(theta, batch)
  updates, new_theta_opt_state = optim.apply(grads, theta_opt_state)
  scaled_updates = tree_map(lambda t: eta * t, updates)
  new_theta = optax.apply_updates(theta, updates)
  return new_theta, new_theta_opt_state

def inner_updates_and_outer_loss(eta, theta, theta_opt_state, rollout_arr):
  """Perform a few inner updates and then compute validation loss."""
  for batch in batches[:-1]:
    theta, theta_opt_state = self.inner_update(theta, theta_opt_state, eta, batch)
 loss = self.outer_loss(theta,  batches[-1])
 return loss, (theta, theta_opt_state)

def update(theta, theta_opt_state, eta, eta_opt_state, batches):
  """Update params(theta) and meta-params (eta)."""
  (new_theta, new_theta_opt_state), meta_grads = jax.grad(inner_updates_and_outer_loss, has_aux=True)(
      eta, theta, theta_opt_state, batches)
  eta_updates, new_eta_opt_state = meta_optim.apply(meta_grads, eta_opt_state)
  new_eta = optax.apply_updates(eta, eta_updates)
  return (new_theta, new_theta_opt_state, new_eta, new_eta_opt_state)

It would be possible to abstract some of this away into a suitable meta-gradient / meta-learning library built on top of optax, and I would be excited to see that! The aim of optax is to provide the shared, low level, well tested, flexible building blocks,
that the community can use to construct such (and many other) higher level libraries. By keeping optax fairly slim we can ensure we can reuse these building blocks across many higher-level libraries, therefore sharing effort and code.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@mtthss Thank you, yes. The use case I was trying to explain is when you can't easily "break out" eta as in your example where eta is just a learning rate. If instead you had added optax.scale(-eta) to the chain in optim. Or similarly, if eta were instead the b1 parameter to scale_by_adam. (And incidentally the problem with breaking out parts of the gradient transformation is that it's bug-prone. Did you forget a minus sign in scaled_updates?)

Then when you go to calculate ax.grad(inner_updates_and_outer_loss..., this depends on inner_update, which depends on optim, which closes over parameters (like b1 in scale_by_adam). This means that your gradient can easily result in leaked tracers in JAX.

Leaked tracers have cost me weeks of debugging time, and now I'm extremely wary of closing over tracers. That's why I prefer the object oriented design I suggested so that the objects can be registered as pytrees and this problem is avoided.

from optax.

mtthss avatar mtthss commented on August 15, 2024

Did you forget a minus sign in scaled_updates?
Yes, in my defence I did write that code in less than 5 minutes,
its more meant to be pseudo-code than actual runnable Python :)

The approach above is not specific to eta being just a learning rate, e.g. I have used to backprop through optax updates to meta-learn all hyper-parameters of complex RL loss functions (eta = {gamma, lambda, baseline_cost, entropy_cost, ...}), or even the entire loss function in the inner update (eta = {parameters of a neural network}).

However, I do see better your point now, in that while I can back-prop through the gradient transformations into any hyper-param of the inner loss (my typical setting), the hyper-parameters of the gradient transform itself are not natively exposed to meta-gradients, unless you break them out as I did for the learning rate.

I would be quite interested in seeing an object oriented design for a meta-optax extension library, that is compatible with the optax API (e.g. each GradientTransformation offers the init/apply public methods, and thus is compatible and chainable to standard optax transforms) but whose transforms allow to expose the hyper-parameters of the transformations themselves to meta-learning. If API compatible, and I think you could make it so, it would be an amazing extension library for optax.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

I would be quite interested in seeing an object oriented design for a meta-optax extension library,

What do you think of this? I'm still working on the design to get it to fit with my project.

from optax.

mtthss avatar mtthss commented on August 15, 2024

That's a good start, a major thing that I would strongly suggest is that you switch the method names to the signature of GradientTramsformation in optax. Then your transformations would be immediately interoperable with optax natives transformations. The init/update is also somewhat canonical across several other JAX libraries as well. Looking forward to see further iterations of this.

Few other smaller points:

  1. Instead of defining your own dataclass, why not use the dataclass defined in the chex library?
  2. Some of the type annotation you define in your codebase are also defined in the chex library (and a couple others could maybe be upstreamed there?). In the Tensor annotation btw the Union is superflous, see for instance chex.Array

from optax.

mtthss avatar mtthss commented on August 15, 2024

I do like of the design how it manages to build the support for meta-gradients on top of the existing optax low level components. If in addition you make sure the transformations are interoperable it would be really quite cool.

from optax.

mtthss avatar mtthss commented on August 15, 2024

One interesting thing is that the dataclass structure for Scale and ScaleByAdam is actually exactly the same (modulo which optax transformation is wrapped). I actually wonder if we could even implement this as a generic wrapper, that takes an arbitrary optax transformation and allows to inject the arguments according to the pattern from your code.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

That's a good start, a major thing that I would strongly suggest is that you switch the method names to the signature of GradientTramsformation in optax. Then your transformations would be immediately interoperable with optax natives transformations. The init/update is also somewhat canonical across several other JAX libraries as well.

Oh, I didn't realize that! Okay, that's pretty convincing.

Instead of defining your own dataclass, why not use the dataclass defined in the chex library?

The main reason is that I updated the dataclass mypy plugin so that it understands the replace method and its arguments: plugin. It also supports inheritance with regular mixins and classes, but I've never actually used that.

Some of the type annotation you define in your codebase are also defined in the chex library (and a couple others could maybe be upstreamed there?)

Nice! I'll just use your types where I can and update my codebases. Feel free to upstream whatever you like. I guess that one day JAX will be able to use Annotated to promise real (non-complex) Array types.

I actually wonder if we could even implement this as a generic wrapper, that takes an arbitrary optax transformation and allows to inject the arguments according to the pattern from your code.

Yes, I also considered that. I don't know how to lift out the parameters to scale_by_adam and create member variables out of them. I guess that it's possible with inspection, but I just wanted to see if I could fit this interface into my larger project.

Looking forward to see further iterations of this.

Perhaps you could help me with some design? I've been thinking about how to get the meta-learning working, and right now the way it's done is to take the gradient with respect to the objects of type, say, Scale, ScaleByAdam, etc. This works fine as long as I want the gradient with respect to every meta-parameter, and it works so long as these objects only contain real values. But what if there is some Boolean flag, or integer component? Then, this approach won't work. What are your thoughts on this?

I implemented my first thought, which is to pass in a MetaParameter object for the meta-parameters. These contain a name. Then, update accepts a mapping from names to values, and looks up the meta-parameters by name and fills them in. In practice, JAX could calculate gradients with respect to these values (I guess they would be JAX tracers). What do you think?

from optax.

mtthss avatar mtthss commented on August 15, 2024

Adding an optional meta_params is definitely one option. For the main library I am probably still going to look into having a generic wrapper that allows to lifts the hyper-params of an arbitrary transform and replace it with a traceable array. Since meta-learning is still somewhat niche (although I am very into it myself) I want to make sure we support it without making any harder for users to write their own transforms and mix and match them with the existing one. I have some ideas on how this might be done, but I will be away for a couple of weeks, so your approach definitely seems like a reasonable to get you going in the short term.

from optax.

mtthss avatar mtthss commented on August 15, 2024

For context I am thinking of something along these lines,
Take with grain of salt, just sharing now because I will be away for a couple of weeks,
and might serve as inspiration to you in the meantime.

import inspect
from jax import tree_util

def lift_hypers(transform):

  class Wrapped():

    def __init__(self, *args):
      self.args = args

    def init(self, params):
        return transform(*self.args).init(params)

    def update(self, grads, state, params=None):
        return transform(*self.args).update(grads, state, params)

    def replace(self, args):
        return Wrapped(*args)

  tree_util.register_pytree_node(
      nodetype=Wrapped,
      flatten_func=lambda d: tree_util.tree_flatten(d.args),
      unflatten_func=lambda s, xs: Wrapped(*xs))

  return Wrapped

This should allow you to wrap an arbitrary transform:

meta_optim = lift_hypers(optax.scale)

In such a way that you can update its hyper-parameters (using replace), you can jit it, etc...
This is just a rough sketch (although does pass some basic checks I ran in a colab), would need to be cleaned up.
But has the big advantage of allowing to extend both existing and arbitrary user defined transforms.

EDIT: To avoid confusion I removed a leftover line from a previous iteration of the lift_hypers wrapper.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

I don't understand what you've done here? You get the args from transform, and then you shadow that variable everywhere so it's never used? And what's the point of lift_hypers?

It seems like lift_hypers tried to transform your functional design into an object oriented design, but why not just start with the object oriented design in the first place and do away with lift_hypers entirely?

Maybe it would just be easier to absorb my objected oriented interface into optax since it implements your interface exactly and allows inspection and comparison, and most importantly doesn't have the pernicious closed-over tracers.

from optax.

mtthss avatar mtthss commented on August 15, 2024

Sorry ignore the args = inspect.getfullargspec(transform).args that was a left over. I now edited the previous post to remove it to avoid confusion.

There are a few orthogonal points to that example:

  1. we can use the object oriented pattern you suggested to create a simple wrapper around arbitrary functional transforms. This would allow replacing hyper-parameters with trainable meta-parameters and exposing them to gradients. But would not impose any additional structure to users who just want to define custom transforms in the traditional way. In particular it does not require to have the x = self.replace_meta_parameters_with_defaults(meta_parameters) line in the implementation of each transform.

  2. we do not need to add to the API a special meta_parameters argument which might be fairly obscure to non meta-learning initiated users (typically the majority). We can instead have a method (in this case 'replace` but could find a different name) that injects the metaparams by returning you a new instance of the transform.

  3. we don't need to make it a dataclass, we can also register the class to ensure it is flattened unflattened appropriately.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

we do not need to add to the API a special meta_parameters argument which might be fairly obscure to non meta-learning initiated users (typically the majority). We can instead have a method (in this case 'replace` but could find a different name) that injects the metaparams by returning you a new instance of the transform.

Good idea. Essentially, you're popping out the metaparameter stuff outside to keep the ordinary transform simpler. I agree that's nicer.

we don't need to make it a dataclass, we can also register the class to ensure it is flattened unflattened appropriately.

That's true. In the thousands of lines of JAX code I've written, I no longer register any classes. I've found the dataclass approach simpler (less code, fewer bugs). I guess that's a matter of personal style.

we can use the object oriented pattern you suggested to create a simple wrapper around arbitrary functional transforms.

Why not just use the object oriented pattern in the internal optax definitions, and then I don't need wrappers at all? You can keep the same user interface, so it's totally transparent to users.

from optax.

mtthss avatar mtthss commented on August 15, 2024

I've found the dataclass approach simpler

I broadly agree. Only downside is that I have had sometimes problems with other libraries not dealing well with Python dataclasses, just because they are fairly new. Overall though, I think I am on board on using dataclasses in the wrapper above.

Why not just use the object oriented pattern in the internal optax definitions

That is a possibility that I am considering, but one that I'd want to evaluate carefully. In the mean time, I will probably first focus on providing a well tested implementation of the wrapper above (modified to use a dataclass as suggested by you). This will address the immediate concern of supporting meta-learning of the hyper-parameters of any existing or user defined transform.

Thank you very much for the great discussion and suggestions, this has been very useful.

from optax.

mtthss avatar mtthss commented on August 15, 2024

Does the inject_hyperparameters wrapper solve the optimisation of metaparameters?

Since this allows you to feed in updated hyper-parameters on every step, it should also allow to take gradients with respect to these (havn't tested it yet), wdyt?

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@mtthss I don't see the wrapper, but I just use rewrote your transforms as dataclasses: https://github.com/NeilGirdhar/tjax/blob/master/tjax/_src/gradient/transforms.py Since these are pytrees, I can differentiate with respect to their parameters.

from optax.

mtthss avatar mtthss commented on August 15, 2024

The current approach where hyper-parameters are captured as closure is designed to keep two things separate:

  • the definition of the optimiser, statically captured in closures
  • the dynamic state of the optimiser, captured in the opt_state

The crux is then: how do we enable users who want to make the optimiser's definition dynamic?

One possibility is to drop entirely the distinction between static optimiser definition (captured by closures) and dynamic optimiser state (captured by the opt_state).

However a second possibility is to preserve this distinction, and instead enable the user who wants to make parts of the optimiser definition dynamic to do so explicitly. How can we do this cleanly and easily?

Until recently we did not have the right tooling for this. However I think the inject_hyperparams wrapper contributed by @n2cholas is exactly the right tool for the job. It enables the user to cleanly and easily prescribe what parts of the optimiser def to make dynamic by just writing things like:

opt = optax.inject_hyperparams(optax.adam)(learning_rate=jnp.array(0.1))

This is sufficient to allow the user to

  1. compute gradients with respect to the hyper-parameter that was exposed to learning
  2. use these gradients to update the hyper-parameter

We put together a complete example for meta-learning an optimiser's hyperparams in the docs
https://optax.readthedocs.io/en/latest/meta_learning.html

Could you take a look and tell us what you think?

I think this is the right approach for meta-learning because it preserves the clear separation
between what's static (in closure) and what's dynamic (in the state),
this is idiomatic in JAX and well suited to the functional nature of JAX programming.

Note that we have seen internally in the past situations were mixing static and dynamic
(often by using the dataclasses as a form of JAX friendly OO programming)
could be initially appealing, but sometimes resulted in subtle bugs,
so if we can preserve clear distinctions and be explicit about what is part of the state
I think it's better in the long run.

That said, lets continue the discussion, and we will continue looking into the compilation issue as well.

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

One possibility is to drop entirely the distinction between static optimiser definition (captured by closures) and dynamic optimiser state (captured by the opt_state).

I think you've misunderstood. I'm not suggesting that you drop this distinction. The distinction is excellent design.

I'm merely suggesting that you store the optimizer definition using a pytree rather than a closure. Jax is horrible with closures. Pytrees allow vectorization, gradients, and prevent recompilation of jits.

How can we do this cleanly and easily?

I've already coded this. Is there any problem with lifting my code into optax?

I think this is the right approach for meta-learning because it preserves the clear separation
between what's static (in closure)

I agree that you should definitely preserve this clear separation. However, what is the benefit of hanging on to closures over dataclasses?

but sometimes resulted in subtle bugs

Would you mind elaborating? I've also written a tremendous amount of JAX code, and I haven't seen any drawbacks to using dataclasses. On the contrary, code that uses dataclasses is easier to read, and has fewer bugs because you always get the static-vs-dynamic question right since it's baked into the dataclass.

My guess is that the dataclasses you used (the ones in chex, for example) don't have the option to marking static fields. Then, treating fields that should be static as dynamic is going to cause bugs.

Could you take a look and tell us what you think?

Yeah, it's an interesting tool, which looks really useful! I'm not sure it solves the goal of this issue though. The main goal of this issue is in the bulleted list at the top: taking derivatives with respect to parameters of the optimizer.

from optax.

mtthss avatar mtthss commented on August 15, 2024

My guess is that the dataclasses you used (the ones in chex, for example) don't have the option to marking static fields. Then, treating fields that should be static as dynamic is going to cause bugs.

Indeed that is the most common source of bugs. Thanks for the clarification that your dataclasses allow for this.
I was indeed thinking of chex.dataclass :).

I am still a bit uncomfortable that the distinction is blurred in the sense that the dynamic parts don't live all in the state. We now have a mixed entity that includes both static and dynamic state outside of the state. I could easily see a user forgetting to mark something as static. With the proposal in the example below, opt def is static by default, and you make it dynamic by injecting it into the state where the rest of the dynamic state lives. I think this makes the distinction stronger while allowing to expose the hyperparams like the step size to learning via a simple call to inject_hyperparams.

I'm not sure it solves the goal of this issue though. The main goal of this issue is in the bulleted list at the top: taking derivatives with respect to parameters of the optimizer.

Lets delve into this, we really want to make sure we support your use case,
so if the current tools don't solve the goal of the issue it is important for us to understand why.
Why do you think "taking derivatives with respect to parameters of the optimizer" isn't solved by inject_hyperparams?

Isn't this what I did in this example I shared above?
https://optax.readthedocs.io/en/latest/meta_learning.html

I understand there is also the other issue you raised, and stylistical preferences that might edge us one way or the other,
but lets first be clear as to whether the optimisation of metaparameters itself is solved in my example.

I really appreciate your comments please keep them coming

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

If you'll indulge me, I'd like to make sure we're on the same page about what static and dynamic means in the context of JAX. I don't know that it's really well explained anywhere, and it's something that's easy to misunderstand.

In JAX, various jit and gradient functions accept static and dynamic parameters. These always default to being dynamic unless
marked static (using the static_argnums, static_argnames, or for historical reasons nondiff_argnums).

Static parameters must be hashable. Dynamic parameters must be pytrees, which are either

  • leaves comprising scalars or jax.numpy.ndarray instances, or else
  • aggregate objects comprising dynamic fields (which act as dynamic parameters) and static fields (which act as static parameters despite being passed as part of a dynamic parameter).

The JIT looks up the compiled program using a dictionary that's keyed by

  • the tree structure (Python types of all its components),
  • the shapes and dtypes of its array-valued leaves, and
  • the values of its static parameters.

Consequently,

  • calling the jitted function with different values of the static parameters always induces recompilation, but
  • calling the jitted function with different values (but the same shape) of the dynamic parameters never induces recompilation. They are merely arguments to the compiled program.

Dynamic parameters are replaced with tracers within the JAX-decorated functions, so

  • they cannot be used as the limit of a jax.lax.scan,
  • they cannot be used in Python switches (like if and while), but
  • they can be vectorized by vmap, and be the differentiand of grad, vjp, jvp, etc.

Static parameters are passed to the jitted function unchanged, so

  • they can be used as the limit of a jax.lax.scan,
  • they can be used in Python switches, but
  • cannot be vectorized or be the differentiand.

We now have a mixed entity that includes both static and dynamic state outside of the state.

I don't see a problem with having aggregate types that are "a mixed entity that includes both static and dynamic state". This is exactly what pytrees are meant to handle.

Also, it seems that you're implying (forgive me if I'm misunderstanding you) that because the optimizer definition doesn't usually take on different values, it "should be static". This is not what static means in a JAX context. The main reasons to make something static are because it is not a number or array, or it's a number you need to switch on its value within the jitted function.

I could easily see a user forgetting to mark something as static.

The user isn't the person marking the fields as static and dynamic. That would be done within optax since that's where the classes are defined.

And if you're worried about getting it wrong, your tests should catch any mistakes: If a dynamic array-valued field is accidentally marked static, JAX will complain ("not a JAX type"). If a field you switch on is accidentally marked dynamic, JAX will complain ("concrete value is expected"). If you accidentally mark a numerical field that should be dynamic as static, it won't cause any bugs, but your users will eventually complain about recompilation or inability to vectorize over that field.

I understand there is also the other issue you raised, and stylistical preferences that might edge us one way or the other,
but lets first be clear as to whether the optimisation of metaparameters itself is solved in my example.

Yes, it is definitely one approach. However, if you had just made the transformation a pytree, you wouldn't need inject_hyperparams at all:

from typing import Iterator, Tuple
import chex
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from tjax.gradient import rmsprop

def generator() -> Iterator[Tuple[chex.Array, chex.Array]]:
    rng = jax.random.PRNGKey(0)

    while True:
        rng, k1, k2 = jax.random.split(rng, num=3)
        x = jax.random.uniform(k1, minval=0.0, maxval=10.0)
        y = 10.0 * x + jax.random.normal(k2)
        yield x, y

g = generator()

def f(theta: chex.Array, x: chex.Array) -> chex.Array:
    return x * theta

theta = jax.random.normal(jax.random.PRNGKey(42))

init_learning_rate = jnp.array(0.1)
meta_learning_rate = jnp.array(0.03)

opt = rmsprop(learning_rate=init_learning_rate)
meta_opt = optax.adam(learning_rate=meta_learning_rate)

def loss(theta, x, y):
    return optax.l2_loss(y, f(theta, x))

def step(opt, theta, state, x, y):
    grad = jax.grad(loss)(theta, x, y)
    updates, state = opt.update(grad, state, theta)
    theta = optax.apply_updates(theta, updates)
    return theta, state

@jax.jit
def outer_loss(eta, theta, state, samples):
    # If opt is a pytree, it can just be reconstructed and passed to step.
    opt = rmsprop(learning_rate=jax.nn.sigmoid(eta))

    for x, y in samples[:-1]:
        theta, state = step(opt, theta, state, x, y)

    x, y = samples[-1]

    return loss(theta, x, y), (theta, state)

@jax.jit
def outer_step(eta, theta, meta_state, state, samples):
    grad, (theta, state) = jax.grad(
        outer_loss, has_aux=True)(eta, theta, state, samples)

    meta_updates, meta_state = meta_opt.update(grad, meta_state)
    eta = optax.apply_updates(eta, meta_updates)

    return eta, theta, meta_state, state

state = opt.init(theta)
# inverse sigmoid, to match the value we initialized the inner optimizer with.
eta = -np.log(1. / init_learning_rate - 1)
meta_state = meta_opt.init(eta)

N = 7
learning_rates = []
thetas = []

for i in range(2000):
    samples = [next(g) for i in range(N)]
    eta, theta, meta_state, state = outer_step(
        eta, theta, meta_state, state, samples)
    learning_rates.append(jax.nn.sigmoid(eta))
    thetas.append(theta)

fig, (ax1, ax2) = plt.subplots(2)
fig.suptitle('Meta-learning RMSProp\'s learning rate')
plt.xlabel('Step')

ax1.semilogy(range(len(learning_rates)), learning_rates)
ax1.set(ylabel='Learning rate')
ax1.label_outer()

plt.xlabel('Number of updates')
ax2.semilogy(range(len(thetas)), thetas)

ax2.label_outer()
ax2.set(ylabel='Theta')

plt.show()

from optax.

mtthss avatar mtthss commented on August 15, 2024

Please refer to the message #197 (comment) for our proposed resolution of the issue going forward.

from optax.

Waterkin avatar Waterkin commented on August 15, 2024

Hi Neil,

Meta-gradients are definitely something that is very close to my heart, As I have been using optax in a number of recent MG papers E.g. https://arxiv.org/abs/2007.08794, https://arxiv.org/abs/2007.08433

You will find that many of the building blocks in optax are carefully crafted to ensure that they are suitable for meta-gradients (including subtle numerical issues; see for instance the eps_root argument in scale_by_adam - only important if you want to backprop through an update - or the specific implementation of clip_by_norm - also carefully written to support back propping through it).

Regarding your specific query, here is how I would implement meta-learning of a global learning rate (written quickly just to give the idea)

optim = optax.chain(optax.clip_by_global_norm(10.), optax.scale_by_adam())
meta_optim = optax.chain(optax.clip_by_global_norm(0.1), optax.scale_by_adam())

def inner_update(theta, theta_opt_state, eta, batch):
  """Perform a step of inner update to the agent."""
  grads = jax.grad(inner_loss)(theta, batch)
  updates, new_theta_opt_state = optim.apply(grads, theta_opt_state)
  scaled_updates = tree_map(lambda t: eta * t, updates)
  new_theta = optax.apply_updates(theta, updates)
  return new_theta, new_theta_opt_state

def inner_updates_and_outer_loss(eta, theta, theta_opt_state, rollout_arr):
  """Perform a few inner updates and then compute validation loss."""
  for batch in batches[:-1]:
    theta, theta_opt_state = self.inner_update(theta, theta_opt_state, eta, batch)
 loss = self.outer_loss(theta,  batches[-1])
 return loss, (theta, theta_opt_state)

def update(theta, theta_opt_state, eta, eta_opt_state, batches):
  """Update params(theta) and meta-params (eta)."""
  (new_theta, new_theta_opt_state), meta_grads = jax.grad(inner_updates_and_outer_loss, has_aux=True)(
      eta, theta, theta_opt_state, batches)
  eta_updates, new_eta_opt_state = meta_optim.apply(meta_grads, eta_opt_state)
  new_eta = optax.apply_updates(eta, eta_updates)
  return (new_theta, new_theta_opt_state, new_eta, new_eta_opt_state)

It would be possible to abstract some of this away into a suitable meta-gradient / meta-learning library built on top of optax, and I would be excited to see that! The aim of optax is to provide the shared, low level, well tested, flexible building blocks, that the community can use to construct such (and many other) higher level libraries. By keeping optax fairly slim we can ensure we can reuse these building blocks across many higher-level libraries, therefore sharing effort and code.

Hi, I tried to use the code above, but got some error:
β‘  Gradient Transformation has no attribute 'apply'
β‘‘ TypeError: unsupported operand type(s) for *: 'dict' and 'DynamicJaxprTracer' (at scaled_updates = tree_map(lambda t: eta * t, updates)οΌ‰

from optax.

NeilGirdhar avatar NeilGirdhar commented on August 15, 2024

@Waterkin It appears your eta is a dict.

Anyway, I think the optax is solution is over-complicated. It's a lot easier to simply use the transformations in tjax.gradient, which have all meta-parameters as ordinary dynamic values.

from optax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.