GithubHelp home page GithubHelp logo

sunfanyunn / imma Goto Github PK

View Code? Open in Web Editor NEW
36.0 3.0 5.0 5.88 MB

Official code for "Interaction Modeling with Multiplex Attention" (NeurIPS 2022)

Python 97.11% HTML 2.54% Shell 0.35%
graph-neural-networks multi-agent-systems interaction-modeling multiplex-attention

imma's Introduction

Interaction Modeling with Multiplex Attention

Authors: Fan-Yun Sun, Isaac Kauvar, Ruohan Zhang, Jiachen Li, Mykel Kochenderfer, Jiajun Wu, Nick Haber

Abstract: Modeling multi-agent systems requires understanding how agents interact. Such systems are often difficult to model because they can involve a variety of types of interactions that layer together to drive rich social behavioral dynamics. Here we introduce a method for accurately modeling multi-agent systems. We present Interaction Modeling with Multiplex Attention (IMMA), a forward prediction model that uses a multiplex latent graph to represent multiple independent types of interactions and attention to account for relations of different strengths. We also introduce Progressive Layer Training, a training strategy for this architecture. We show that our approach outperforms state-of-the-art models in trajectory forecasting and relation inference, spanning three multi-agent scenarios: social navigation, cooperative task achievement, and team sports. We further demonstrate that our approach can improve zero-shot generalization and allows us to probe how different interactions impact agent behavior.

This repository contains the codes for our paper, which is accepted at NeurIPs 2022. For more details, please refer to the paper (arxiv, openreview).

Environment Setup

  1. Install Python-RVO2 library
  2. Install socialforce library
  3. Install necessary packages with pip
pip install -r requirements.txt

Data Setup

Social Navigation Environment

This is a simulated environment inspired by https://github.com/vita-epfl/CrowdNav. After installing necessary dependencies, refer to the following sample commands to start the simulation.

randomseed=17
dataset_size=100000
obs_frames=24
rollouts=10

cd data_utils/socialnav
python generate_dataset.py --dataset_size $dataset_size \
                           --randomseed $randomseed \
                           --obs_frames ${obs_frames} \
                           --rollouts ${rollouts}

The resulting dataset will be stored at datasets/*.tensor. You can make modifications to the config file dat_utils/socialnav/configs/default.py to change the simulation setting.

To inspect and interact with the environment (control the embodied agent with your arrow keys):

cd data_utils/socialnav
python human_play.py

PHASE

The preprocessed dataset is under datasets/phase/collab. To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

NBA dataset

Download the preprocessed dataset here (or run gdown 1eJbDHy3fOHfzOStf-jSuYCz_YQloQU3s) and place it under datasets. Alternatively, you can create your own dataset from raw sportVU logs (refer to this repository or the code under data_utils/bball) To load the dataset, refer to the function prepare_dataset in data_utils/load_dataset.py.

Training and Evaluation

Find sample commands at run_socialnav.sh, run_phase.sh and run_bball.sh.

Progressive Layered Training

Loss curve visualized over the course of training IMMA with PLT on the NBA dataset. New layers are added after the model "converges". Teaser image

Citation

If you find the code or paper useful for your research, please cite our paper:

@article{sun2022interaction,
  title={Interaction Modeling with Multiplex Attention},
  author={Sun, Fan-Yun and Kauvar, Isaac and Zhang, Ruohan and Li, Jiachen and Kochenderfer, Mykel and Wu, Jiajun and Haber, Nick},
  journal={Advances in Neural Information Processing Systems},
  year={2022}
}

Acknwledgement

In htis project we use (parts of) the implementations from the following works:

We thank the authors for open sourcing their methods.

imma's People

Contributors

sunfanyunn avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

imma's Issues

questions about training NBA dataset

Thanks for providing such a great work!I have some questions when training in the NBA datasets

  1. I use the default config in the main.py and run_bball.sh when training the NBA datasets. However, it seems that I have to take weeks to get result for MG+PLT setting and seven or eight days to get result for MG setting. I guess whether the default config is for the other two datasets? Could you please share the config and hyparameters about the NBA datasets?

  2. As said above, I use the default config to tain and get ade/fde result as 0.66/1.31, which is quite better than the result in the same setting presented in your paper. So I don't know if there are some errors in the config

  3. If the random seeds influence the final result? If it is, could you please share the seeds that lead to the results presented in the paper?

Thanks again for answering me questions listed above!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.