GithubHelp home page GithubHelp logo

Recurrent Dqn about atari HOT 31 OPEN

kaixhin avatar kaixhin commented on July 17, 2024
Recurrent Dqn

from atari.

Comments (31)

Kaixhin avatar Kaixhin commented on July 17, 2024

Yep a switch for using a DRQN architecture would be great. For now I'd go for using histLen as the number of frames to use BPTT on for a single-frame DRQN. Would be good to base it on the rnn library, especially since it now has the optimised SeqLSTM.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

This is the Caffe implementation from the paper:
https://github.com/mhauskn/dqn/tree/recurrent

Altough Caffe I never looked at probably will help.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin I see you started working on this, cool. I'll have some time now, so I'll look at the multigpu and async modes.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@lake4790k Almost have something working. Disabling this line lets the DRQN train, as otherwise it crashes here, somehow propagating a batch of size 20 forward but expecting the normal batch size of 32 backwards.

I'm new to the rnn library, so let me know if you have any ideas. Performance is considerably slower, which will be due to having to process several time steps sequentially. This is in line with Appendix B in that paper though.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin Awesome! I have no experience with rnn either, I will need to study it to have an idea. I have two 980TIs and will be able to run longer experiments to see if it goes anywhere.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@lake4790k I'd have to delve into the original paper/code, but it looks like they train the network every step (as opposed to every 4). This seems like it'll be a problem for BPTT. In any case if you haven't used rnn before I'll focus on this.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin cool, I'll have my hands full with async for now, but in the meantime I'll be able to help with running longer rdqn experiments on my workstation when you think it's worth trying.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

Here's the result of running ./run.sh demo -recurrent true, so I'm reasonably confident that the DRQN is capable of learning, but I'm not testing this further for now so I'm leaving this issue open. In any case, I still haven't solved this issue (which I mentioned above).

scores

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

Pinging @JoostvDoorn since he's contributed to rnn and may have ideas about the minibatch problem/performance improvements/whether it's possible to save and restore state before and after training (and if that should be done since the parameters have changed slightly).

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

@Kaixhin I will have a look later.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin I'm not getting the error you mentioned when doing validation on the last batch with size 20 when running demo. I'm using the master code which has sequencer:remember('both') enabled. You mention you had to disable that to not crash...? master runs fine for me as it is.

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

I think this is in the rnn branch. This may or may not be a bug when using FastLSTM with the nngraph version. Setting nn.FastLSTM.usenngraph = false changed the error for me, but I only got the chance to look at this for a moment.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

