GithubHelp home page GithubHelp logo

Comments (11)

jheagerty avatar jheagerty commented on May 9, 2024

And sorry, for context, the checkpointing is so that I can implement self-play where the baseline model that controls the enemy is updated to match the model we're training every once in a while.

from purejaxrl.

jheagerty avatar jheagerty commented on May 9, 2024

Never mind, figured out the main thing for me, checkpointing. You have to:

  • remove the jitting of make_train when you call it
  • @jax.jit the def _update_step
  • (these two combined seems to make no measurable difference to performance, but my testing was not extensive)
  • add a save of whatever you want at the end of train(rng)
  • add a load from your save of specifically only network_params at the start

My one concern is that I doubt/cannot tell whether the learning rate scheduler is maintained, but I will worry about that later.

On less regular metrics, I will also worry about that later, but I've seen some likely jittable tools.

from purejaxrl.

Howuhh avatar Howuhh commented on May 9, 2024

@jheagerty actually I think you can save checkpoints under jit easily with callbacks, such as jax.experimental.io_callback() (for example inside _update_step to save after the each update)

from purejaxrl.

luchris429 avatar luchris429 commented on May 9, 2024

yep! I do it the way @Howuhh is describing. If you look at the code, there is the debug callback. You can just replace the print function with your checkpointing and wandb logging.

from purejaxrl.

jheagerty avatar jheagerty commented on May 9, 2024

Thanks so much! Will look into this

from purejaxrl.

luchris429 avatar luchris429 commented on May 9, 2024

The way you do it should be fine too, and is arguably better (though takes more code). The optimizer parameters (which includes the lr scheduling info) should be in the train_state

from purejaxrl.

Chulabhaya avatar Chulabhaya commented on May 9, 2024

@luchris429 Hi Chris! I was wondering, how did you implement restoring of checkpoints in the PureJaxRL end-to-end jitting? I'm able to save checkpoints pretty easily with a debug callback function, but I can't quite figure out how to restore. I attempted to put a experimental.io_callback function in the train function but I can't actually do anything with the string checkpoint path because JAX can't handle strings.

from purejaxrl.

luchris429 avatar luchris429 commented on May 9, 2024

You can try to load the runner state here!

Does it not work if you set the filename in the config?

from purejaxrl.

Chulabhaya avatar Chulabhaya commented on May 9, 2024

You can try to load the runner state here!

Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

from purejaxrl.

luchris429 avatar luchris429 commented on May 9, 2024

Sorry I didn't catch this message! I hope you've figured it out.

I think you need to make sure it's a static argument since you can't JIT a string as an argument.

from purejaxrl.

gzadigo avatar gzadigo commented on May 9, 2024

You can try to load the runner state here!
Does it not work if you set the filename in the config?

So I tried something like the code below at exactly the line you pointed out (in a modified PPO script where I split the actor/critic):

def resuming_callback(path):
    checkpointer = ocp.PyTreeCheckpointer()
    raw_restored = checkpointer.restore(path)
    return raw_restored

runner_state = (actor_state, vf_state, time_state, env_state, obsv, train_key)
if args.resume:
    raw_restored = io_callback(
        resuming_callback, runner_state, args.resume_checkpoint_path
     )

runner_state, metric = jax.lax.scan(
    _update_step, runner_state, None, args.num_iterations
)

However JAX errors out with the complaint that my args.resume_checkpoint_path is a string which is not compatible. Hence my current conundrum. Perhaps I'm setting this up wrong or using the wrong JAX callback?

actually, I have tried with failure result even a fixed filename. Restoration in jit is quite difficult for me

from purejaxrl.

Related Issues (15)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.