Comments (12)
Hi @rosshemsley , sorry for the late reply. I've written the proposal document and edited the first comment of this issue.
from optax.
Hey @wdphy16, just to let you know - we linked your design doc into the documentation (https://optax.readthedocs.io/en/latest/design_docs.html) and added you as a contributor to https://optax.readthedocs.io/en/latest/contributors.html.
We appreciate the work you have put into supporting optax! 🎉
from optax.
Hello @wdphy16, thanks for your comment, this initiative is interesting and welcome!
One consideration is that optax is heavily used by models that are 'internal', and so any changes that we pull upstream have to be carefully managed to avoid performance penalties, semantics changes, and drastic API changes that could cause problems to existing optax users.
We do believe that these ideas can be integrated whilst managing this complexity though, and think this could be an important set of features for optax. Perhaps a good way to start could be to sketch up a proposal document, suggesting a concrete approach - effectively a more detailed version of your comment, where we can add suggestions. That way we can help to avoid some of the pitfalls given above up front :)
from optax.
Hi @rosshemsley , thanks for your feedback! Should I write the proposal document here, or in a PR or Discussions?
from optax.
Hey @wdphy16: perhaps an issue thread, or a github gist?
Any approach that is publicly readable and comment-able would be fine,
One example of a proposal with public comments is here: golang/go#43651 (consider this an example, though probably this is far more detailed than would be needed in this case).
from optax.
Thanks @wdphy16 for sharing such a thoughtful and detailed proposal.
I think we should certainly implement the split real norm approach described in your doc.
That is an enhancement of the current capabilities of optax with literally no downsides,
as those who are not interested in complex number will just continue not to use the split_complex
wrapper.
Could you put together a PR for this?
About the other proposal, I think it would be nice to support that as well,
but we need to think carefully about performance and readability,
for instance it would be good to verify for sure what XLA is or isn't able to do
E.g. the point from your doc:
When g is real, it should be possible for jax.jit to eliminate the dispatch overhead of conj and real, and optimize (g * g) ** (order / 2) into g ** order (TODO: I'm not sure if currently jax.jit is smart enough to do this)
from optax.
It turns out that currently jax.jit
cannot optimize the composition of a multiplication and a power. I did some quick experiments with this debugging function make_hlo
:
In [1]: def norm(g, order):
...: return (g.conj() * g).real ** (order / 2)
...:
In [2]: g = jnp.array([1., 2., 3.])
In [3]: print(make_hlo(norm)(g, 2))
HloModule xla_computation_norm.11
ENTRY xla_computation_norm.11 {
constant.3 = pred[] constant(false)
parameter.1 = f32[3]{0} parameter(0)
multiply.4 = f32[3]{0} multiply(parameter.1, parameter.1)
parameter.2 = s32[] parameter(1)
convert.5 = f32[] convert(parameter.2)
constant.6 = f32[] constant(2)
divide.7 = f32[] divide(convert.5, constant.6)
broadcast.8 = f32[3]{0} broadcast(divide.7), dimensions={}
power.9 = f32[3]{0} power(multiply.4, broadcast.8)
ROOT tuple.10 = (f32[3]{0}) tuple(power.9)
}
In [4]: print(make_hlo(norm, optimize=True)(g, 2))
HloModule xla_computation_norm__1.11
fused_computation {
param_1.4 = f32[3]{0} parameter(1)
multiply.2 = f32[3]{0} multiply(param_1.4, param_1.4)
param_0.1 = s32[] parameter(0)
convert.0 = f32[] convert(param_0.1)
constant.1 = f32[] constant(0.5)
multiply.1 = f32[] multiply(convert.0, constant.1)
broadcast.0 = f32[3]{0} broadcast(multiply.1), dimensions={}
ROOT power.0 = f32[3]{0} power(multiply.2, broadcast.0)
}
ENTRY xla_computation_norm__1.11 {
parameter.2 = s32[] parameter(1)
parameter.1 = f32[3]{0} parameter(0)
fusion = f32[3]{0} fusion(parameter.2, parameter.1), kind=kLoop, calls=fused_computation
ROOT tuple.10 = (f32[3]{0}) tuple(fusion)
}
We can see that in the optimized HLO conj
and real
are indeed eliminated for the real g
, but there are still a multiplication and a power. Actually such optimization does not exactly preserve the floating point arithmetic result, so it should be enabled only with some dangerous flag like 'fast math', and I don't know if JAX developers will have such plan in the future.
Anyway, we can still use an if-else statement to eliminate that overhead when g
is real. I've updated the proposal.
from optax.
@wdphy16 can you update the PR to use the if-else statement and fix ADAM? I think this should be good for @mtthss
from optax.
The PR #241 is only for split_norm
. To implement the complex norm in optimizers, I need to take some time to review everything (at my best) in transforms.py
and alias.py
, and I'll finish that today or tomorrow.
from optax.
Hi all!
Do I understand correctly that with split_complex
merged (thanks a lot @wdphy16!!) the section "Split real norm" of the proposal is implemented fully or is there anything outstanding for this?
If so, the next step would be discussing whether to add the complex norm in addition to the split real norm?
Thanks a lot!
from optax.
Yes, the split real norm is fully implemented. I've also implemented the complex norm for some relatively simple optimizers and I'll make another PR.
from optax.
Ah that's great, thanks a lot!
from optax.
Related Issues (20)
- Allow RMSProp use same scaling as adam without momentum + make schedule_free_adamw use rmsprop directly (to spare one memory slot)) HOT 1
- Masking certain parameters for weight decay in adamw HOT 10
- Momentum buffers still created and computed even if beta1 is 0 HOT 4
- Feature Request: Second-Order Optimization Methods in Optax HOT 2
- LBFGS not working for custom classes HOT 1
- Precision bug in `optax.softmax_cross_entropy_with_integer_labels`
- Feature request: GaLore optimizer HOT 2
- Add an assignment problem solver
- Jaxopt vs Optax HOT 3
- LinearSolveTest.test_solve_sparse fails with jax 0.4.26 HOT 1
- How to Cleanly Specify Optimizer, Schedule, & Gradient Clipping with inject_hyperparams HOT 3
- Add ACProp HOT 1
- Intended usage of the Sophia optimiser HOT 7
- Memory leak in optax.radam HOT 2
- Timeline for JaxOpt migration HOT 11
- AttributeError: module 'optax' has no attribute 'lbfgs' HOT 2
- Performance issue: multi_transform and set_to_zero don't prevent computation HOT 1
- Support for CSR format sparse matrix in optimizer? HOT 1
- Loose dependency allows chex version without 'warn_deprecated_function' to be installed HOT 1
- Stochasticsm in Adam HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from optax.