Comments (11)
And sorry, for context, the checkpointing is so that I can implement self-play where the baseline model that controls the enemy is updated to match the model we're training every once in a while.
from purejaxrl.
Never mind, figured out the main thing for me, checkpointing. You have to:
- remove the jitting of make_train when you call it
- @jax.jit the def _update_step
- (these two combined seems to make no measurable difference to performance, but my testing was not extensive)
- add a save of whatever you want at the end of train(rng)
- add a load from your save of specifically only network_params at the start
My one concern is that I doubt/cannot tell whether the learning rate scheduler is maintained, but I will worry about that later.
On less regular metrics, I will also worry about that later, but I've seen some likely jittable tools.
from purejaxrl.
@jheagerty actually I think you can save checkpoints under jit easily with callbacks, such as jax.experimental.io_callback() (for example inside _update_step
to save after the each update)
from purejaxrl.
yep! I do it the way @Howuhh is describing. If you look at the code, there is the debug callback. You can just replace the print function with your checkpointing and wandb logging.
from purejaxrl.
Thanks so much! Will look into this
from purejaxrl.
The way you do it should be fine too, and is arguably better (though takes more code). The optimizer parameters (which includes the lr scheduling info) should be in the train_state
from purejaxrl.
@luchris429 Hi Chris! I was wondering, how did you implement restoring of checkpoints in the PureJaxRL end-to-end jitting? I'm able to save checkpoints pretty easily with a debug callback function, but I can't quite figure out how to restore. I attempted to put a experimental.io_callback
function in the train
function but I can't actually do anything with the string checkpoint path because JAX can't handle strings.
from purejaxrl.
You can try to load the runner state here!
Does it not work if you set the filename in the config?
from purejaxrl.
You can try to load the runner state here!
Does it not work if you set the filename in the config?
So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):
def resuming_callback(path):
checkpointer = ocp.PyTreeCheckpointer()
raw_restored = checkpointer.restore(path)
return raw_restored
runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
raw_restored = io_callback(
resuming_callback, runner_state, args.resume_checkpoint_path
)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, args.num_iterations
)
However JAX errors out with the complaint that my args.resume_checkpoint_path
is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?
from purejaxrl.
Sorry I didn't catch this message! I hope you've figured it out.
I think you need to make sure it's a static argument since you can't JIT a string as an argument.
from purejaxrl.
You can try to load the runner state here!
Does it not work if you set the filename in the config?So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):
def resuming_callback(path): checkpointer = ocp.PyTreeCheckpointer() raw_restored = checkpointer.restore(path) return raw_restored runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key) if args.resume: raw_restored = io_callback( resuming_callback, runner_state, args.resume_checkpoint_path ) runner_state, metric = jax.lax.scan( _update_step, runner_state, None, args.num_iterations )However JAX errors out with the complaint that my
args.resume_checkpoint_path
is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?
actually, I have tried with failure result even a fixed filename. Restoration in jit is quite difficult for me
from purejaxrl.
Related Issues (15)
- Applications of meta-rl to DQN HOT 2
- JAX Version Strangeness
- Question about critic loss HOT 2
- MuJoco Support for weld equality constraint HOT 2
- DreamerV3 purejaxrl-ify HOT 1
- Different results when running PPO with the same seed multiple times? HOT 3
- Colab examples need to install distrax gymnax and brax? HOT 2
- PPO Implementation Ignores Time Limits HOT 4
- NormalizeVecObservation Wrapper Shape Mismatch for Mean and Var HOT 1
- is it possible to add checkpoint saving/restore to this repository? HOT 1
- RNNs hidden resets HOT 1
- Recurrent network support HOT 2
- linear schedule is not linear (?) HOT 3
- Multi GPU support HOT 5
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from purejaxrl.