GithubHelp home page GithubHelp logo

by571 / randomized-ensembled-double-q-learning-redq- Goto Github PK

View Code? Open in Web Editor NEW
19.0 3.0 0.0 783 KB

Pytorch implementation of Randomized Ensembled Double Q-learning (REDQ)

Python 13.65% Jupyter Notebook 86.35%
reinforcement-learning deep-reinforcement-learning soft-actor-critic q-learning ensemble-learning redq randomized-ensemble double-q-learning

randomized-ensembled-double-q-learning-redq-'s Introduction

Randomized-Ensembled-Double-Q-learning-REDQ-

PyTorch implementation of Randomized-Ensembled-Double-Q-learning-REDQ. This repo contains a notebook version and a script version to run REDQ or SAC.

For more information about REDQ check out the paper or checkout my medium article about it. In the future I will add REDQ to my Soft-Actor-Critic-and-Extensions repository so you can combine it with several other performance increasing extensions like PER, D2RL or Munchausen RL.

Dependencies

Trained and tested on:

Python 3.6
PyTorch 1.7.0  
Numpy 1.15.2 
gym 0.10.11 
pybulletgym

How to use:

The new script combines all extensions and the add-ons can be simply added by setting the corresponding flags.

python train.py -info redq

To train SAC simply set the REDQ specific parameter to (N=2, M=2, G=1).

python train.py --N 2 --M 2 --G 1 -info sac

Observe training results

tensorboard --logdir=runs

Results

REDQ trained with N=5,M=2,G=5 (REDQ actually recommends 10,2,20) However, for faster training I trained with the adapted parameters. If someone finds a way to speed up training please let me know. with N=10, G=20 it takes ~ 10x longer when compared to regular SAC.

Pendulum

LunarLanderContinuous REDQ-212 is the regular SAC implementation for N=2, G=1 and M=2.

ToDos:

  • Currently this REDQ version supports only a subsample size of 2 for the REDQ hyperparameter M, however, this repository will be updated over time.
  • Do comparison runs for REDQ and SAC [ currently running for pybullet environments like cheetah, hopper]
  • improve training speed (wall-clock time)
  • add requirements.txt

Author

  • Sebastian Dittert

Feel free to use this code for your own projects or research.

@misc{REDQ,
  author = {Dittert, Sebastian},
  title = {PyTorch Implementation of Randomized-Ensembled-Double-Q-learning-REDQ-},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/BY571/Randomized-Ensembled-Double-Q-learning-REDQ-}},
}

randomized-ensembled-double-q-learning-redq-'s People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

randomized-ensembled-double-q-learning-redq-'s Issues

Application to discrete action space

Hi,

I was wondering that have you applied this idea to environments with discrete action space? or have any idea how it would perform?

Thanks

Actor Update Bug

Should the actor update not utilise idx[0] and idx[1] for Q1 and Q2? currently it just gets the same value of Q from the same critic

---------------------------- update actor ----------------------------

    if step == self.G-1:

        actions_pred, log_prob, _ = self.actor_local.sample(states)             
        
        # TODO: make this variable for possible more than two critics
        Q1 = self.critics[idx[**0**]](states, actions_pred.squeeze(0)).cpu()
        Q2 = self.critics[idx[**0**]](states, actions_pred.squeeze(0)).cpu()
        Q = torch.min(Q1,Q2)

How to design the Average Actor network?

When I design a ensemble Q like this, the algo never convergence? Have you any idean about this?

        q_ensemble = self.critic[0](states, sample_action)
        for k in range(1, self.N):
            q_curr = self.critic[k](states, sample_action)
            q_ensemble = q_ensemble + q_curr
        q=q_ensemble/self.N
        policy_loss = torch.mean((-q - self.alpha * entropy))
        # update the Actor parameters
        self.policy_optim.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optim.step()

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.