GithubHelp home page GithubHelp logo

Comments (4)

puyuan1996 avatar puyuan1996 commented on June 4, 2024 1

Of course, the current modifications are made for the sake of program compatibility. The better approach is to identify the location where the shape of the action tensor changes from torch.Size([128, 2]) to torch.Size([128, 2, 1]). We would greatly appreciate it if you could debug and locate the corresponding position.

from lightzero.

ekiefl avatar ekiefl commented on June 4, 2024 1

Identifying the problem

I found the problem. I'm using a batch size of 32 in these examples, and a continuous action space size of 2.

At this point in the code, the action space for the whole batch seems properly shaped:

obs_batch_ori, action_batch, child_sampled_actions_batch, mask_batch, indices, weights, make_time = current_batch

shape of action_batch: (32, 5, 2)

Then, there's this line:

# shape: (batch_size, num_unroll_steps, action_dim)
# NOTE: .float(), in continuous action space.
action_batch = torch.from_numpy(action_batch).to(self._cfg.device).float().unsqueeze(-1)

shape of action_batch: torch.Size([32, 5, 2, 1])

This is a problem the next time action_batch is used:

network_output = self._learn_model.recurrent_inference(
latent_state, reward_hidden_state, action_batch[:, step_k]
)

shape of action_batch[:, step_k]: torch.Size([32, 2, 1])

Testing variables

Here is the initial action_batch:

