GithubHelp home page GithubHelp logo

shuijing725 / crowdnav_prediction_attngraph Goto Github PK

View Code? Open in Web Editor NEW
125.0 3.0 23.0 46.98 MB

[ICRA 2023] Intention Aware Robot Crowd Navigation with Attention-Based Interaction Graph

Home Page: https://sites.google.com/view/intention-aware-crowdnav/home

License: MIT License

Python 97.89% Shell 2.11%
attention-is-all-you-need collision-avoidance crowd-navigation deep-reinforcement-learning human-robot-interaction trajectory-prediction deep-learning reinforcement-learning

crowdnav_prediction_attngraph's Introduction

CrowdNav++

This repository contains the codes for our paper titled "Intention Aware Robot Crowd Navigation with Attention-Based Interaction Graph" in ICRA 2023. For more details, please refer to the project website and arXiv preprint. For experiment demonstrations, please refer to the youtube video.

[News]

  • Please check out our open-sourced sim2real tutorial here
  • Please check out my curated paper list for robot social navigation here (It is under active development)

Abstract

We study the problem of safe and intention-aware robot navigation in dense and interactive crowds. Most previous reinforcement learning (RL) based methods fail to consider different types of interactions among all agents or ignore the intentions of people, which results in performance degradation. In this paper, we propose a novel recurrent graph neural network with attention mechanisms to capture heterogeneous interactions among agents through space and time. To encourage longsighted robot behaviors, we infer the intentions of dynamic agents by predicting their future trajectories for several timesteps. The predictions are incorporated into a model-free RL framework to prevent the robot from intruding into the intended paths of other agents. We demonstrate that our method enables the robot to achieve good navigation performance and non-invasiveness in challenging crowd navigation scenarios. We successfully transfer the policy learned in simulation to a real-world TurtleBot 2i.

Setup

  1. In a conda environment or virtual environment with Python 3.x, install the required python package
pip install -r requirements.txt
  1. Install Pytorch 1.12.1 following the instructions here

  2. Install OpenAI Baselines

git clone https://github.com/openai/baselines.git
cd baselines
pip install -e .
  1. Install Python-RVO2 library

Overview

This repository is organized in five parts:

  • crowd_nav/ folder contains configurations and policies used in the simulator.
  • crowd_sim/ folder contains the simulation environment.
  • gst_updated/ folder contains the code for running inference of a human trajectory predictor, named Gumbel Social Transformer (GST) [2].
  • rl/ contains the code for the RL policy networks, wrappers for the prediction network, and ppo algorithm.
  • trained_models/ contains some pretrained models provided by us.

Note that this repository does not include codes for training a trajectory prediction network. Please refer to from this repo instead.

Run the code

Training

  • Modify the configurations.
    1. Environment configurations: Modify crowd_nav/configs/config.py. Especially,

      • Choice of human trajectory predictor:
        • Set sim.predict_method = 'inferred' if a learning-based GST predictor is used [2]. Please also change pred.model_dir to be the directory of a trained GST model. We provide two pretrained models here.
        • Set sim.predict_method = 'const_vel' if constant velocity model is used.
        • Set sim.predict_method = 'truth' if ground truth predictor is used.
        • Set sim.predict_method = 'none' if you do not want to use future trajectories to change the observation and reward.
      • Randomization of human behaviors: If you want to randomize the ORCA humans,
        • set env.randomize_attributes to True to randomize the preferred velocity and radius of humans;
        • set humans.random_goal_changing to True to let humans randomly change goals before they arrive at their original goals.
    2. PPO and network configurations: modify arguments.py

      • env_name (must be consistent with sim.predict_method in crowd_nav/configs/config.py):
        • If you use the GST predictor, set to CrowdSimPredRealGST-v0.
        • If you use the ground truth predictor or constant velocity predictor, set to CrowdSimPred-v0.
        • If you don't want to use prediction, set to CrowdSimVarNum-v0.
      • use_self_attn: human-human attention network will be included if set to True, else there will be no human-human attention.
      • use_hr_attn: robot-human attention network will be included if set to True, else there will be no robot-human attention.
  • After you change the configurations, run
    python train.py 
    
  • The checkpoints and configuration files will be saved to the folder specified by output_dir in arguments.py.

