Comments (5)
I was thinking the same thing; we can view the returned value as an overlay over the original values.
There are a couple of annoying things here to be aware of:
- For state, updated states will still modify the original state member of
StatePair
to be the _foil_cse'd one. This is pretty annoying, though not the worst thing ever. - The _foil_cse HLO does not get pruned when not used - we're left with a huge number of vestigial
constant
andrng
HLO ops. Blake from XLA team speculates that this is because the ops could be side-effecting.
As a side note, it's not clear to me whether _foil_cse-ing params (which don't need to be _foil_cse'd, as they're constant!) is even a good idea. This is sufficiently murky that I think Haiku shouldn't take a stance on this.
from dm-haiku.
I like 1, we can definitely be smarter wrt params and state.
I think a general solution would be to return only updated values for both params/state and then merge into the parent frame. I think we can do this by stashing the intial values (we only care about identity) that are inputs to the stateful_fun and only returning ones that have a different identity.
def only_new_values(original, updated):
output = defaultdict(dict)
for mod_name, mod_state in updated.items():
for name, value in updated.items():
if name in original[mod_name] and original[mod_name][name] is not value:
output[mod_name][name] = value
return output
def stateful_fun(*a, **k):
hk_state = k.pop(..)
orig_params, orig_state = copy_params_state(hk_state)
with use_and_update(hk_state):
out = f(*a, **k)
# Only return updated values.
params = only_new_values(orig_params, hk_state.params)
state = only_new_values(orig_state, hk_state.state)
return out, (params, state)
I think I agree re PRNG keys too, but would want to triple check the details.
from dm-haiku.
My changes in google/jax#2391 to the XLA CSE mechanism change the semantics of this bug.
Now, we're incurring unnecessary serialization (because parameters are now an output of the remat layer), rather than unnecessary layers of _foil_cse
.
The underlying cause remains the same (unnecessarily feeding and updating the Haiku state).
from dm-haiku.
I'll take a stab at not updating state if it doesn't change.
from dm-haiku.
As a result of google/jax#2391, it looks like the problems get optimized away; we're not even incurring unnecessary serialization right now.
These changes may still be good to make defensively.
from dm-haiku.
Related Issues (20)
- Warning: hk.LayerNorm when used in transformer decoder causes violation of autoregressive property HOT 1
- Reservoir Computing with Haiku
- Efficiency difference in using jax.lax.fori_loop vs looping over identical layers? HOT 2
- Please publish requirements.txt fix to pip
- How to use `apply` with additional parameters? HOT 1
- hk.Conv2DTranspose takes FOREVER to initialize and compile HOT 1
- 0.4.16 timeline HOT 2
- How to export haiku network parameters into Pytorch network?
- Modules got silently "reused" with `hk.vmap` HOT 2
- Wrong gradients in a Haiku network
- Direct Feedback Alignment
- Issue with wheels including docs and examples folder
- `haiku.experimental.flax` is not part of newest pip release HOT 1
- Train multiple hk.nets.MLP with one optimizer HOT 2
- TypeError: 'type' object is not subscriptable HOT 4
- Wrapping the ```init``` function inside ```jax.jit``` HOT 1
- Consider make flax an optional dependency HOT 1
- hk.switch does not work inside a hk.vmap function when hk.set_state is used HOT 1
- hk.BatchNorm with jax.vmap
- Integrating vmap with BatchNorm
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from dm-haiku.