GithubHelp home page GithubHelp logo

dqn's Introduction

dqn

This is a very basic DQN (with experience replay) implementation, which uses OpenAI's gym environment and Keras/Theano neural networks.

Requirements

  • gym
  • keras
  • theano
  • numpy

and all their dependencies.

Usage

To run, python example.py <env_name>. It runs MsPacman-v0 if no env is specified. Uncomment the env.render() line to see the game while training, however, this is likely to make training slow.

Currently, it assumes that the observation is an image, i.e. a 3d array, which is the case for all Atari games, and other Atari-like environments.

Purpose

This is meant to be a very simple implementation, to be used as a starter code. I aimed it to be easy-to-comprehend rather than feature-complete.

Pull requests welcome!

References

TODO

  • Extend to other environemnts. Currently only works for Atari and Atari-like environments where the observation space is a 3D Box.

dqn's People

Contributors

sherjilozair avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

dqn's Issues

Cropping won't work for all environments

Hey @sherjilozair/everyone,

Just letting you know, this almost certainly won't work for any arbitrary game — at least, correctly — in its current form. The problem is the cropping: cropping a static 84x84 window in the center will miss a lot of important information.

Take for instance SpaceInvaders-v0. This is what the convolutional network sees:
figure_1. Where is the ship that we control? It is cropped off the bottom of the screen. Thus, this information can't be contained within the network.

I appreciate that this is just a simple implementation, but it's worth noting that a complete solution would have hand picked regions over which to crop for each game (unfortunately).

A couple of questions ...

  1. The default memory is set up as 50 episodes in your code, the deep-mind Atari paper uses 1,000,000 -- this seems like a major limiting issue when training -- can you achieve similar performance with 50 ?
  2. The deep-mind paper runs for 100 x 50,000 episode epochs during training roughly, do you see roughly the same training time to proficiency on the same games here with similar scores?
  3. Looking at a few new environments and value fn networks, I have experienced somewhat unstable average max value function outputs (growing undoubtedly or large fluctuation) -- I'm wondering if it makes sense to do some kind of action value target clipping or re-scaling - this seems to occur when the network is having difficulty differentiating next states from current states (lots of the inputs are the same or indistinguishable but a few are slightly different) -- is this kind of instability expected? what kind of training time is expected to achieve more stability here?

dqn.py indexing is not right

I found the indexing in build_function not right.
You can run the code below to testify the wrong indexing in VS[:, A]

This indexing should be written like line 51, 52 in https://github.com/ShibiHe/DQN_OpenAI_keras/blob/master/agents.py or https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py

def build_functions(self):
    S = Input(shape=self.state_size)
    NS = Input(shape=self.state_size)
    A = Input(shape=(1,), dtype='int32')
    R = Input(shape=(1,), dtype='float32')
    T = Input(shape=(1,), dtype='int32')
    self.build_model()
    self.value_fn = K.function([S], self.model(S))

    VS = self.model(S)
    VNS = disconnected_grad(self.model(NS))
    future_value = (1-T) * VNS.max(axis=1, keepdims=True)
    discounted_future_value = self.discount * future_value
    target = R + discounted_future_value
    cost0 = VS[:, A] - target
    cost = ((VS[:, A] - target)**2).mean()
    opt = RMSprop(0.0001)
    params = self.model.trainable_weights
    updates = opt.get_updates(params, [], cost)

    self.train_fn = K.function([S, NS, A, R, T], [cost, cost0, target, A], updates=updates)
    # import numpy as np
    # t = self.train_fn([np.random.rand(10, *self.state_size), np.random.rand(10, *self.state_size), np.ones((10, 1)), np.ones((10, 1)), np.zeros((10, 1))])
    # print('cost=', t[0])
    # print('cost0=', t[1])
    # print('target=', t[2])
    # print('A=', t[3])
    # raw_input()_

Areas for improvement

I'm working with this code and have made a few changes already (would submit a pull request but they way I've done them is pretty hacky and I've never done pull requests before :) ). They are:

  • changing hyperparameters (.9 --> .99 for discount, .1 to .05 for epsilon). These are based on the Mnih et al. 2015 hyperparameters. Not sure what the performance effect is but .1 seems high for a final epsilon level.
  • adding video output
  • epsilon annealing (this was done hackily, by manually specifying epsilons for different episode intervals, but could be done more cleanly. Per Mnih et al., I am starting with epsilon of 1 and annealing roughly linearly for a thousand episodes).

Other possible areas for improvement:

  • grayscaling for efficiency
  • frame skip for efficiency

I'd potentially be interested in pull requesting some of these if I can figure out how, but just thought I'd post this first to get thoughts on the above/see if people have other ideas for key areas of improvement.

Does not run example.py get error

when i try to run the example.py i get the following error:

raise TypeError('outputs of a TensorFlow backend function '
TypeError: outputs of a TensorFlow backend function should be a list or tuple.

any help would be appreciated

failed to run with the following error

File "/Users/fox/PycharmProjects/dqn-master/dqn.py", line 53, in build_functions
self.build_model()

assert type(outputs) in {list, tuple}, 'Output to a TensorFlow backend function should be a list or tuple.'
AssertionError: Output to a TensorFlow backend function should be a list or tuple.

Does not work with Tensorflow

I get the following error when running with TensorFlow (with GPU)

current context was not created by the StreamExecutor cuda_driver API: 0x2ccf300; a CUDA runtime call was likely performed without using a StreamExecutor context

Issue is that the code is referencing Theano directly:-
from theano.gradient import disconnected_grad

reference:- tensorflow/tensorflow#916

differences with Mnih

Hello! according to Mnih, the function phi applies a preprocessing to the last 4 frames of a history and stacks them to produce the input to the Q function, however, reading your code I understand that it only feeds one raw frame to the Q network. Am I right?

I also found in dqn.py, the procedure iterate(self), it has a for which says:
episode = random.randint(max(0, N-50), N-1)

shouldn't this be N-self.memory instead of N-50?

This is my first interaction here, hope you understand 😄

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.