Testing

Please modify the test arguments in line 20-33 of test.py (Don't set the argument values in terminal!), and run

python test.py 

Note that the config.py and arguments.py in the testing folder will be loaded, instead of those in the root directory.
The testing results are logged in trained_models/your_output_dir/test/ folder, and are also printed on terminal.
If you set visualize=True in test.py, you will be able to see visualizations like this:

Test pre-trained models provided by us

Method --model_dir in test.py --test_model in test.py
Ours without randomized humans trained_models/GST_predictor_no_rand 41200.pt
ORCA without randomized humans trained_models/ORCA_no_rand 00000.pt
Social force without randomized humans trained_models/SF_no_rand 00000.pt
Ours with randomized humans trained_models/GST_predictor_rand 41665.pt

Plot predicted future human positions

To visualize the episodes with predicted human trajectories, as well as saving visualizations to disk, please refer to save_slides branch.
Note that the above visualization and file saving will slow down testing significantly!

  • Set save_slides=True in test.py and all rendered frames will be saved in a subfolder inside the trained_models/your_output_dir/social_eval/.

Plot the training curves

python plot.py

Here are example learning curves of our proposed network model with GST predictor.

Sim2Real

We are happy to announce that our sim2real tutorial and code are released here!
Note: This repo only serves as a reference point for the sim2real transfer of crowd navigation. Since there are lots of uncertainties in real-world experiments that may affect performance, we cannot guarantee that it is reproducible on all cases.

Disclaimer

  1. We only tested our code in Ubuntu with Python 3.6 and Python 3.8. The code may work on other OS or other versions of Python, but we do not have any guarantee.

  2. The performance of our code can vary depending on the choice of hyperparameters and random seeds (see this reddit post). Unfortunately, we do not have time or resources for a thorough hyperparameter search. Thus, if your results are slightly worse than what is claimed in the paper, it is normal. To achieve the best performance, we recommend some manual hyperparameter tuning.

Citation

If you find the code or the paper useful for your research, please cite the following papers:

@inproceedings{liu2022intention,
  title={Intention Aware Robot Crowd Navigation with Attention-Based Interaction Graph},
  author={Liu, Shuijing and Chang, Peixin and Huang, Zhe and Chakraborty, Neeloy and Hong, Kaiwen and Liang, Weihang and Livingston McPherson, D. and Geng, Junyi and Driggs-Campbell, Katherine},
  booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
  year={2023},
  pages={12015-12021}
}

@inproceedings{liu2020decentralized,
  title={Decentralized Structural-RNN for Robot Crowd Navigation with Deep Reinforcement Learning},
  author={Liu, Shuijing and Chang, Peixin and Liang, Weihang and Chakraborty, Neeloy and Driggs-Campbell, Katherine},
  booktitle={IEEE International Conference on Robotics and Automation (ICRA)},
  year={2021},
  pages={3517-3524}
}

Credits

Other contributors:
Peixin Chang
Zhe Huang
Neeloy Chakraborty

Part of the code is based on the following repositories:

[1] S. Liu, P. Chang, W. Liang, N. Chakraborty, and K. Driggs-Campbell, "Decentralized Structural-RNN for Robot Crowd Navigation with Deep Reinforcement Learning," in IEEE International Conference on Robotics and Automation (ICRA), 2019, pp. 3517-3524. (Github: https://github.com/Shuijing725/CrowdNav_DSRNN)

[2] Z. Huang, R. Li, K. Shin, and K. Driggs-Campbell. "Learning Sparse Interaction Graphs of Partially Detected Pedestrians for Trajectory Prediction," in IEEE Robotics and Automation Letters, vol. 7, no. 2, pp. 1198–1205, 2022. (Github: https://github.com/tedhuang96/gst)

Contact

If you have any questions or find any bugs, please feel free to open an issue or pull request.

crowdnav_prediction_attngraph's People

Contributors

shuijing725 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

crowdnav_prediction_attngraph's Issues

weigth for DS-RNN (original)

Hi @Shuijing725 !
I read your paper and code. it was awesome!!.

Anyway, i have a single question.
I want to use your old version =DS-RNN-CrowdNav in this recently version.
In your config , there is exist about robot.policy.
So i modified this part to "srnn", and i tested by a example weight about recently version.
Did i run original version or recently version?

simulation for ros_turtlebot2i_env

Hello,

I encounter a question for ros_turtlebot2i_env.
In the code

humanStatesSub = Subscriber('/dr_spaam_detections', PoseArray) # human px, py, vi

how is this topic published in the ROS work space?

Thanks.

GST training question.

Hi! @Shuijing725 , i read your reference.
But there is no script about how to make train dataset.
Could you tell me how to create dataset?
(In my situation, actually i use other platform about DRL environment. )

Values contained in /map topic do not seem meaningful

Dear researcher,

Thank you for your fantastic work. I am working on running your project. However, the messages received by /map seems not meaningful. In a map that is almost empty, the map.data I received is almost full of -1. If I understand it correctly, value 0 stands for open space without obstacles. So what is the meaning of -1 and 100:
image

Test.py settings partial settings not working

Hello, do you still remember me? I still have some issues with the test.py, when I set save_slides=True, the folder trained_models/GST_predictor_rand/social_eval/41665 where the results are saved is empty. There is also a problem with the --test_case parameter setting. When I set any other positive number, the test will still run for 500 episodes. I'm confused and made some modifications to evaluation.py but it doesn't work. Do you have any suggestions? Maybe the uploaded file is incomplete

Prediction horizons, intervals between prediction steps, and simulation frequency

Hi researchers,

Thank you for your elegant work. I am working on a project that requires a deeper dive into prediction models, and I am having some difficulty finding the prediction horizons, intervals between prediction steps, and intervals between simulation steps. Would you mind telling me the values or indicating which files contain these parameter settings? (I suppose if I use pretrained GST models, the prediction horizons and prediction frequency are set and cannot be changed by config files, correct?)

Thank you so much!

Test result storage

When I finished executing test.py, I failed to find the saved test results in the trained_models/your_output_dir/test/ folder
Hope you can give me some advice thanks

Training and compute resources

Hi, please could you provide information about what compute resources you used to train the model and how long it took.

How to become multiple intelligent agents?

Hello, I really like your work!I have encountered some questions and would like to ask you.
Recently, I hope to transform one intelligent agent into multiple intelligent agents.

I hope these multiple agents adopt the same reinforcement learning strategy and set sim. predict_ method='Inferred'.

I would like to know which code needs to be modified?
Thank you.

training issue about dsrnn

Hi @Shuijing725 !.
I have a issue about your code in training with srnn.
I changed robot.policy in config.py from "selfAttn_merge_srnn" to "srnn".
Then when i run the train.py, the error occurred like under line.

`<Monitor<CrowdSimPredRealGST>>
No ghost version.
new gst
new st model
LOADED MODEL
device: cuda:0

Traceback (most recent call last):
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 247, in
main()
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/train.py", line 90, in main
actor_critic = Policy(
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/model.py", line 28, in init
self.base = base(obs_shape, base_kwargs)
File "/root/dsrnn_ws/CrowdNav_Prediction_AttnGraph/rl/networks/srnn_model.py", line 378, in init
robot_size = 7 if args.env_type == 'crowd_sim' else 2
AttributeError: 'Namespace' object has no attribute 'env_type'. Did you mean: 'env_name'`?

Could you solve this issue? Thank you for reading!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.