GithubHelp home page GithubHelp logo

brain-agent's Introduction

Brain Agent

Brain Agent is a distributed agent learning system for large-scale and multi-task reinforcement learning, developed by Kakao Brain. Brain Agent is based on the V-trace actor-critic framework IMPALA and modifies Sample Factory to utilize multiple GPUs and CPUs for a much higher throughput rate during training.

Features

  1. First publicly available implementations of reproducing SOTA results on DMLAB30.
  2. Scalable & massive throughput.
    BrainAgent can produce and train 20B frames/week, or 34K fps, with 16 V100 GPUs, by scaling up high throughput single node system Sample Factory.
  3. Based on following algorithms and architectures.
  4. For self-supervised representation learning, we include 2 additional features.
    • ResNet-based decoder to reconstruct the original input image (trxl_recon)
    • Additional autoregressive transformer to predict the images of future steps from the current state embedding and future action sequence (trxl_future_pred)
  5. Provide codes for both training and evaluation, along with SOTA model checkpoint with 28M params.

How to Install

  • Python 3.7
  • Pytorch 1.9.0
  • CUDA 11.1
  • Install DMLab envrionment - DMLab Github
  • pip install -r requirements.txt

Description of Codes

  • dist_launch.py -> distributed training launcher
  • eval.py -> entry point for evaluation
  • train.py -> entry point for training
  • brain_agent
    • core
      • agents
        • dmlab_multitask_agent.py
      • algos
        • aux_future_predict.py -> Computes auxiliary loss by predicting future state transitions with autoregressive transformer. Used only for (trxl_future_pred).
        • popart.py
        • vtrace.py
      • actor_worker.py
      • learner_worker.py
      • policy_worker.py
      • shared_buffer.py -> Defines SharedBuffer class for zero-copy communication between workers.
    • envs
      • dmlab
    • utils
      • ...
  • configs
    • ... -> Hyperparam configs for each of training/evaluation.

How to Run

Training

  • 1 node x 1 GPU

       python train.py cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
  • 1 node x 4 GPUs = 4 GPUs

       python -m dist_launch --nnodes=1 --node_rank=0 --nproc_per_node=4 -m train \ 
         cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
  • 4 nodes x 4 GPUs each = 16 GPUs

       sleep 120; python -m dist_launch --nnodes=4 --node_rank=0 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 
         cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR 
       sleep 120; python -m dist_launch --nnodes=4 --node_rank=1 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 
         cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
       sleep 120; python -m dist_launch --nnodes=4 --node_rank=2 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 
         cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR
       sleep 120; python -m dist_launch --nnodes=4 --node_rank=3 --nproc_per_node=4 --master_addr=$MASTER_ADDR -m train \ 
         cfg=configs/trxl_recon_train.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR

Evaluation

python eval.py cfg=configs/trxl_recon_eval.yaml train_dir=$TRAIN_DIR experiment=$EXPERIMENT_DIR test.checkpoint=$CHECKPOINT_FILE_PATH 

Setting Hyperparameters

  • All the default hyperparameters are defined at configs/default.yaml
  • Other config files override on configs/default.yaml.
  • You can use pre-defined hyperparameters for our experiments with configs/trxl_recon_train.yaml or configs/trxl_future_pred.yaml.

Results for DMLAB30

  • Settings

    • 3 runs with different seeds
    • 100 episodes per each run
    • HNS : Human Normalised Score
  • Results

    Model Mean HNS Median HNS Mean Capped HNS
    MERLIN 115.2 - 89.4
    GTrXL 117.6 - 89.1
    CoBERL 115.47 110.86 -
    R2D2+ - 99.5 85.7
    LASER - 97.2 81.7
    PBL 104.16 - 81.5
    PopArt-IMPALA - - 72.8
    IMPALA - - 58.4
    Ours (lstm_baseline, 20B ckpt) 103.03 ± 0.37 92.04 ± 0.73 81.35 ± 0.25
    Ours (trxl_baseline, 20B ckpt) 111.95 ± 1.00 105.43 ± 2.61 85.57 ± 0.20
    Ours (trxl_recon, 20B ckpt) 123.60 ± 0.84 108.63 ± 1.20 91.25 ± 0.41
    Ours (trxl_future_pred, 20B ckpt) 128.00 ± 0.43 108.80 0.99 90.53 ± 0.26
