Comments (3)
Hi! I wrote the current Haiku implementation. I'll answer the questions in reverse order :)
How do I call
jax.vjp
on module?
If you look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/stateful.py you can see how we define a bunch of wrappers around Jax functions to work with Haiku. There's a lot of code but the idea is simple: temporarily grab the global state and thread it through the Jax function inputs, then make it global state again within the function (and reverse when returning). We don't have a wrapper around vjp right now (we have one for grad), but it shouldn't be too hard to do.
how can I implement this in Haiku?
I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?
As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs? The approximation we use is that used by SNGAN (https://arxiv.org/pdf/1802.05957.pdf) and BigGAN, and we know it remains quite stable on accelerators, I'd be curious if you have run any experiments checking the exact approach's numerics.
from dm-haiku.
I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?
For use in neural network training, I think you would still want to estimate the vector corresponding to the largest singular in an online fashion.
Here's a clearer way to separate the logic:
def _l2_normalize(x, eps=1e-4):
return x * jax.lax.rsqrt((x ** 2).sum() + eps)
def _l2_norm(x):
return jnp.sqrt((x ** 2).sum())
def _power_iteration(A, u, n_steps=10):
"""Update an estimate of the first right-singular vector of A()."""
def fun(u, _):
v, A_transpose = jax.vjp(A, u)
u, = A_transpose(v)
u = _l2_normalize(u)
return u, None
u, _ = lax.scan(fun, u, xs=None, length=n_steps)
return u
def estimate_spectral_norm(f, x, seed=0, n_steps=10):
"""Estimate the spectral norm of f(x) linearized at x."""
rng = jax.random.PRNGKey(seed)
u0 = jax.random.normal(rng, x.shape)
_, f_jvp = jax.linearize(f, x)
u = _power_iteration(f_jvp, u0, n_steps)
sigma = _l2_norm(f_jvp(u))
return sigma
I can imagine estimate_spectral_norm
being a separately useful utility, but in a spectral normalization layer, you'd want to save the vector u0
as state on the layer and only use a handful of power iterations in each neural net evaluation.
As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs?
The same approach (but written in a much more awkward/manual way) was used in this ICLR 2019 paper. Numerically, they should be identical. If you're using fully-connected layers, the calculation is exactly the same as the older method, just using autodiff instead of explicit matrix/vector products.
From a fundamental perspective I would guess this is quite efficient and numerically stable on accelerators, because the operation is uses are the exact same as those used at the core of neural net training:
- forward evaluation of linear layers (e.g., "convolution")
- gradient evaluation of linear layers (e.g., "convolution transpose")
The cost of doing a single power iteration is thus roughly equivalent to that of pushing a single additional example through the neural net.
(The version I wrote in this comment is slightly simpler that the version in the ICRL 2019 paper, because it uses norm(A(u))
rather v @ A(u)
to calculate the singular value and only normalizes once per iteration, but I doubt those make much of a difference and are not hard to change.)
from dm-haiku.
hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the Haiku version. I had some questions:
- How is this used in a typical training loop? Are the params spectral normalized after the gradient update using SNParamsTree (as seen in page 6, Algorithm 1, line 5 of the original paper)? If so, why not just create a helper function that does the spectral normalization and then use
jax.tree_map
to spectral normalize the params? e.g.params = jax.tree_map(lambda x: spectral_normalize(x), params)
- I'd like to understand the power iteration method used; it's different than what I've read on the wikipedia article, and seems to not converge to the right eigenvalue even after many iteration steps for some matrices
- why is
lax.stop_gradient
used? - why do we need to keep track of the state
u0
andsigma
?
from dm-haiku.
Related Issues (20)
- Reservoir Computing with Haiku
- Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? HOT 2
- Please publish requirements.txt fix to pip
- How to use `apply` with additional parameters? HOT 1
- hk.Conv2DTranspose takes FOREVER to initialize and compile HOT 1
- 0.4.16 timeline HOT 2
- How to export haiku network parameters into Pytorch network?
- Modules got silently "reused" with `hk.vmap` HOT 2
- Wrong gradients in a Haiku network
- Direct Feedback Alignment
- Issue with wheels including docs and examples folder
- `haiku.experimental.flax` is not part of newest pip release HOT 1
- Train multiple hk.nets.MLP with one optimizer HOT 2
- TypeError: 'type' object is not subscriptable HOT 4
- Wrapping the ```init``` function inside ```jax.jit``` HOT 1
- Consider make flax an optional dependency HOT 1
- hk.switch does not work inside a hk.vmap function when hk.set_state is used HOT 1
- hk.BatchNorm with jax.vmap
- Integrating vmap with BatchNorm
- +ดู~ดวงใจเทวพรหม พรชีวัน (EP.6) ตอนที่ 6 ดูย้อนหลังเต็ม เรื่อง ULTRA~HD ตอนล่าสุดฟรี
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dm-haiku.