GithubHelp home page GithubHelp logo

luchris429 / purejaxrl Goto Github PK

View Code? Open in Web Editor NEW
550.0 11.0 46.0 1.72 MB

Really Fast End-to-End Jax RL Implementations

License: Apache License 2.0

Python 100.00%
deep-reinforcement-learning jax reinforcement-learning reinforcement-learning-algorithms ppo

purejaxrl's People

Contributors

cool-rr avatar lcipolina avatar luchris429 avatar lupuandr avatar mttga avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

purejaxrl's Issues

DreamerV3 purejaxrl-ify

DreamerV3 is a SOTA MbRL algorithm. It's already written in jax. I want to gather if anyone here is potentially interested to make it's environment interactions also into jax. ie If you have a purejax environment, can we modify the current dreamerv3 code so that we have a blazing fast setup similar to other algorithms here. I am just warming up to jax, so I don't have too much insight into it yet.

JAX Version Strangeness

There are strange gaps in speed and performance of JAX v0.4.8 and JAX v 0.4.13. Currently investigating. I would recommend sticking with <=0.4.8 for now.

MuJoco Support for weld equality constraint

Hi!
Thanks for releasing this codebase. It's so awesome to see how much faster we can iterate now with tools like this :)

I am just wondering if you attempted to use any environments that use mocaps with weld equality constraints. I am trying to get Meta-World running with Brax and am curious if there's a way around this issue. Thanks!

linear schedule is not linear (?)

Hi! Noticed that the linear schedule for learning rate for a small number of steps is actually not linear. I doubt it makes much difference to the final results, but thought I'd show it anyway. Maybe I made a mistake somewhere?

total_timesteps = 10_000
num_minibatches = 32
num_steps = 128
num_envs = 4
num_epochs = 10
lr = 0.1
num_updates = total_timesteps // num_steps // num_envs

