luchris429 / purejaxrl Goto Github PK
View Code? Open in Web Editor NEWReally Fast End-to-End Jax RL Implementations
License: Apache License 2.0
Really Fast End-to-End Jax RL Implementations
License: Apache License 2.0
DreamerV3 is a SOTA MbRL algorithm. It's already written in jax. I want to gather if anyone here is potentially interested to make it's environment interactions also into jax. ie If you have a purejax environment, can we modify the current dreamerv3 code so that we have a blazing fast setup similar to other algorithms here. I am just warming up to jax, so I don't have too much insight into it yet.
There are strange gaps in speed and performance of JAX v0.4.8
and JAX v 0.4.13
. Currently investigating. I would recommend sticking with <=0.4.8
for now.
Hi!
Thanks for releasing this codebase. It's so awesome to see how much faster we can iterate now with tools like this :)
I am just wondering if you attempted to use any environments that use mocaps with weld equality constraints. I am trying to get Meta-World running with Brax and am curious if there's a way around this issue. Thanks!
Hi! Noticed that the linear schedule for learning rate for a small number of steps is actually not linear. I doubt it makes much difference to the final results, but thought I'd show it anyway. Maybe I made a mistake somewhere?
total_timesteps = 10_000
num_minibatches = 32
num_steps = 128
num_envs = 4
num_epochs = 10
lr = 0.1
num_updates = total_timesteps // num_steps // num_envs
def linear_schedule_purejaxrl(count):
frac = 1.0 - (count // (num_minibatches * num_epochs)) / num_updates
return lr * frac
def linear_schedule(count):
frac = 1.0 - count / (num_epochs * num_minibatches * num_updates)
return lr * frac
total_gradient_updates = (num_minibatches * num_epochs) * num_updates
plt.plot([linear_schedule(i) for i in range(1, total_gradient_updates + 1)], label="linear")
plt.plot([linear_schedule_purejaxrl(i) for i in range(1, total_gradient_updates + 1)], label="purejaxrl")
plt.legend()
plt.xlabel("Update")
plt.ylabel("Learning Rate")
Hi,
I have just read the blog post and think this is really cool work.
I'm guessing that a full academic version of the work is coming out soon.
I have a couple of questions about Figure 5
Thanks for any information
Tried running the colab example to look at the speed but it seems like colab doesn't come installed with distrax brax gymnax
https://colab.research.google.com/github/luchris429/purejaxrl/blob/main/examples/example_1.ipynb
ModuleNotFoundError Traceback (most recent call last)
[<ipython-input-1-91b777b64c88>](https://localhost:8080/#) in <cell line: 9>()
7 from typing import Sequence, NamedTuple, Any
8 from flax.training.train_state import TrainState
----> 9 import distrax
10 import gymnax
11 from gymnax.wrappers.purerl import LogWrapper, FlattenObservationWrapper
ModuleNotFoundError: No module named 'distrax'
---------------------------------------------------------------------------
NOTE: If your import is failing due to a missing package, you can
manually install dependencies using either !pip or !apt.
To view examples of installing some common dependencies, click the
"Open Examples" button below.
I know this sound ridiculous but I've spent ages trying to implement checkpoint saving into your example/walkthrough training code and have been getting nowhere.
Similarly (as it's something to do every n steps or every epoch) I've been trying to reduce the frequency of metric collection, as it has been giving me VRAM errors with my lowly NVIDIA 3080.
Any advice / solutions would be very gratefully received.
It will be great to support the recurrent network in the PPO
I was wondering if there are any plans to release multi-gpu training code?
Naively pmapping and using DDPPO does not seem to scale well, as the gpus remain idle while syncing the gradients.
I notice that the implemented critic loss (https://github.com/luchris429/purejaxrl/blob/main/purejaxrl/ppo.py#L179) in PPO is quite different from traditional TD error, more like PPO's actor loss style. Could you please point me to any reference? If there is no such reference, are there any reasons behind for doing so?
Hey all! I'm trying to track down a seeming reproducibility issue I'm having with the PPO implementation after I added some simple WandB logging. I ran the same code 7 times, and 4 of the times the results are identical. However, 3 of the times the results differ:
Would anyone have any ideas as to why this might be happening?
Here's my slightly modified code:
from typing import NamedTuple, Sequence
import distrax
import flax.linen as nn
import gymnax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
from purejaxrl.wrappers import FlattenObservationWrapper, LogWrapper
import wandb
from jax import config
#config.update("jax_disable_jit", True)
class ActorCritic(nn.Module):
action_dim: Sequence[int]
activation: str = "tanh"
@nn.compact
def __call__(self, x):
if self.activation == "relu":
activation = nn.relu
else:
activation = nn.tanh
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(
self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
)(actor_mean)
pi = distrax.Categorical(logits=actor_mean)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(x)
critic = activation(critic)
critic = nn.Dense(
64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
)(critic)
critic = activation(critic)
critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
critic
)
return pi, jnp.squeeze(critic, axis=-1)
class Transition(NamedTuple):
done: jnp.ndarray
action: jnp.ndarray
value: jnp.ndarray
reward: jnp.ndarray
log_prob: jnp.ndarray
obs: jnp.ndarray
info: jnp.ndarray
def make_train(config):
config["NUM_UPDATES"] = (
config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
)
config["MINIBATCH_SIZE"] = (
config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
)
env, env_params = gymnax.make(config["ENV_NAME"])
env = FlattenObservationWrapper(env)
env = LogWrapper(env)
def linear_schedule(count):
frac = (
1.0
- (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
/ config["NUM_UPDATES"]
)
return config["LR"] * frac
def train(rng):
# INIT NETWORK
network = ActorCritic(
env.action_space(env_params).n, activation=config["ACTIVATION"]
)
rng, _rng = jax.random.split(rng)
init_x = jnp.zeros(env.observation_space(env_params).shape)
network_params = network.init(_rng, init_x)
if config["ANNEAL_LR"]:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(learning_rate=linear_schedule, eps=1e-5),
)
else:
tx = optax.chain(
optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
optax.adam(config["LR"], eps=1e-5),
)
train_state = TrainState.create(
apply_fn=network.apply,
params=network_params,
tx=tx,
)
# INIT ENV
rng, _rng = jax.random.split(rng)
reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
# TRAIN LOOP
def _update_step(runner_state, unused):
# COLLECT TRAJECTORIES
def _env_step(runner_state, unused):
train_state, time_state, env_state, last_obs, rng = runner_state
# SELECT ACTION
rng, _rng = jax.random.split(rng)
pi, value = network.apply(train_state.params, last_obs)
action = pi.sample(seed=_rng)
log_prob = pi.log_prob(action)
# STEP ENV
rng, _rng = jax.random.split(rng)
rng_step = jax.random.split(_rng, config["NUM_ENVS"])
obsv, env_state, reward, done, info = jax.vmap(
env.step, in_axes=(0, 0, 0, None)
)(rng_step, env_state, action, env_params)
transition = Transition(
done, action, value, reward, log_prob, last_obs, info
)
runner_state = (train_state, time_state, env_state, obsv, rng)
return runner_state, transition
runner_state, traj_batch = jax.lax.scan(
_env_step, runner_state, None, config["NUM_STEPS"]
)
# CALCULATE ADVANTAGE
train_state, time_state, env_state, last_obs, rng = runner_state
_, last_val = network.apply(train_state.params, last_obs)
def _calculate_gae(traj_batch, last_val):
def _get_advantages(gae_and_next_value, transition):
gae, next_value = gae_and_next_value
done, value, reward = (
transition.done,
transition.value,
transition.reward,
)
delta = reward + config["GAMMA"] * next_value * (1 - done) - value
gae = (
delta
+ config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
)
return (gae, value), gae
_, advantages = jax.lax.scan(
_get_advantages,
(jnp.zeros_like(last_val), last_val),
traj_batch,
reverse=True,
unroll=16,
)
return advantages, advantages + traj_batch.value
advantages, targets = _calculate_gae(traj_batch, last_val)
# UPDATE NETWORK
def _update_epoch(update_state, unused):
def _update_minbatch(train_state, batch_info):
traj_batch, advantages, targets = batch_info
def _loss_fn(params, traj_batch, gae, targets):
# RERUN NETWORK
pi, value = network.apply(params, traj_batch.obs)
log_prob = pi.log_prob(traj_batch.action)
# CALCULATE VALUE LOSS
value_pred_clipped = traj_batch.value + (
value - traj_batch.value
).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
value_losses = jnp.square(value - targets)
value_losses_clipped = jnp.square(value_pred_clipped - targets)
value_loss = (
0.5 * jnp.maximum(value_losses, value_losses_clipped).mean()
)
# CALCULATE ACTOR LOSS
ratio = jnp.exp(log_prob - traj_batch.log_prob)
gae = (gae - gae.mean()) / (gae.std() + 1e-8)
loss_actor1 = ratio * gae
loss_actor2 = (
jnp.clip(
ratio,
1.0 - config["CLIP_EPS"],
1.0 + config["CLIP_EPS"],
)
* gae
)
loss_actor = -jnp.minimum(loss_actor1, loss_actor2)
loss_actor = loss_actor.mean()
entropy = pi.entropy().mean()
total_loss = (
loss_actor
+ config["VF_COEF"] * value_loss
- config["ENT_COEF"] * entropy
)
return total_loss, (value_loss, loss_actor, entropy)
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
total_loss, grads = grad_fn(
train_state.params, traj_batch, advantages, targets
)
train_state = train_state.apply_gradients(grads=grads)
return train_state, total_loss
train_state, traj_batch, advantages, targets, rng = update_state
rng, _rng = jax.random.split(rng)
# Batching and Shuffling
batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
assert (
batch_size == config["NUM_STEPS"] * config["NUM_ENVS"]
), "batch size must be equal to number of steps * number of envs"
permutation = jax.random.permutation(_rng, batch_size)
batch = (traj_batch, advantages, targets)
batch = jax.tree_util.tree_map(
lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
)
shuffled_batch = jax.tree_util.tree_map(
lambda x: jnp.take(x, permutation, axis=0), batch
)
# Mini-batch Updates
minibatches = jax.tree_util.tree_map(
lambda x: jnp.reshape(
x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
),
shuffled_batch,
)
train_state, total_loss = jax.lax.scan(
_update_minbatch, train_state, minibatches
)
update_state = (train_state, traj_batch, advantages, targets, rng)
return update_state, total_loss
# Updating Training State and Metrics:
update_state = (train_state, traj_batch, advantages, targets, rng)
update_state, loss_info = jax.lax.scan(
_update_epoch, update_state, None, config["UPDATE_EPOCHS"]
)
train_state = update_state[0]
metric = traj_batch.info
rng = update_state[-1]
time_state["timesteps"] = time_state["timesteps"] + 512
metric["time"] = time_state["timesteps"]
metric["loss"] = loss_info
# Debugging mode
if config.get("DEBUG"):
def callback(info):
return_values = info["returned_episode_returns"][
info["returned_episode"]
]
timesteps = (
info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
)
for t in range(len(timesteps)):
print(
f"global step={timesteps[t]}, episodic return={return_values[t]}"
)
data_log = {
"losses/value_loss": info["loss"][1][0].mean().item(),
"losses/policy_loss": info["loss"][1][1].mean().item(),
"losses/entropy": info["loss"][1][2].mean().item(),
"losses/total_loss": info["loss"][0].mean().item(),
}
if return_values.size > 0:
data_log["misc/episodic_return"] = return_values.mean().item()
wandb.log(data_log, step=info["time"])
jax.debug.callback(callback, metric)
runner_state = (train_state, time_state, env_state, last_obs, rng)
return runner_state, metric
rng, _rng = jax.random.split(rng)
time_state = {"timesteps": jnp.array(0)}
runner_state = (train_state, time_state, env_state, obsv, _rng)
runner_state, metric = jax.lax.scan(
_update_step, runner_state, None, config["NUM_UPDATES"]
)
return {"runner_state": runner_state, "metrics": metric}
return train
if __name__ == "__main__":
config = {
"LR": 2.5e-4,
"NUM_ENVS": 4,
"NUM_STEPS": 128,
"TOTAL_TIMESTEPS": 100000,
"UPDATE_EPOCHS": 4,
"NUM_MINIBATCHES": 4,
"GAMMA": 0.99,
"GAE_LAMBDA": 0.95,
"CLIP_EPS": 0.2,
"ENT_COEF": 0.01,
"VF_COEF": 0.5,
"MAX_GRAD_NORM": 0.5,
"ACTIVATION": "tanh",
"ENV_NAME": "CartPole-v1",
"ANNEAL_LR": True,
"DEBUG": True,
}
run_name = f"test"
wandb_id = wandb.util.generate_id()
run_id = f"{run_name}_{wandb_id}"
wandb.init(
id=run_id,
project="test5",
name=run_name,
mode="online",
)
rng = jax.random.PRNGKey(30)
train_jit = jax.jit(make_train(config))
out = train_jit(rng)
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.