GithubHelp home page GithubHelp logo

Comments (3)

Cogitans avatar Cogitans commented on September 25, 2024

Hi! I wrote the current Haiku implementation. I'll answer the questions in reverse order :)

How do I call jax.vjp on module?

If you look at https://github.com/deepmind/dm-haiku/blob/master/haiku/_src/stateful.py you can see how we define a bunch of wrappers around Jax functions to work with Haiku. There's a lot of code but the idea is simple: temporarily grab the global state and thread it through the Jax function inputs, then make it global state again within the function (and reverse when returning). We don't have a wrapper around vjp right now (we have one for grad), but it shouldn't be too hard to do.

how can I implement this in Haiku?

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs? The approximation we use is that used by SNGAN (https://arxiv.org/pdf/1802.05957.pdf) and BigGAN, and we know it remains quite stable on accelerators, I'd be curious if you have run any experiments checking the exact approach's numerics.

from dm-haiku.

shoyer avatar shoyer commented on September 25, 2024

I definitely think it would be good to have this function somewhere, but I'm a little hesitant on suggesting where. If I am imagining your implementation correctly, it doesn't actually require any Haiku state or Parameters (unlike hk.SpectralNorm which uses state to store a running estimate of the spectral values). Would it be better to have it be a pure function elsewhere, and have examples on how you could use it with a Haiku Module (along with a Flax/etc modules, since presumably they'd all work)?

For use in neural network training, I think you would still want to estimate the vector corresponding to the largest singular in an online fashion.

Here's a clearer way to separate the logic:

def _l2_normalize(x, eps=1e-4):
  return x * jax.lax.rsqrt((x ** 2).sum() + eps)

def _l2_norm(x):
  return jnp.sqrt((x ** 2).sum())

def _power_iteration(A, u, n_steps=10):
  """Update an estimate of the first right-singular vector of A()."""
  def fun(u, _):
    v, A_transpose = jax.vjp(A, u)
    u, = A_transpose(v)
    u = _l2_normalize(u)
    return u, None
  u, _ = lax.scan(fun, u, xs=None, length=n_steps)
  return u

def estimate_spectral_norm(f, x, seed=0, n_steps=10):
  """Estimate the spectral norm of f(x) linearized at x."""
  rng = jax.random.PRNGKey(seed)
  u0 = jax.random.normal(rng, x.shape)
  _, f_jvp = jax.linearize(f, x)
  u = _power_iteration(f_jvp, u0, n_steps)
  sigma = _l2_norm(f_jvp(u))
  return sigma

I can imagine estimate_spectral_norm being a separately useful utility, but in a spectral normalization layer, you'd want to save the vector u0 as state on the layer and only use a handful of power iterations in each neural net evaluation.

As a second question, has this approach been used before/do you know how well it works on GPUs/TPUs?

The same approach (but written in a much more awkward/manual way) was used in this ICLR 2019 paper. Numerically, they should be identical. If you're using fully-connected layers, the calculation is exactly the same as the older method, just using autodiff instead of explicit matrix/vector products.

From a fundamental perspective I would guess this is quite efficient and numerically stable on accelerators, because the operation is uses are the exact same as those used at the core of neural net training:

  • forward evaluation of linear layers (e.g., "convolution")
  • gradient evaluation of linear layers (e.g., "convolution transpose")

The cost of doing a single power iteration is thus roughly equivalent to that of pushing a single additional example through the neural net.

(The version I wrote in this comment is slightly simpler that the version in the ICRL 2019 paper, because it uses norm(A(u)) rather v @ A(u) to calculate the singular value and only normalizes once per iteration, but I doubt those make much of a difference and are not hard to change.)

from dm-haiku.

chiamp avatar chiamp commented on September 25, 2024

hi @Cogitans, I'm trying to add spectral normalization into Flax and am modeling it after the Haiku version. I had some questions:

  • How is this used in a typical training loop? Are the params spectral normalized after the gradient update using SNParamsTree (as seen in page 6, Algorithm 1, line 5 of the original paper)? If so, why not just create a helper function that does the spectral normalization and then use jax.tree_map to spectral normalize the params? e.g. params = jax.tree_map(lambda x: spectral_normalize(x), params)
  • I'd like to understand the power iteration method used; it's different than what I've read on the wikipedia article, and seems to not converge to the right eigenvalue even after many iteration steps for some matrices
  • why is lax.stop_gradient used?
  • why do we need to keep track of the state u0 and sigma?

from dm-haiku.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.