GithubHelp home page GithubHelp logo

alpc91 / sgrl Goto Github PK

View Code? Open in Web Editor NEW
16.0 2.0 1.0 87.72 MB

[ICML 2023 Oral] Official environments and implementations for "Subequivariant Graph Reinforcement Learning in 3D Environments"

Home Page: https://alpc91.github.io/SGRL/

License: MIT License

Python 99.90% Shell 0.10%
deep-learning generalization locomotion modularity reinforcement-learning equivariance graph transformer decentralized-control geometric-deep-learning modular-control multi-task-learning symmetry zero-shot

sgrl's Introduction

Subequivariant Graph Reinforcement Learning in 3D Environment

ICML 2023 Oral

Runfa Chen1, Jiaqi Han1, Fuchun Sun1 2, Wenbing Huang3 4

1Department of Computer Science and Technology, Institute for AI, BNRist Center, Tsinghua University, 2THU-Bosch JCML Center, 3Gaoling School of Artificial Intelligence, Renmin University of China, 4Beijing Key Laboratory of Big Data Management and Analysis Methods

This is a PyTorch-based implementation of our Subequivariant Graph Reinforcement Learning. In this work, we introduce a new morphology-agnostic RL benchmark that extends the widely adopted 2D-Planar setting to 3D-SGRL, permitting significantly larger exploring space of the agents with arbitrary initial location and target direction. To learn a policy in this massive search space, we design SET, a novel model that preserves geometric symmetry by construction. Experimental results strongly support the necessity of encoding symmetry into the policy network and its wide applicability towards learning to navigate in various 3D environments.

If you find this work useful in your research, please cite using the following BibTeX:

@inproceedings{chen2023sgrl,
    title = {Subequivariant Graph Reinforcement Learning in 3D Environment},
    author = {Chen, Runfa and Han, Jiaqi and Sun, Fuchun and Huang, Wenbing},
    booktitle={International Conference on Machine Learning},
    year={2023},
    organization={PMLR}
    }

Setup

Requirements

Installing Dependencies

pip install --upgrade pip
pip install -r requirements.txt

Running Code

Flags and Parameters Description
--env_name <STRING> The name of the experiment project folder and the project name in wandb
--morphologies <STRING> Find existing environments matching each keyword for training (e.g. walker, hopper, humanoid, cheetah, whh, cwhh, etc)
--expID <STRING> Experiment Name for creating saving directory
--exp_path <STRING> The directory path where the experimental results are saved
--config_path <STRING> The path to the configuration file
--gpu <INT> The GPU device ID (e.g., 0, 1, 2, 3, etc)
--custom_xml <PATH> Path to custom xml file for training the morphology-agnostic policy.
When <PATH> is a file, train with that xml morphology only.
When <PATH> is a directory, train on all xml morphologies found in the directory
--actor_type <STRING> Type of the actor to use (e.g., smp, swat, set, mlp, etc)
--critic_type <STRING> Type of the critic to use (e.g., smp, swat, set, mlp, etc)
--seed <INT> (Optional) Seed for Gym, PyTorch and Numpy

Train with existing environment

  • Train SET on 3D_Hopper++ (3 variants of hopper):
cd src/
bash start.sh

3D-SGRL Environments

3D Hopper

3d_hopper_3_shin

3d_hopper_4_lower_shin

3d_hopper_5_full
3D Walker

3d_walker_2_right_leg_left_knee

3d_walker_3_left_leg_right_foot

3d_walker_4_right_knee_left_foot

3d_walker_5_foot

3d_walker_5_left_knee

3d_walker_7_full

3d_walker_3_left_knee_right_knee

3d_walker_6_right_foot
3D Humanoid

3d_humanoid_7_left_arm

3d_humanoid_7_lower_arms

3d_humanoid_7_right_arm

3d_humanoid_7_right_leg

3d_humanoid_8_left_knee

3d_humanoid_9_full

3d_humanoid_7_left_leg

3d_humanoid_8_right_knee
3D Cheetah

3d_cheetah_10_tail_leftbleg

3d_cheetah_11_leftfleg

3d_cheetah_11_tail_rightfknee

3d_cheetah_12_rightbknee

3d_cheetah_12_tail_leftbfoot

3d_cheetah_13_rightffoot

3d_cheetah_13_tail

3d_cheetah_14_full

3d_cheetah_11_leftbkneen_rightffoot

3d_cheetah_12_tail_leftffoot

For the results reported in the paper, the following agents are in the held-out set for the corresponding experiments:

  • 3D_Walker++: 3d_walker_3_left_knee_right_knee, 3d_walker_6_right_foot
  • 3D_Humanoid++: 3d_humanoid_7_left_leg, 3d_humanoid_8_right_knee
  • 3D_Cheetah++: 3d_cheetah_11_leftbkneen_rightffoot, 3d_cheetah_12_tail_leftffoot

All other agents in the corresponding experiments are used for training.

Acknowledgement

The RL code is based on this open-source implementation and the morphology-agnostic implementation is built on top of SMP (Huang et al., ICML 2020), Amorpheus (Kurin et al., ICLR 2021) and SWAT (Hong et al., ICLR 2022) repository.

sgrl's People

Contributors

alpc91 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

zhibinqiu

sgrl's Issues

mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation.

Hello:
There are some bugs when I running "bash start.sh" or "bash start_humanoid.sh".
File "cymj.pyx", line 156, in mujoco_py.cymj.wrap_mujoco_warning.__exit__ File "cymj.pyx", line 77, in mujoco_py.cymj.c_warning_callback File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/mujoco_py/builder.py", line 368, in user_warning_raise_exception raise MujocoException(warn + 'Check for NaN in simulation.') File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/gym/envs/mujoco/mujoco_env.py", line 125, in do_simulation self.sim.step() File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/gym/envs/mujoco/mujoco_env.py", line 125, in do_simulation self.sim.step() mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. File "cymj.pyx", line 77, in mujoco_py.cymj.c_warning_callback File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/mujoco_py/builder.py", line 368, in user_warning_raise_exception raise MujocoException(warn + 'Check for NaN in simulation.') File "mjsim.pyx", line 126, in mujoco_py.cymj.MjSim.step File "mjsim.pyx", line 126, in mujoco_py.cymj.MjSim.step File "cymj.pyx", line 156, in mujoco_py.cymj.wrap_mujoco_warning.__exit__ mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/mujoco_py/builder.py", line 368, in user_warning_raise_exception raise MujocoException(warn + 'Check for NaN in simulation.') mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. File "cymj.pyx", line 156, in mujoco_py.cymj.wrap_mujoco_warning.__exit__ File "cymj.pyx", line 77, in mujoco_py.cymj.c_warning_callback File "cymj.pyx", line 77, in mujoco_py.cymj.c_warning_callback File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/mujoco_py/builder.py", line 368, in user_warning_raise_exception raise MujocoException(warn + 'Check for NaN in simulation.') mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. File "/home/amax/anaconda3/envs/sgrl/lib/python3.8/site-packages/mujoco_py/builder.py", line 368, in user_warning_raise_exception raise MujocoException(warn + 'Check for NaN in simulation.') mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. mujoco_py.builder.MujocoException: Unknown warning type Time = 0.0000.Check for NaN in simulation. __

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.