GithubHelp home page GithubHelp logo

pyTorch about ga3c HOT 6 OPEN

nvlabs avatar nvlabs commented on May 14, 2024
pyTorch

from ga3c.

Comments (6)

ifrosio avatar ifrosio commented on May 14, 2024

We are not planning implementing it for now, but some people are indeed suggesting that pyTorch may be faster than TF. It would be great if someone can implement GA3C in pyTorch following our guidelines.

from ga3c.

etienne87 avatar etienne87 commented on May 14, 2024

I did a quick trial in one of my branches . Actually, TF is almost twice as fast, because the naive way I did the vectorized loss is probably involving a lot of function calls. The same issue arises for Chainer version. The loss takes almost more time to compute than the cnn. I think it could work faster if implementing it as a specific layer.

from ga3c.

ppwwyyxx avatar ppwwyyxx commented on May 14, 2024

Just FYI, my friend was able to reproduce both the speed and performance of my a3c implementation with his pytorch code.
It batches data differently from GA3C, but the overall structure is similar.

from ga3c.

etienne87 avatar etienne87 commented on May 14, 2024

interesting @ppwwyyxx !
My naive implementation gives something like this :

results txt

I am not sure if the problem is in the batching, rather than the explicit calls & many steps of computation for the loss.

        p, v = self.model.forward_multistep(x_, c, h)
        probs = F.softmax(p)
        probs = F.relu(probs - Config.LOG_EPSILON)
        log_probs = torch.log(probs) 
        adv = (rewards - v)
        adv = torch.masked_select(adv,mask)
        log_probs_a = torch.masked_select(log_probs,a) #we cannot use it because of variable length input
        piloss = -torch.sum( log_probs_a * Variable(adv.data), 0)  
        entropy = torch.sum(torch.sum(log_probs*probs,1),0) * self.beta
        vloss = torch.sum(adv.pow(2),0) / 2
        loss = piloss + entropy + vloss

If someone knows how to do this more quickly in pytorch ...?

from ga3c.

dylanthomas avatar dylanthomas commented on May 14, 2024

@ppwwyyxx Is there a public git repo for your friend's pyTorch implementation ?

from ga3c.

ppwwyyxx avatar ppwwyyxx commented on May 14, 2024

Unfortunately no..

from ga3c.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.