Comments (6)
I'd suggest looking at method interceptors if you want to monkey patch module methods.
Most people doing this sort of patching only want to change module parameters, if your use case is similar then you may find custom getters are more convenient.
from dm-haiku.
It sounds like what you want is w = jax.lax.stop_gradient(w)
(which is basically what you describe, identify forward and 0 for backwards). If you put that in your custom getter it will cause gradients of those parameters to be zero.
Something to watch out for is that in other frameworks (e.g. TF) stop_gradient
causes a None
to be returned as the gradient which optimizers then skip. In JAX this causes zeros to be returned. Another way to say this is that other frameworks AD systems allow you to tell the difference between "gradient disabled" and "0 gradient", in JAX you can only do this if you look at the value of the gradient and conditionally update the parameter and optimizer state based on that.
With some optimizers this can cause a non-zero update to be applied to your parameters (even when gradients are zero), usually this is not what people want when applying stop gradient on parameters (you want to keep the value of those parameters fixed).
If you want to skip updating some params entirely, I would suggest not doing this with custom getters and stop_gradient, but rather partitioning your parameters into ones you want to update and ones you want to hold fixed:
my_f = hk.transform(my_f)
def my_loss_fn(train_params, non_train_params, ..):
params = hk.data_structures.merge(train_params, non_train_params)
out = my_f.apply(params, ..)
..
return loss
grad_my_loss_fn = jax.grad(my_loss_fn)
def is_trainable(module_name, param_name, param_value):
# Can be whatever you want..
return 'linear' in module_name
params = my_f.init(..)
train_params, non_train_params = hk.data_structures.partition(is_trainable, params)
opt_state = opt.init(train_params) # Only get opt_state for trainable params, saves some memory :)
for batch in dataset:
grads = grad_my_loss_fn(train_params, non_train_params, ..)
# NOTE: grads will only be defined for `train_params`.
# Only updating train_params.
updates, opt_state = opt.update(grads, opt_state, train_params)
train_params = optax.apply_updates(updates, train_params)
You would probably want to rework the above so you could jit
the train step but it would look basically the same. You could even close over the constant parameters (rather than pass them in each time) which would allow XLA to potentially do some constant folding and run your code even faster.
from dm-haiku.
That sounds like it would work 😄 . Re map functions you probably want something like params = jax.tree_map(apply_mask, params)
.
We also have a pruning example which implements https://arxiv.org/abs/1710.01878. I suspect this could be a useful reference for you.
from dm-haiku.
Thanks I custom getters is what I want! 👍
Are there equivalents of these methods for gradients in Jax/Haiku? Torch has backward hooks in addition to forward hooks (which are similar to interceptors and getters) but I'm not even sure if a backward hook makes sense in a non-tape style autograd engine. I suppose I could use custom_jvp
/custom_vjp
decorators but they seem too complicated for something simple like setting a bunch of incoming gradients to zero.
from dm-haiku.
Very eye-opening tips Tom! So what I'm actually trying to do is to apply a pruning mask (which probably has the same form as the FlatMapping
that init
returns) to let's say an MLP. For that to work I'd need to (1) set the corresponding parameters to zero and (2) make sure they remain zero by stopping their gradients.
The problem is the granularity of this is at the tensor-level, i.e. I want to prune (disable) some indices in each tensor matrix. I think doing (1) is doable by just modifying the params using some map
method (that I haven't found yet) before injecting it using apply
. Doing (2) would be easy if all I wanted to do was to do was forward/backward on this pruned model. I could do the same thing I did to params
to what grad()
returns before giving it to the optimiser. However, I'm trying to compute some higher-order gradient thing and I want to ensure that corresponding gradients remain remain zero during that process. The stop_gradient
is almost what I want but that works at the tensor level.
I know I have digressed from the original question so feel free to close this issue but any other tips is very welcome.
from dm-haiku.
Actually, I was just staring at the gradient clipping example in Jax's documentation and shouldn't something like below work?
@partial(custom_vjp, nondiff_argnums=(0, 1))
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, None # no residual values to save
def clip_gradient_bwd(lo, hi, _, g):
return (jnp.clip(g, lo, hi),)
def forward(params, x, y):
logits = net.apply(params, x)
loss = cross_entropy(logits, y)
return loss
def custom_forward(params, x, y):
new_params = {}
for module_name, param_dict in params.items():
new_params[module_name] = {
k: clip_gradient(-0.5, 0.5, v) for k, v in param_dict.items()
}
return forward(new_params, x, y)
I mean, obviously I'd have to modify the backward path of the clip_gradient_bwd
to actually set some of the gradients to zero instead of clipping but isn't this what I'd want?
from dm-haiku.
Related Issues (20)
- Warning: hk.LayerNorm when used in transformer decoder causes violation of autoregressive property HOT 1
- Reservoir Computing with Haiku
- Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? HOT 2
- Please publish requirements.txt fix to pip
- How to use `apply` with additional parameters? HOT 1
- hk.Conv2DTranspose takes FOREVER to initialize and compile HOT 1
- 0.4.16 timeline HOT 2
- How to export haiku network parameters into Pytorch network?
- Modules got silently "reused" with `hk.vmap` HOT 2
- Wrong gradients in a Haiku network
- Direct Feedback Alignment
- Issue with wheels including docs and examples folder
- `haiku.experimental.flax` is not part of newest pip release HOT 1
- Train multiple hk.nets.MLP with one optimizer HOT 2
- TypeError: 'type' object is not subscriptable HOT 4
- Wrapping the ```init``` function inside ```jax.jit``` HOT 1
- Consider make flax an optional dependency HOT 1
- hk.switch does not work inside a hk.vmap function when hk.set_state is used HOT 1
- hk.BatchNorm with jax.vmap
- Integrating vmap with BatchNorm
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dm-haiku.