GithubHelp home page GithubHelp logo

tdmpc2-jax's Introduction

Fun with all things autonomous!

ShaneFlandermeyer's GitHub stats

tdmpc2-jax's People

Contributors

shaneflandermeyer avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

edwhu edoust

tdmpc2-jax's Issues

Nice work! Reproducing original results?

Hi Shane,

Great work on this. I'll definitely have a closer look soon. Do you have any preliminary benchmark results that compare the jax and pytorch versions? Both in terms of env steps and wall-time. Super curious to see how they compare!

Environment support

Hey, great work on this one!

Using your codebase I trained on a few jittable environments (mainly gymnax and self-written), which works quite well, however it requires some changes to the training script. It also requires more environment wrappers for the (tensorboard) logging to work

Do you plan to have full support for a jax-based framework like gymnax or brax? If so, any thoughts on which library to support?

Having support for jittable environments (without gym wrappers) might make it possible to scan over the training loop

Environment Setup

Hello! Amazing work with the jax implementation and vectorization. I wanted to ask could you please give me some pointers on how to use this implementation for custom gym/ mujoco-gym env? Assuming it's continuous control with a simple vector for the state.

Help with CrossQ implementation

I'm interested in implementing the CrossQ critic update, which only requires two Q networks, and no target networks. This could speed up TDMPC2 a decent amount. A key part of the method is using BatchNorm correctly.

I'm rather new to Flax, and implementing BatchNorm is a bit annoying. We have to carry the batch statistics as part of the training state. I made some preliminary progress towards this, but got a bit stuck with the batch normalization details.

A few requests and questions:

  1. Could you make a cross-q branch in the official repo so I can make a PR? We can continue discussion from there. I don't think the CrossQ ever needs to be merged into main, but having it as a variant could be useful to others.

  2. Onto the actual bug:
    First, I made a batch normalized version of the Q function in mlp.py that can take in a training boolean to specify train or evaluation mode.

However, it seems like once that Q function is initialized and traced, when I later try to pass in the boolean, I get an error.

  @jax.jit
  def Q(self, z: jax.Array, a: jax.Array, params: Dict, key: PRNGKeyArray, train: bool
        ) -> Tuple[jax.Array, jax.Array]:
    z = jnp.concatenate([z, a], axis=-1)
    # TODO: figure out why including train as an argument breaks things.
    # logits, updates = self.value_model.apply_fn(
    #     {'params': params, 'batch_stats': self.value_model.batch_stats}, z, rngs={'dropout': key}, mutable=['batch_stats'])
    logits, updates = self.value_model.apply_fn(
    {'params': params, 'batch_stats': self.value_model.batch_stats}, z, train, rngs={'dropout': key}, mutable=['batch_stats'])

    Q = two_hot_inv(logits, self.symlog_min, self.symlog_max, self.num_bins)
    return Q, logits, updates

So I get this error:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/train.py", line 217, in train
    agent, train_info = agent.update(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/tdmpc2.py", line 318, in update
    (encoder_grads, dynamics_grads, value_grads, reward_grads, continue_grads), model_info = jax.grad(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/tdmpc2.py", line 264, in world_model_loss_fn
    all_q, all_q_logits, updates = self.model.Q(jnp.concat([zs[:-1], next_z]), jnp.concat([actions, next_action]), value_params, key=Q_key, train=True)
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/world_model.py", line 291, in Q
    logits, updates = self.value_model.apply_fn(
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/networks/ensemble.py", line 19, in __call__
    return ensemble()(*args, **kwargs)
  File "/home/edward/miniconda3/envs/tdmpc2jax/lib/python3.10/site-packages/flax/linen/combinators.py", line 105, in __call__
    outputs = self.layers[0](*args, **kwargs)
  File "/home/edward/projects/tdmpc2-jax/tdmpc2_jax/networks/mlp.py", line 215, in __call__
    x = self.norm(dtype=self.dtype, use_running_average = not train)(x)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function Q at /home/edward/projects/tdmpc2-jax/tdmpc2_jax/world_model.py:286 for jit. This concrete value was not available in Python because it depends on the value of the argument train.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
  0%|                                                                    

So currently to just get the code to run, I don't pass in the training boolean, so it is always running training=True by default. But in some cases, I think I want training=False like when I'm using the Q function for planning. Either way, it seems like attempting to pass in training=... as an argument throws the error, which suggests to me the Jax tracing isn't recognizing training as a dynamic argument for some reason. You can see the hack / hotfix here:
https://github.com/edwhu/tdmpc2-jax/blob/495d09657b64b1d82298eb184a555921bd9e383e/tdmpc2_jax/world_model.py#L291

Would love to hear your thoughts on this, I've been stuck on this for a couple hours over the past few days. Thanks!

Saving trained model

Is there any plans on adding support for saving single/multi-task trained models and continued training of these models?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.