GithubHelp home page GithubHelp logo

sss135 / pytorch-rl-kit Goto Github PK

View Code? Open in Web Editor NEW
3.0 2.0 1.0 742 KB

Proximal Policy Optimization in PyTorch

License: MIT License

Python 100.00%
pytorch ppo proximal-policy-optimization deep-reinforcement-learning deep-learning deep-neural-networks actor-critic gym atari ale

pytorch-rl-kit's Introduction

Proximal Policy Optimization in PyTorch

Changes from PPO paper

A More General Robust Loss Function (Barron loss)

https://arxiv.org/abs/1701.03077

State-value is optimized with Barron loss. Advantages are scaled using Barron loss derivative. To use MSE loss for state-value and unscaled advantages set barron_alpha_c = (2, 1).

On average when used with Atari, instead of MSE / Huber loss, it does not change performance much.

Policy / value clip constraint is multiplied by abs(advantages)

This will make constraint different for each element in batch. Set advantage_scaled_clip = False to disable.

As with Barron loss, on average I haven't observed much difference with or without it.

KL Divergence penalty implementation is different

When kl < kl_target it is not applied. When kl > kl_target it is scaled quadratically based on abs(kl - kl_target) and policy and entropy maximization objectives are disabled.

I've found this implementation to be much easier to tune than original KL div penalty.

Additional clipping constraint

New constraint type which clips raw network output vector instead of action log probs. See 'opt' in PPO constraint documentation.

Sometimes it helps with convergence on continuous control tasks when used with clip or kl constraints.

Several different constraints could be applied at same time

See PPO constraint documentation.

Entroy added to reward

Entropy maximization helps in some games. See entropy_reward_scale in PPO.

Extra Atari network architectures

In addition to original network architecture, biggger one is available. See cnn_kind in CNNActor.

Quasi-Recurrent Neural Networks

https://arxiv.org/abs/1611.01576

See PPO_QRNN, QRNNActor, CNN_QRNNActor. QRNN implementation requires https://github.com/salesforce/pytorch-qrnn. With some effort QRNN could be replaced with another RNN architecture like LSTM or GRU.

Installation

pip install git+https://github.com/SSS135/ppo-pytorch

Required packages:

Training

Training code does not print any information to console. Instead it logs various info to Tensorboard.

Classic control

CartPole-v1 for 500K steps without CUDA (--force-cuda to enable it, won't improve performance)

python example.py --env-name CartPole-v1 --steps 500_000 --tensorboard-path /tensorboard/output/path

Atari

PongNoFrameskip-v4 for 10M steps (40M emulator frames) with CUDA

python example.py --atari --env-name PongNoFrameskip-v4 --steps 10_000_000 --tensorboard-path /tensorboard/output/path

New gym environments

When library is imported following gym environments are registered:

Continuous versions of Acrobot and CartPole AcrobotContinuous-v1, CartPoleContinuous-v0, CartPoleContinuous-v1

CartPole with 10000 steps limit CartPoleContinuous-v2, CartPole-v2

Results

PongNoFrameskip-v4

Activations of first convolution layer

Absolute value of gradients of state pixels (sort of pixel importance)

BreakoutNoFrameskip-v4

QbertNoFrameskip-v4

SpaceInvadersNoFrameskip-v4

SeaquestNoFrameskip-v4

CartPole-v1

CartPoleContinuous-v2

pytorch-rl-kit's People

Contributors

sss135 avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

famishedrover

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.