GithubHelp home page GithubHelp logo

Comments (12)

wdphy16 avatar wdphy16 commented on August 15, 2024 1

Hi @rosshemsley , sorry for the late reply. I've written the proposal document and edited the first comment of this issue.

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024 1

Hey @wdphy16, just to let you know - we linked your design doc into the documentation (https://optax.readthedocs.io/en/latest/design_docs.html) and added you as a contributor to https://optax.readthedocs.io/en/latest/contributors.html.

We appreciate the work you have put into supporting optax! 🎉

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Hello @wdphy16, thanks for your comment, this initiative is interesting and welcome!

One consideration is that optax is heavily used by models that are 'internal', and so any changes that we pull upstream have to be carefully managed to avoid performance penalties, semantics changes, and drastic API changes that could cause problems to existing optax users.

We do believe that these ideas can be integrated whilst managing this complexity though, and think this could be an important set of features for optax. Perhaps a good way to start could be to sketch up a proposal document, suggesting a concrete approach - effectively a more detailed version of your comment, where we can add suggestions. That way we can help to avoid some of the pitfalls given above up front :)

from optax.

wdphy16 avatar wdphy16 commented on August 15, 2024

Hi @rosshemsley , thanks for your feedback! Should I write the proposal document here, or in a PR or Discussions?

from optax.

rosshemsley avatar rosshemsley commented on August 15, 2024

Hey @wdphy16: perhaps an issue thread, or a github gist?

Any approach that is publicly readable and comment-able would be fine,

One example of a proposal with public comments is here: golang/go#43651 (consider this an example, though probably this is far more detailed than would be needed in this case).

from optax.

mtthss avatar mtthss commented on August 15, 2024

Thanks @wdphy16 for sharing such a thoughtful and detailed proposal.

I think we should certainly implement the split real norm approach described in your doc.
That is an enhancement of the current capabilities of optax with literally no downsides,
as those who are not interested in complex number will just continue not to use the split_complex wrapper.

Could you put together a PR for this?

About the other proposal, I think it would be nice to support that as well,
but we need to think carefully about performance and readability,
for instance it would be good to verify for sure what XLA is or isn't able to do

E.g. the point from your doc:

When g is real, it should be possible for jax.jit to eliminate the dispatch overhead of conj and real, and optimize (g * g) ** (order / 2) into g ** order (TODO: I'm not sure if currently jax.jit is smart enough to do this)

from optax.

wdphy16 avatar wdphy16 commented on August 15, 2024

It turns out that currently jax.jit cannot optimize the composition of a multiplication and a power. I did some quick experiments with this debugging function make_hlo:

In [1]: def norm(g, order):
   ...:     return (g.conj() * g).real ** (order / 2)
   ...:

In [2]: g = jnp.array([1., 2., 3.])

In [3]: print(make_hlo(norm)(g, 2))
HloModule xla_computation_norm.11

ENTRY xla_computation_norm.11 {
  constant.3 = pred[] constant(false)
  parameter.1 = f32[3]{0} parameter(0)
  multiply.4 = f32[3]{0} multiply(parameter.1, parameter.1)
  parameter.2 = s32[] parameter(1)
  convert.5 = f32[] convert(parameter.2)
  constant.6 = f32[] constant(2)
  divide.7 = f32[] divide(convert.5, constant.6)
  broadcast.8 = f32[3]{0} broadcast(divide.7), dimensions={}
  power.9 = f32[3]{0} power(multiply.4, broadcast.8)
  ROOT tuple.10 = (f32[3]{0}) tuple(power.9)
}

In [4]: print(make_hlo(norm, optimize=True)(g, 2))
HloModule xla_computation_norm__1.11

fused_computation {
  param_1.4 = f32[3]{0} parameter(1)
  multiply.2 = f32[3]{0} multiply(param_1.4, param_1.4)
  param_0.1 = s32[] parameter(0)
  convert.0 = f32[] convert(param_0.1)
  constant.1 = f32[] constant(0.5)
  multiply.1 = f32[] multiply(convert.0, constant.1)
  broadcast.0 = f32[3]{0} broadcast(multiply.1), dimensions={}
  ROOT power.0 = f32[3]{0} power(multiply.2, broadcast.0)
}

ENTRY xla_computation_norm__1.11 {
  parameter.2 = s32[] parameter(1)
  parameter.1 = f32[3]{0} parameter(0)
  fusion = f32[3]{0} fusion(parameter.2, parameter.1), kind=kLoop, calls=fused_computation
  ROOT tuple.10 = (f32[3]{0}) tuple(fusion)
}

We can see that in the optimized HLO conj and real are indeed eliminated for the real g, but there are still a multiplication and a power. Actually such optimization does not exactly preserve the floating point arithmetic result, so it should be enabled only with some dangerous flag like 'fast math', and I don't know if JAX developers will have such plan in the future.

Anyway, we can still use an if-else statement to eliminate that overhead when g is real. I've updated the proposal.

from optax.

PhilipVinc avatar PhilipVinc commented on August 15, 2024

@wdphy16 can you update the PR to use the if-else statement and fix ADAM? I think this should be good for @mtthss

from optax.

wdphy16 avatar wdphy16 commented on August 15, 2024

The PR #241 is only for split_norm. To implement the complex norm in optimizers, I need to take some time to review everything (at my best) in transforms.py and alias.py, and I'll finish that today or tomorrow.

from optax.

mkunesch avatar mkunesch commented on August 15, 2024

Hi all!

Do I understand correctly that with split_complex merged (thanks a lot @wdphy16!!) the section "Split real norm" of the proposal is implemented fully or is there anything outstanding for this?

If so, the next step would be discussing whether to add the complex norm in addition to the split real norm?

Thanks a lot!

from optax.

wdphy16 avatar wdphy16 commented on August 15, 2024

Yes, the split real norm is fully implemented. I've also implemented the complex norm for some relatively simple optimizers and I'll make another PR.

from optax.

mkunesch avatar mkunesch commented on August 15, 2024

Ah that's great, thanks a lot!

from optax.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.