Results for all 30 tasks
Level lstm_baseline  trxl_baseline      trxl_recon     trxl_future_pred
rooms_collect_good_objects_(train / test) 94.22 ± 0.84 / 95.13 ± 0.61 97.85 ± 0.31 / 95.20 ± 1.26 97.58 ± 0.20 / 89.39 ± 1.42 98.19 ± 0.18 / 98.52 ± 0.95
rooms_exploit_deferred_effects_(train / test) 37.84 ± 2.23 / 4.36 ± 1.84 38.40 ± 3.82 / 1.73 ± 0.63 38.86 ± 3.48 / 4.04 ± 0.89 40.93 ± 3.12 / 2.26 ± 0.71
rooms_select_nonmatching_object 50.13 ± 2.95 98.78 ± 1.38 99.52 ± 0.97 113.20 ± 1.14
rooms_watermaze 45.09 ± 4.70 36.92 ± 6.90 111.20 ± 2.29 55.82 ± 0.74
rooms_keys_doors_puzzle 51.75 ± 8.90 55.86 ± 4.25 61.24 ± 9.09 64.95 ± 8.43
language_select_described_object 150.57 ± 0.58 154.90 ± 0.22 155.35 ± 0.17 158.23 ± 0.90
language_select_located_object 225.97 ± 1.93 244.46 ± 1.56 252.04 ± 0.31 261.20 ± 1.15
language_execute_random_task 126.49 ± 2.35 139.63 ± 1.23 145.21 ± 0.36 150.20 ± 1.35
language_answer_quantitative_question 153.92 ± 2.35 162.99 ± 2.42 163.72 ± 1.36 166.07 ± 1.72
lasertag_one_opponent_small 234.90 ± 6.19 243.52 ± 3.96 249.99 ± 6.64 279.54 ± 4.14
lasertag_three_opponents_small 235.61 ± 1.92 242.61 ± 3.75 246.68 ± 5.99 264.20 ± 3.76
lasertag_one_opponent_large 74.88 ± 5.06 83.51 ± 1.31 82.55 ± 2.15 94.86 ± 3.64
lasertag_three_opponents_large 84.78 ± 2.42 92.04 ± 2.17 96.54 ± 0.67 105.83 ± 0.47
natlab_fixed_large_map 98.10 ± 1.77 110.74 ± 1.34 120.53 ± 1.79 118.17 ± 1.79
natlab_varying_map_regrowth 108.54 ± 1.20 107.16 ± 2.68 108.14 ± 1.25 104.83 ± 1.26
natlab_varying_map_randomized 85.33 ± 6.52 86.33 ± 7.30 85.53 ± 6.69 77.74 ± 0.84
skymaze_irreversible_path_hard 55.29 ± 9.08 60.63 ± 4.73 61.63 ± 2.52 66.30 ± 5.69
skymaze_irreversible_path_varied 77.02 ± 3.57 77.41 ± 0.67 81.31 ± 2.34 79.36 ± 7.95
psychlab_arbitrary_visuomotor_mapping 52.17 ± 2.06 51.46 ± 0.45 101.82 ± 0.19 101.80 ± 0.00
psychlab_continuous_recognition 52.57 ± 0.46 52.41 ± 0.92 102.46 ± 0.32 102.30 ± 0.00
psychlab_sequential_comparison 76.82 ± 0.45 75.48 ± 1.16 75.74 ± 0.58 76.13 ± 0.77
psychlab_visual_search 101.54 ± 0.10 101.58 ± 0.04 101.91 ± 0.00 101.90 ± 0.00
explore_object_locations_small 118.89 ± 0.93 121.47 ± 0.26 123.54 ± 2.61 126.67 ± 2.08
explore_object_locations_large 111.46 ± 2.91 120.70 ± 2.12 115.43 ± 1.64 129.83 ± 2.41
explore_obstructed_goals_small 136.92 ± 6.02 148.05 ± 1.96 166.75 ± 3.63 174.30 ± 3.72
explore_obstructed_goals_large 92.36 ± 5.81 106.73 ± 7.86 153.44 ± 3.20 176.43 ± 1.50
explore_goal_locations_small 143.21 ± 8.21 154.87 ± 4.41 177.16 ± 0.37 193.00 ± 3.75
explore_goal_locations_large 98.50 ± 9.61 117.33 ± 6.75 160.39 ± 3.32 178.13 ± 7.15
explore_object_rewards_few 76.29 ± 1.52 108.64 ± 0.89 109.58 ± 3.53 110.07 ± 1.42
explore_object_rewards_many 72.33 ± 0.87 105.33 ± 1.52 105.15 ± 0.75 107.23 ± 1.59
  • Learning curves
    Learning Curve

Distributed RL System Overview

Learning Curve

Notes

Citation

@misc{kakaobrain2022brain_agent, title = {Brain Agent}, 
author = {Donghoon Lee, Taehwan Kwon, Seungeun Rho, Daniel Wontae Nam, Jongmin Kim, Daejin Jo, and Sungwoong Kim}, 
year = {2022}, howpublished = {\url{https://github.com/kakaobrain/brain_agent}} }

brain-agent's People

Contributors

donghoonlee04 avatar seungeunrho avatar sungwoong avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

brain-agent's Issues

More documentation and comments planned?

Hello everyone!

Thanks for your promising project!

Are you guys planning on adding further documentation like how to utilize other environments?
DMLab is only available for Linux. Does your code run on Windows or Mac as well?

Also, are you planning on commenting your code to make it more readable? This would be really helpful to actually work with it.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.