GithubHelp home page GithubHelp logo

Comments (6)

tomhennigan avatar tomhennigan commented on September 25, 2024 1

I'd suggest looking at method interceptors if you want to monkey patch module methods.

Most people doing this sort of patching only want to change module parameters, if your use case is similar then you may find custom getters are more convenient.

from dm-haiku.

tomhennigan avatar tomhennigan commented on September 25, 2024 1

It sounds like what you want is w = jax.lax.stop_gradient(w) (which is basically what you describe, identify forward and 0 for backwards). If you put that in your custom getter it will cause gradients of those parameters to be zero.

Something to watch out for is that in other frameworks (e.g. TF) stop_gradient causes a None to be returned as the gradient which optimizers then skip. In JAX this causes zeros to be returned. Another way to say this is that other frameworks AD systems allow you to tell the difference between "gradient disabled" and "0 gradient", in JAX you can only do this if you look at the value of the gradient and conditionally update the parameter and optimizer state based on that.

With some optimizers this can cause a non-zero update to be applied to your parameters (even when gradients are zero), usually this is not what people want when applying stop gradient on parameters (you want to keep the value of those parameters fixed).

If you want to skip updating some params entirely, I would suggest not doing this with custom getters and stop_gradient, but rather partitioning your parameters into ones you want to update and ones you want to hold fixed:

my_f = hk.transform(my_f)

def my_loss_fn(train_params, non_train_params, ..):
  params = hk.data_structures.merge(train_params, non_train_params)
  out = my_f.apply(params, ..)
  ..
  return loss

grad_my_loss_fn = jax.grad(my_loss_fn)

def is_trainable(module_name, param_name, param_value):
  # Can be whatever you want..
  return 'linear' in module_name

params = my_f.init(..)
train_params, non_train_params = hk.data_structures.partition(is_trainable, params)
opt_state = opt.init(train_params)  # Only get opt_state for trainable params, saves some memory :)

for batch in dataset:
  grads = grad_my_loss_fn(train_params, non_train_params, ..)
  # NOTE: grads will only be defined for `train_params`.

  # Only updating train_params.
  updates, opt_state = opt.update(grads, opt_state, train_params)
  train_params = optax.apply_updates(updates, train_params)

You would probably want to rework the above so you could jit the train step but it would look basically the same. You could even close over the constant parameters (rather than pass them in each time) which would allow XLA to potentially do some constant folding and run your code even faster.

from dm-haiku.

tomhennigan avatar tomhennigan commented on September 25, 2024 1

That sounds like it would work 😄 . Re map functions you probably want something like params = jax.tree_map(apply_mask, params).

We also have a pruning example which implements https://arxiv.org/abs/1710.01878. I suspect this could be a useful reference for you.

from dm-haiku.

mil-ad avatar mil-ad commented on September 25, 2024

Thanks I custom getters is what I want! 👍

Are there equivalents of these methods for gradients in Jax/Haiku? Torch has backward hooks in addition to forward hooks (which are similar to interceptors and getters) but I'm not even sure if a backward hook makes sense in a non-tape style autograd engine. I suppose I could use custom_jvp/custom_vjp decorators but they seem too complicated for something simple like setting a bunch of incoming gradients to zero.

from dm-haiku.

mil-ad avatar mil-ad commented on September 25, 2024

Very eye-opening tips Tom! So what I'm actually trying to do is to apply a pruning mask (which probably has the same form as the FlatMapping that init returns) to let's say an MLP. For that to work I'd need to (1) set the corresponding parameters to zero and (2) make sure they remain zero by stopping their gradients.

The problem is the granularity of this is at the tensor-level, i.e. I want to prune (disable) some indices in each tensor matrix. I think doing (1) is doable by just modifying the params using some map method (that I haven't found yet) before injecting it using apply. Doing (2) would be easy if all I wanted to do was to do was forward/backward on this pruned model. I could do the same thing I did to params to what grad() returns before giving it to the optimiser. However, I'm trying to compute some higher-order gradient thing and I want to ensure that corresponding gradients remain remain zero during that process. The stop_gradient is almost what I want but that works at the tensor level.

I know I have digressed from the original question so feel free to close this issue but any other tips is very welcome.

from dm-haiku.

mil-ad avatar mil-ad commented on September 25, 2024

Actually, I was just staring at the gradient clipping example in Jax's documentation and shouldn't something like below work?

@partial(custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
    return x  # identity function


def clip_gradient_fwd(lo, hi, x):
    return x, None  # no residual values to save


def clip_gradient_bwd(lo, hi, _, g):
    return (jnp.clip(g, lo, hi),)


def forward(params, x, y):
    logits = net.apply(params, x)
    loss = cross_entropy(logits, y)
    return loss

def custom_forward(params, x, y):
    new_params = {}
    for module_name, param_dict in params.items():
        new_params[module_name] = {
            k: clip_gradient(-0.5, 0.5, v) for k, v in param_dict.items()
        }

    return forward(new_params, x, y)

I mean, obviously I'd have to modify the backward path of the clip_gradient_bwd to actually set some of the gradients to zero instead of clipping but isn't this what I'd want?

from dm-haiku.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.