def linear_schedule_purejaxrl(count):
    frac = 1.0 - (count // (num_minibatches * num_epochs)) / num_updates
    return lr * frac

def linear_schedule(count):
    frac = 1.0 - count / (num_epochs * num_minibatches * num_updates)
    return lr * frac


total_gradient_updates = (num_minibatches * num_epochs) * num_updates

plt.plot([linear_schedule(i) for i in range(1, total_gradient_updates + 1)], label="linear")
plt.plot([linear_schedule_purejaxrl(i) for i in range(1, total_gradient_updates + 1)], label="purejaxrl")
plt.legend()
plt.xlabel("Update")
plt.ylabel("Learning Rate")

Result:
example

Applications of meta-rl to DQN

Hi,

I have just read the blog post and think this is really cool work.
I'm guessing that a full academic version of the work is coming out soon.

I have a couple of questions about Figure 5

  1. Do you have any more understanding of what is happening there? Why this is preferred over L2?
  2. I'm reminded of Revisiting the Arcade Learning Environment: Evaluation Protocols and Open Problems for General Agents that found that the huber loss function does a bit better over the standard L2 loss as well for DQN, so I would be interested in replicating this work for DQN and investigating the type of loss function found there. Is there similarity in the asymmetry and non-convexness of the PPO loss function.

Thanks for any information

Colab examples need to install distrax gymnax and brax?

Tried running the colab example to look at the speed but it seems like colab doesn't come installed with distrax brax gymnax

https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/example_1.ipynb

ModuleNotFoundError                       Traceback (most recent call last)

[<ipython-input-1-91b777b64c88>](https://localhost:8080/#) in <cell line: 9>()
      7 from typing import Sequence, NamedTuple, Any
      8 from flax.training.train_state import TrainState
----> 9 import distrax
     10 import gymnax
     11 from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper

ModuleNotFoundError: No module named 'distrax'


---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.

To view examples of installing some common dependencies, click the
"Open Examples" button below.

Checkpointing and less regular metric collection

I know this sound ridiculous but I've spent ages trying to implement checkpoint saving into your example/walkthrough training code and have been getting nowhere.

Similarly (as it's something to do every n steps or every epoch) I've been trying to reduce the frequency of metric collection, as it has been giving me VRAM errors with my lowly NVIDIA 3080.

Any advice / solutions would be very gratefully received.

Multi GPU support

I was wondering if there are any plans to release multi-gpu training code?
Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.

Different results when running PPO with the same seed multiple times?

Hey all! I'm trying to track down a seeming reproducibility issue I'm having with the PPO implementation after I added some simple WandB logging. I ran the same code 7 times, and 4 of the times the results are identical. However, 3 of the times the results differ:

image

Would anyone have any ideas as to why this might be happening?

Here's my slightly modified code:

from typing import NamedTuple, Sequence

import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from purejaxrl.wrappers import FlattenObservationWrapper, LogWrapper
import wandb

from jax import config
#config.update("jax_disable_jit", True)

class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"

    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        pi = distrax.Categorical(logits=actor_mean)

        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
            critic
        )

        return pi, jnp.squeeze(critic, axis=-1)


class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray


def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = gymnax.make(config["ENV_NAME"])
    env = FlattenObservationWrapper(env)
    env = LogWrapper(env)

    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac

    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).n, activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(learning_rate=linear_schedule, eps=1e-5),
            )
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(config["LR"], eps=1e-5),
            )
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)

        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, time_state, env_state, last_obs, rng = runner_state

                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)

                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = jax.vmap(
                    env.step, in_axes=(0, 0, 0, None)
                )(rng_step, env_state, action, env_params)
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, time_state, env_state, obsv, rng)
                return runner_state, transition

            runner_state, traj_batch = jax.lax.scan(
                _env_step, runner_state, None, config["NUM_STEPS"]
            )

            # CALCULATE ADVANTAGE
            train_state, time_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)

            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae

                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                return advantages, advantages + traj_batch.value

            advantages, targets = _calculate_gae(traj_batch, last_val)

            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minbatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info

                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)

                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = (
                            0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
                        )

                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                                ratio,
                                1.0 - config["CLIP_EPS"],
                                1.0 + config["CLIP_EPS"],
                            )
                            * gae
                        )
                        loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
                        loss_actor = loss_actor.mean()
                        entropy = pi.entropy().mean()

                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)

                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grads = grad_fn(
                        train_state.params, traj_batch, advantages, targets
                    )
                    train_state = train_state.apply_gradients(grads=grads)
                    return train_state, total_loss

                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                # Batching and Shuffling
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
                ), "batch size must be equal to number of steps * number of envs"
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch
                )
                # Mini-batch Updates
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                train_state, total_loss = jax.lax.scan(
                    _update_minbatch, train_state, minibatches
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss

            # Updating Training State and Metrics:
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch, update_state, None, config["UPDATE_EPOCHS"]
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]
            time_state["timesteps"] = time_state["timesteps"] + 512
            metric["time"] = time_state["timesteps"]
            metric["loss"] = loss_info

            # Debugging mode
            if config.get("DEBUG"):

                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )
                        
                    data_log = {
                        "losses/value_loss": info["loss"][1][0].mean().item(),
                        "losses/policy_loss": info["loss"][1][1].mean().item(),
                        "losses/entropy": info["loss"][1][2].mean().item(),
                        "losses/total_loss": info["loss"][0].mean().item(),
                    }
                    if return_values.size > 0:
                        data_log["misc/episodic_return"] = return_values.mean().item()
                    wandb.log(data_log, step=info["time"])

                jax.debug.callback(callback, metric)

            runner_state = (train_state, time_state, env_state, last_obs, rng)
            return runner_state, metric

        rng, _rng = jax.random.split(rng)
        time_state = {"timesteps": jnp.array(0)}
        runner_state = (train_state, time_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step, runner_state, None, config["NUM_UPDATES"]
        )
        return {"runner_state": runner_state, "metrics": metric}

    return train


if __name__ == "__main__":
    config = {
        "LR": 2.5e-4,
        "NUM_ENVS": 4,
        "NUM_STEPS": 128,
        "TOTAL_TIMESTEPS": 100000,
        "UPDATE_EPOCHS": 4,
        "NUM_MINIBATCHES": 4,
        "GAMMA": 0.99,
        "GAE_LAMBDA": 0.95,
        "CLIP_EPS": 0.2,
        "ENT_COEF": 0.01,
        "VF_COEF": 0.5,
        "MAX_GRAD_NORM": 0.5,
        "ACTIVATION": "tanh",
        "ENV_NAME": "CartPole-v1",
        "ANNEAL_LR": True,
        "DEBUG": True,
    }
    run_name = f"test"
    wandb_id = wandb.util.generate_id()
    run_id = f"{run_name}_{wandb_id}"
    
    wandb.init(
        id=run_id,
        project="test5",
        name=run_name,
        mode="online",
    )
    
    rng = jax.random.PRNGKey(30)
    train_jit = jax.jit(make_train(config))
    out = train_jit(rng)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.