GithubHelp home page GithubHelp logo

lucidrains / anymal-belief-state-encoder-decoder-pytorch Goto Github PK

View Code? Open in Web Editor NEW
58.0 6.0 8.0 1.1 MB

Implementation of the Belief State Encoder / Decoder in the new breakthrough robotics paper from ETH Zürich

License: MIT License

Python 100.00%
deep-learning artificial-intelligence robotics locomotion-control

anymal-belief-state-encoder-decoder-pytorch's Introduction

Belief State Encoder / Decoder (Anymal) - Pytorch

Implementation of the Belief State Encoder / Decoder in the new breakthrough robotics paper from ETH Zürich.

This paper is important as it seems their learned approach produced a policy that rivals Boston Dynamic's handcrafted algorithms (quadripedal Spot).

The results speak for itself in their video demonstration

Install

$ pip install anymal-belief-state-encoder-decoder-pytorch

Usage

Teacher

import torch
from anymal_belief_state_encoder_decoder_pytorch import Teacher

teacher = Teacher(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50
)

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)
privileged = torch.randn(1, 50)

action_logits, values = teacher(proprio, extero, privileged, return_values = True) # (1, 10)

Student

import torch
from anymal_belief_state_encoder_decoder_pytorch import Student

student = Student(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    gru_num_layers = 2,
    gru_hidden_size = 50
)

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)

action_logits, hiddens = student(proprio, extero) # (1, 10), (2, 1, 50)
action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)
action_logits, hiddens = student(proprio, extero, hiddens) # (1, 10), (2, 1, 50)

# hiddens are in the shape (num gru layers, batch size, gru hidden dimension)
# train with truncated bptt

Full Anymal (which contains both Teacher and Student)

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

# mock data

proprio = torch.randn(1, 133)
extero = torch.randn(1, 4, 52)
privileged = torch.randn(1, 50)

# first train teacher

teacher_action_logits = anymal.forward_teacher(proprio, extero, privileged)

# teacher is trained with privileged information in simulation with domain randomization

# after teacher has satisfactory performance, init the student with the teacher weights, excluding the privilege information encoder from the teacher (which student does not have)

anymal.init_student_with_teacher()

# then train the student on the proprioception and noised exteroception, forcing it to reconstruct the privileged information that the teacher had access to (as well as learning to denoise the exterception) - there is also a behavior loss between the policy logits of the teacher with those of the student

loss, hiddens = anymal(proprio, extero, privileged)
loss.backward()

# finally, you can deploy the student to the real world, zero-shot

anymal.eval()
dist, hiddens = anymal.forward_student(proprio, extero, return_action_categorical_dist = True)
action = dist.sample()

PPO training of the Teacher (using a mock environment, this needs to be substituted with a environment wrapper around simulator)

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal, PPO
from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

mock_env = MockEnv(
    proprio_dim = 133,
    extero_dim = 52,
    privileged_dim = 50
)

ppo = PPO(
    env = mock_env,
    anymal = anymal,
    epochs = 10,
    lr = 3e-4,
    eps_clip = 0.2,
    beta_s = 0.01,
    value_clip = 0.4,
    max_timesteps = 10000,
    update_timesteps = 5000,
)

# train for 10 episodes

for _ in range(10):
    ppo()

# save the weights of the teacher for student training

torch.save(anymal.state_dict(), './anymal-with-trained-teacher.pt')

To train the student

import torch
from anymal_belief_state_encoder_decoder_pytorch import Anymal
from anymal_belief_state_encoder_decoder_pytorch.trainer import StudentTrainer
from anymal_belief_state_encoder_decoder_pytorch.ppo import MockEnv

anymal = Anymal(
    num_actions = 10,
    num_legs = 4,
    extero_dim = 52,
    proprio_dim = 133,
    privileged_dim = 50,
    recon_loss_weight = 0.5
)

# first init student with teacher weights, at the very beginning
# if not resuming training

mock_env = MockEnv(
    proprio_dim = 133,
    extero_dim = 52,
    privileged_dim = 50
)

trainer = StudentTrainer(
    anymal = anymal,
    env = mock_env
)

# for 100 episodes

for _ in range(100):
    trainer()

... You've beaten Boston Dynamics and its team of highly paid control engineers!

But you probably haven't beaten a real quadripedal "anymal" just yet :)

Todo

  • finish belief state decoder
  • wrapper class that instantiates both teacher and student, handle student forward pass with reconstruction loss + behavioral loss
  • handle noising of exteroception for student
  • add basic PPO logic for teacher
  • add basic student training loop with mock environment
  • make sure all hyperparameters for teacher PPO training + teacher / student distillation is in accordance with appendix
  • noise scheduler for student (curriculum factor that goes from 0 to 1 from epochs 1 to 100)
  • fix student training, it does not look correct
  • make sure tbptt is setup correctly
  • add reward crafting as in paper
  • play around with deepminds mujoco

Diagrams

Citations

@article{2022,
  title     = {Learning robust perceptive locomotion for quadrupedal robots in the wild},
  url       = {http://dx.doi.org/10.1126/scirobotics.abk2822},
  journal   = {Science Robotics},
  publisher = {American Association for the Advancement of Science (AAAS)},
  author    = {Miki, Takahiro and Lee, Joonho and Hwangbo, Jemin and Wellhausen, Lorenz and Koltun, Vladlen and Hutter, Marco},
  year      = {2022},
  month     = {Jan}
}

anymal-belief-state-encoder-decoder-pytorch's People

Contributors

lucidrains avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

anymal-belief-state-encoder-decoder-pytorch's Issues

student training

i'm fairly sure i got the student network correct, as well as the teacher -> student distillation code

but not confident about how the rollouts are done (and the subsequent learning and truncated BPTT)

Off-policy PPO?

Doesn't PPO, at least the vanilla variant, only work on-policy? That is, from recent data, not an experience replay?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.