from numpy import array
array([[[-0.4490683972835541, -0.9950420260429382],
        [-0.4014930129051208, -0.7829872369766235],
        [-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064]],

       [[-0.4014930129051208, -0.7829872369766235],
        [-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019]],

       [[-0.4213773608207703, -0.7039519548416138],
        [-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556]],

       [[-0.4044414460659027, -0.834709882736206 ],
        [-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177]],

       [[-0.1590566784143448, -0.2195666283369064],
        [-0.7565191984176636, -0.2842720746994019],
        [ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823]],

       [[ 0.6578846573829651,  0.8156458139419556],
        [ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [ 0.3157244141177024,  0.122345928192084 ]],

       [[ 0.7340825796127319,  0.8159331679344177],
        [-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [ 0.0447234260510479, -0.0061242674195186],
        [ 2.076320333718342 ,  0.0612052095300819]],

       [[-0.4984613955020905,  0.2341146618127823],
        [ 0.2147858291864395,  0.8802192211151123],
        [-0.0753794324983024,  0.018401418199892 ],
        [ 0.230092215421578 , -0.154497093666865 ],
        [ 0.8510225575898217,  0.0945327318063604]],

       [[ 0.0692359060049057, -0.5581157207489014],
        [-0.0967390909790993,  0.6544007062911987],
        [-0.5564347505569458, -0.5217161774635315],
        [ 0.227927178144455 , -0.5079849362373352],
        [-0.7363128662109375,  0.7372166514396667]],

       [[-0.5564347505569458, -0.5217161774635315],
        [ 0.227927178144455 , -0.5079849362373352],
        [-0.7363128662109375,  0.7372166514396667],
        [ 0.8861203789710999,  0.7655220031738281],
        [-0.8919497132301331,  0.7664076685905457]],

       [[ 0.8861203789710999,  0.7655220031738281],
        [-0.8919497132301331,  0.7664076685905457],
        [-0.8111757040023804,  0.3714333176612854],
        [ 0.9299460443093313,  0.62140516600195  ],
        [-0.7294452606963147, -0.5507582288642449]],

       [[-0.4489460289478302, -0.4065003991127014],
        [-0.0877294093370438,  0.6240969896316528],
        [-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917]],

       [[-0.0877294093370438,  0.6240969896316528],
        [-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917],
        [-0.2952068150043488,  0.7421513199806213]],

       [[-0.145333394408226 ,  0.0135438507422805],
        [ 0.7303228378295898,  0.324310839176178 ],
        [ 0.1930731981992722, -0.6004103422164917],
        [-0.2952068150043488,  0.7421513199806213],
        [ 0.3888700902462006, -0.6838327646255493]],

       [[ 0.7991305589675903, -0.6135419011116028],
        [-1.4756834812262265, -1.5111042458027768],
        [-1.0753735258729593, -0.0687822111641535],
        [-0.6506161214320586, -0.2692845067856774],
        [-1.509456993847584 ,  0.4623811175903334]],

       [[ 0.9131344556808472, -0.9124650955200195],
        [-0.6182714104652405,  0.8433078527450562],
        [ 0.9737944006919861,  0.3280912041664124],
        [-0.503142774105072 , -0.4597557783126831],
        [ 0.1592527478933334, -0.3711529672145844]],

       [[ 0.97501140832901  , -0.9102342128753662],
        [-0.6536456346511841, -0.6314908266067505],
        [-0.3102632761001587,  0.5417718887329102],
        [ 0.5856770137135533, -0.1399733906201436],
        [ 0.9160984176427144,  0.0171331457675763]],

       [[-0.6536456346511841, -0.6314908266067505],
        [-0.3102632761001587,  0.5417718887329102],
        [ 1.1057067309019377,  0.2459597965178402],
        [-2.78280593866222  , -2.279670235885282 ],
        [ 0.3490433187008873,  0.2584693277416903]],

       [[-0.3102632761001587,  0.5417718887329102],
        [ 1.7954634122483244, -0.1587822891046758],
        [ 0.8435877018233144, -1.792088634720315 ],
        [-0.6243843906905057,  1.004054336454013 ],
        [ 1.0931276206988136, -0.8503439391331218]],

       [[-0.0523804016411304,  0.9104548692703247],
        [-0.982227087020874 , -0.4803890585899353],
        [-0.9744316339492798,  0.9138423204421997],
        [-0.8951768279075623, -0.8566049337387085],
        [ 0.6736979484558105, -0.8682780265808105]],

       [[-0.9744316339492798,  0.9138423204421997],
        [-0.8951768279075623, -0.8566049337387085],
        [ 0.6736979484558105, -0.8682780265808105],
        [-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261]],

       [[ 0.6736979484558105, -0.8682780265808105],
        [-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587]],

       [[-0.9753063321113586, -0.1303187161684036],
        [ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587],
        [-0.9776009976312126,  0.0929417044544182]],

       [[ 0.0419043824076653, -0.9880978465080261],
        [-0.8325971961021423, -0.4706742167472839],
        [ 0.4065617024898529, -0.1309379935264587],
        [-1.2604611072245877,  0.3255282013572517],
        [-0.6375371714023212, -0.2479576952053556]],

       [[ 0.4065617024898529, -0.1309379935264587],
        [ 1.0853122572285252, -1.113450095645758 ],
        [ 0.6383299872593756,  0.4615825320021247],
        [-0.4526769121750209, -0.5026150726186746],
        [-0.4429909946263973, -0.6435901784670306]],

       [[ 0.8820233941078186,  0.4469195306301117],
        [ 0.1421182006597519, -0.3563036918640137],
        [ 0.4080510437488556, -0.0753544196486473],
        [-0.9183218479156494, -0.715552031993866 ],
        [ 0.3346443176269531, -0.7762950658798218]],

       [[-0.9183218479156494, -0.715552031993866 ],
        [ 0.3346443176269531, -0.7762950658798218],
        [ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823]],

       [[ 0.3346443176269531, -0.7762950658798218],
        [ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358]],

       [[ 0.3761063516139984,  0.7810404896736145],
        [-0.9150596857070923,  0.6707392930984497],
        [ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693]],

       [[ 0.1624187082052231, -0.0102332159876823],
        [ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693],
        [ 0.1176412274461428, -1.5554193438851447],
        [-0.639407819647646 ,  0.070470209032387 ]],

       [[ 0.9643478393554688,  0.5449203252792358],
        [-0.6978347897529602,  0.3335686028003693],
        [ 0.9371315195670832, -0.5187535949083031],
        [ 1.6102262403242997,  0.7461320383571783],
        [-1.0328825079779835,  0.5870174496427473]],

       [[-0.6978347897529602,  0.3335686028003693],
        [ 0.4909302677064525, -1.8273185492431   ],
        [-1.3642024846362164,  0.4962908880233692],
        [ 0.9805795313754992, -1.0712426831281652],
        [ 0.7456210692367394, -2.0334058101172436]]])

And here is action_batch after transformation:

from torch import tensor
tensor([[[[-0.4491],
          [-0.9950]],

         [[-0.4015],
          [-0.7830]],

         [[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]]],


        [[[-0.4015],
          [-0.7830]],

         [[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]]],


        [[[-0.4214],
          [-0.7040]],

         [[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]]],


        [[[-0.4044],
          [-0.8347]],

         [[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]]],


        [[[-0.1591],
          [-0.2196]],

         [[-0.7565],
          [-0.2843]],

         [[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]]],


        [[[ 0.6579],
          [ 0.8156]],

         [[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[ 0.3157],
          [ 0.1223]]],


        [[[ 0.7341],
          [ 0.8159]],

         [[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[ 0.0447],
          [-0.0061]],

         [[ 2.0763],
          [ 0.0612]]],


        [[[-0.4985],
          [ 0.2341]],

         [[ 0.2148],
          [ 0.8802]],

         [[-0.0754],
          [ 0.0184]],

         [[ 0.2301],
          [-0.1545]],

         [[ 0.8510],
          [ 0.0945]]],


        [[[ 0.0692],
          [-0.5581]],

         [[-0.0967],
          [ 0.6544]],

         [[-0.5564],
          [-0.5217]],

         [[ 0.2279],
          [-0.5080]],

         [[-0.7363],
          [ 0.7372]]],


        [[[-0.5564],
          [-0.5217]],

         [[ 0.2279],
          [-0.5080]],

         [[-0.7363],
          [ 0.7372]],

         [[ 0.8861],
          [ 0.7655]],

         [[-0.8919],
          [ 0.7664]]],


        [[[ 0.8861],
          [ 0.7655]],

         [[-0.8919],
          [ 0.7664]],

         [[-0.8112],
          [ 0.3714]],

         [[ 0.9299],
          [ 0.6214]],

         [[-0.7294],
          [-0.5508]]],


        [[[-0.4489],
          [-0.4065]],

         [[-0.0877],
          [ 0.6241]],

         [[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]]],


        [[[-0.0877],
          [ 0.6241]],

         [[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]],

         [[-0.2952],
          [ 0.7422]]],


        [[[-0.1453],
          [ 0.0135]],

         [[ 0.7303],
          [ 0.3243]],

         [[ 0.1931],
          [-0.6004]],

         [[-0.2952],
          [ 0.7422]],

         [[ 0.3889],
          [-0.6838]]],


        [[[ 0.7991],
          [-0.6135]],

         [[-1.4757],
          [-1.5111]],

         [[-1.0754],
          [-0.0688]],

         [[-0.6506],
          [-0.2693]],

         [[-1.5095],
          [ 0.4624]]],


        [[[ 0.9131],
          [-0.9125]],

         [[-0.6183],
          [ 0.8433]],

         [[ 0.9738],
          [ 0.3281]],

         [[-0.5031],
          [-0.4598]],

         [[ 0.1593],
          [-0.3712]]],


        [[[ 0.9750],
          [-0.9102]],

         [[-0.6536],
          [-0.6315]],

         [[-0.3103],
          [ 0.5418]],

         [[ 0.5857],
          [-0.1400]],

         [[ 0.9161],
          [ 0.0171]]],


        [[[-0.6536],
          [-0.6315]],

         [[-0.3103],
          [ 0.5418]],

         [[ 1.1057],
          [ 0.2460]],

         [[-2.7828],
          [-2.2797]],

         [[ 0.3490],
          [ 0.2585]]],


        [[[-0.3103],
          [ 0.5418]],

         [[ 1.7955],
          [-0.1588]],

         [[ 0.8436],
          [-1.7921]],

         [[-0.6244],
          [ 1.0041]],

         [[ 1.0931],
          [-0.8503]]],


        [[[-0.0524],
          [ 0.9105]],

         [[-0.9822],
          [-0.4804]],

         [[-0.9744],
          [ 0.9138]],

         [[-0.8952],
          [-0.8566]],

         [[ 0.6737],
          [-0.8683]]],


        [[[-0.9744],
          [ 0.9138]],

         [[-0.8952],
          [-0.8566]],

         [[ 0.6737],
          [-0.8683]],

         [[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]]],


        [[[ 0.6737],
          [-0.8683]],

         [[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]]],


        [[[-0.9753],
          [-0.1303]],

         [[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]],

         [[-0.9776],
          [ 0.0929]]],


        [[[ 0.0419],
          [-0.9881]],

         [[-0.8326],
          [-0.4707]],

         [[ 0.4066],
          [-0.1309]],

         [[-1.2605],
          [ 0.3255]],

         [[-0.6375],
          [-0.2480]]],


        [[[ 0.4066],
          [-0.1309]],

         [[ 1.0853],
          [-1.1135]],

         [[ 0.6383],
          [ 0.4616]],

         [[-0.4527],
          [-0.5026]],

         [[-0.4430],
          [-0.6436]]],


        [[[ 0.8820],
          [ 0.4469]],

         [[ 0.1421],
          [-0.3563]],

         [[ 0.4081],
          [-0.0754]],

         [[-0.9183],
          [-0.7156]],

         [[ 0.3346],
          [-0.7763]]],


        [[[-0.9183],
          [-0.7156]],

         [[ 0.3346],
          [-0.7763]],

         [[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]]],


        [[[ 0.3346],
          [-0.7763]],

         [[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]]],


        [[[ 0.3761],
          [ 0.7810]],

         [[-0.9151],
          [ 0.6707]],

         [[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]]],


        [[[ 0.1624],
          [-0.0102]],

         [[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]],

         [[ 0.1176],
          [-1.5554]],

         [[-0.6394],
          [ 0.0705]]],


        [[[ 0.9643],
          [ 0.5449]],

         [[-0.6978],
          [ 0.3336]],

         [[ 0.9371],
          [-0.5188]],

         [[ 1.6102],
          [ 0.7461]],

         [[-1.0329],
          [ 0.5870]]],


        [[[-0.6978],
          [ 0.3336]],

         [[ 0.4909],
          [-1.8273]],

         [[-1.3642],
          [ 0.4963]],

         [[ 0.9806],
          [-1.0712]],

         [[ 0.7456],
          [-2.0334]]]])

from lightzero.

puyuan1996 avatar puyuan1996 commented on June 4, 2024 1

Hello, thank you for your detailed feedback. We have confirmed that this was a redundant operation and it has been fixed in the latest commit on the main branch. Thank you once again for your active contribution.

from lightzero.

puyuan1996 avatar puyuan1996 commented on June 4, 2024

Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 27188cf. BTW, as you can see, there might be a more efficient implementation for this data processing code. We would greatly appreciate it if you could provide an optimized version. Best wishes!

from lightzero.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.