ok so there are two issues:

  1. nn.FastLSTM.usenngraph = true
    nngraph/gmodule.lua:335: split(4) cannot split 32 outputs
    this is issue in both rnn and master
  2. nn.FastLSTM.usenngraph = false
    Wrong size for view. Input size: 20x1x3. Output size: 32x3
    this is only in rnn, because @Kaixhin fixed #16 in master (but not in rnn) that returns before doing the backward during validation, because it is not even needed, so maybe no issue after all?

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024
  1. With nn.FastLSTM.usenngraph = true, I get the same error as @lake4790k. This seems to be Element-Research/rnn#172. Which is a shame, as apparently it's significantly faster with this flag enabled (see Element-Research/rnn#182).
  2. Yes, so if you remove the return on line 374 in master then it fails. So I consider this a bug, albeit one that is being hidden by that return - why is this occurring even when states is 20x4x1x24x24 and QCurr is 20x1x3? If the error is dependent on previous batches then the learning must be incorrect. I was wrong and removing sequencer:remember('both') doesn't stop the crash.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin re: 2. agree, this error is bad, so returning before is not a solution. I'm not sure if learning is bad with the normal batch sizes, could be only not handling a batch size change somewhere properly. I tried an isolated FastLSTM+Sequencer net, there switching batch sizes worked fine, weird. I'm looking adding LSTM to async, once I get that working will experiment with this further.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@lake4790k I also tried a simple FastLSTM + Sequencer net with different batch sizes - no problem. I agree with it being likely that some module is not switching its internal variables to the correct size, but finding out exactly where the problem lies is tricky. It may be that I haven't set up the recurrency correctly, but apart from this batch size issue it seems to work fine.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin I need to refresh async from master for the recurrent, should I do a merge or rebase (I'm thinking of merge rather)? Does it even matter when merging back from async to master eventually?

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@lake4790k I'd go with a merge since it preserves history correctly. It's better to make sure all the changes in master are integrated sooner rather than later.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

Done the merge and added recurrent support for 1-step Q in async. This is 7 minutes of training, seems to work well:

scores

Agent sees only the latest frame per step and backpropagates with unrolling 5 steps on every step, weights are updated every 5 (or terminal) steps, no Sequencer is needed in this algo. I used sharedRmsProp and kept the ReLU after the FastLSTM to have comparable setup to my usual async testing.

Pretty cool that is works, I'll try if it performs similar with a flickering catch as they did in the paper with the flickering pong. Also in the async paper they added a half size LSTM layer after the linear instead of replacing it, will try that as well (although the DRQN paper says replacing is the best).

Will add support for the n-step methods as well, there it's a bit trickier to get right as there are steps taken forwards and backwards to calculate n-step returns, will have to take care that forwards/backwards are correct for LSTM as well.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

Also tried replacing FastLSTM with GRU with everything else being the same, that did not converge after running it longer interestingly.

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

@lake4790k Do you have the flickering catch version somewhere?

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@JoostvDoorn haven't got around to it since, but probably takes a few lines to add to rlenvs.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@JoostvDoorn I can add that to rlenvs.Catch if you want? You may also be interested in the obscured option I set up, which blanks a strip of screen at the bottom so that the agent has to infer the motion of the ball properly. Quick enough to test by adding opt.obscured = true in Setup.lua.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@JoostvDoorn Done. Just get the latest version of rlenvs and this repo. -flickering is a probability between 0 and 1 of the screen blanking out.

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

@Kaixhin Great thanks.

Have you tried storing the state instead of calling forget for every time step? I am doing this now, however it takes longer to train but it will probably converge. I agree this has to do with the changing state distribution, but we cannot really let the agent explore without considering the history to take full advantage of the LSTM.
scores4

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@JoostvDoorn I thought that this line would actually set remember for all internal modules, but I'm not certain? If that is not the case then yes I agree it should be set on the LSTM units themselves.

In summy, in Agent:observe, the only place that forget is called is at a terminal state. Of course when learning it should call forget before passing the minibatch through, and after learning as well. This means that memSampleFreq is the maximum amount of history the LSTMs keep during training, but they receive the entire history during validation/evaluation.

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

@Kaixhin Yes that line is enough, I will change that in my pull request.

I missed memSampleFreq, so I assumed it was calling forget every time. I guess memSampleFreq >= histLen is a good thing here, such that training, and updating have a similar distribution. Do note though that the 5th action will update based on the 2th, 3th, 4th, and 5th state in the Q-learning update, while the policy followed will be only be based on the 5th state, right?

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

@JoostvDoorn Yep memSampleFreq >= histLen would be sensible. Sorry not sure I understand your last question though. During learning updates for recurrent networks, histLen is used to determine the sequence length of states fed in (no concatenating frames in time as with a normal DQN). During training the hidden state will go back until the last time forget was called (and forget is called every memSampleFreq).

from atari.

JoostvDoorn avatar JoostvDoorn commented on July 17, 2024

I guess like this; forget is called at the first time step so the LSTM will not have accumulated any information at this point, once here it will start accumulating state information (note however on torch.uniform() < epsilon we don't accumulate info, which is a bug). Now after calling Agent:learn we call forget again. Then once the episode continues, and reaches the point here the state information is the same as in the start of the episode, depending on the environment this is a problem.

from atari.

Kaixhin avatar Kaixhin commented on July 17, 2024

Thanks for spotting the bug. @lake4790k please check 626712b to make sure async agents are accounted for as well.

@JoostvDoorn If I understand correctly then the issue is that the agent can't retain information during training because observe is interspersed with forget calls during learn? That's what I was wondering about above. My reasoning comes from the rnn docs. Also, it would be prohibitive to keep old states from before learn and pass them all through the network before starting again.

from atari.

lake4790k avatar lake4790k commented on July 17, 2024

@Kaixhin yes this is needed for async, just created #47 to do it a bit differently.

from atari.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.