Comments (4)
Of course, the current modifications are made for the sake of program compatibility. The better approach is to identify the location where the shape of the action tensor changes from torch.Size([128, 2])
to torch.Size([128, 2, 1])
. We would greatly appreciate it if you could debug and locate the corresponding position.
from lightzero.
Identifying the problem
I found the problem. I'm using a batch size of 32 in these examples, and a continuous action space size of 2.
At this point in the code, the action space for the whole batch seems properly shaped:
shape of action_batch
: (32, 5, 2)
Then, there's this line:
LightZero/lzero/policy/sampled_efficientzero.py
Lines 333 to 335 in 95e94b9
shape of action_batch
: torch.Size([32, 5, 2, 1])
This is a problem the next time action_batch
is used:
LightZero/lzero/policy/sampled_efficientzero.py
Lines 416 to 418 in 95e94b9
shape of action_batch[:, step_k]
: torch.Size([32, 2, 1])
Testing variables
Here is the initial action_batch
:
from numpy import array
array([[[-0.4490683972835541, -0.9950420260429382],
[-0.4014930129051208, -0.7829872369766235],
[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064]],
[[-0.4014930129051208, -0.7829872369766235],
[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019]],
[[-0.4213773608207703, -0.7039519548416138],
[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556]],
[[-0.4044414460659027, -0.834709882736206 ],
[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177]],
[[-0.1590566784143448, -0.2195666283369064],
[-0.7565191984176636, -0.2842720746994019],
[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823]],
[[ 0.6578846573829651, 0.8156458139419556],
[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[ 0.3157244141177024, 0.122345928192084 ]],
[[ 0.7340825796127319, 0.8159331679344177],
[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[ 0.0447234260510479, -0.0061242674195186],
[ 2.076320333718342 , 0.0612052095300819]],
[[-0.4984613955020905, 0.2341146618127823],
[ 0.2147858291864395, 0.8802192211151123],
[-0.0753794324983024, 0.018401418199892 ],
[ 0.230092215421578 , -0.154497093666865 ],
[ 0.8510225575898217, 0.0945327318063604]],
[[ 0.0692359060049057, -0.5581157207489014],
[-0.0967390909790993, 0.6544007062911987],
[-0.5564347505569458, -0.5217161774635315],
[ 0.227927178144455 , -0.5079849362373352],
[-0.7363128662109375, 0.7372166514396667]],
[[-0.5564347505569458, -0.5217161774635315],
[ 0.227927178144455 , -0.5079849362373352],
[-0.7363128662109375, 0.7372166514396667],
[ 0.8861203789710999, 0.7655220031738281],
[-0.8919497132301331, 0.7664076685905457]],
[[ 0.8861203789710999, 0.7655220031738281],
[-0.8919497132301331, 0.7664076685905457],
[-0.8111757040023804, 0.3714333176612854],
[ 0.9299460443093313, 0.62140516600195 ],
[-0.7294452606963147, -0.5507582288642449]],
[[-0.4489460289478302, -0.4065003991127014],
[-0.0877294093370438, 0.6240969896316528],
[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917]],
[[-0.0877294093370438, 0.6240969896316528],
[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917],
[-0.2952068150043488, 0.7421513199806213]],
[[-0.145333394408226 , 0.0135438507422805],
[ 0.7303228378295898, 0.324310839176178 ],
[ 0.1930731981992722, -0.6004103422164917],
[-0.2952068150043488, 0.7421513199806213],
[ 0.3888700902462006, -0.6838327646255493]],
[[ 0.7991305589675903, -0.6135419011116028],
[-1.4756834812262265, -1.5111042458027768],
[-1.0753735258729593, -0.0687822111641535],
[-0.6506161214320586, -0.2692845067856774],
[-1.509456993847584 , 0.4623811175903334]],
[[ 0.9131344556808472, -0.9124650955200195],
[-0.6182714104652405, 0.8433078527450562],
[ 0.9737944006919861, 0.3280912041664124],
[-0.503142774105072 , -0.4597557783126831],
[ 0.1592527478933334, -0.3711529672145844]],
[[ 0.97501140832901 , -0.9102342128753662],
[-0.6536456346511841, -0.6314908266067505],
[-0.3102632761001587, 0.5417718887329102],
[ 0.5856770137135533, -0.1399733906201436],
[ 0.9160984176427144, 0.0171331457675763]],
[[-0.6536456346511841, -0.6314908266067505],
[-0.3102632761001587, 0.5417718887329102],
[ 1.1057067309019377, 0.2459597965178402],
[-2.78280593866222 , -2.279670235885282 ],
[ 0.3490433187008873, 0.2584693277416903]],
[[-0.3102632761001587, 0.5417718887329102],
[ 1.7954634122483244, -0.1587822891046758],
[ 0.8435877018233144, -1.792088634720315 ],
[-0.6243843906905057, 1.004054336454013 ],
[ 1.0931276206988136, -0.8503439391331218]],
[[-0.0523804016411304, 0.9104548692703247],
[-0.982227087020874 , -0.4803890585899353],
[-0.9744316339492798, 0.9138423204421997],
[-0.8951768279075623, -0.8566049337387085],
[ 0.6736979484558105, -0.8682780265808105]],
[[-0.9744316339492798, 0.9138423204421997],
[-0.8951768279075623, -0.8566049337387085],
[ 0.6736979484558105, -0.8682780265808105],
[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261]],
[[ 0.6736979484558105, -0.8682780265808105],
[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587]],
[[-0.9753063321113586, -0.1303187161684036],
[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587],
[-0.9776009976312126, 0.0929417044544182]],
[[ 0.0419043824076653, -0.9880978465080261],
[-0.8325971961021423, -0.4706742167472839],
[ 0.4065617024898529, -0.1309379935264587],
[-1.2604611072245877, 0.3255282013572517],
[-0.6375371714023212, -0.2479576952053556]],
[[ 0.4065617024898529, -0.1309379935264587],
[ 1.0853122572285252, -1.113450095645758 ],
[ 0.6383299872593756, 0.4615825320021247],
[-0.4526769121750209, -0.5026150726186746],
[-0.4429909946263973, -0.6435901784670306]],
[[ 0.8820233941078186, 0.4469195306301117],
[ 0.1421182006597519, -0.3563036918640137],
[ 0.4080510437488556, -0.0753544196486473],
[-0.9183218479156494, -0.715552031993866 ],
[ 0.3346443176269531, -0.7762950658798218]],
[[-0.9183218479156494, -0.715552031993866 ],
[ 0.3346443176269531, -0.7762950658798218],
[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823]],
[[ 0.3346443176269531, -0.7762950658798218],
[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358]],
[[ 0.3761063516139984, 0.7810404896736145],
[-0.9150596857070923, 0.6707392930984497],
[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693]],
[[ 0.1624187082052231, -0.0102332159876823],
[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693],
[ 0.1176412274461428, -1.5554193438851447],
[-0.639407819647646 , 0.070470209032387 ]],
[[ 0.9643478393554688, 0.5449203252792358],
[-0.6978347897529602, 0.3335686028003693],
[ 0.9371315195670832, -0.5187535949083031],
[ 1.6102262403242997, 0.7461320383571783],
[-1.0328825079779835, 0.5870174496427473]],
[[-0.6978347897529602, 0.3335686028003693],
[ 0.4909302677064525, -1.8273185492431 ],
[-1.3642024846362164, 0.4962908880233692],
[ 0.9805795313754992, -1.0712426831281652],
[ 0.7456210692367394, -2.0334058101172436]]])
And here is action_batch
after transformation:
from torch import tensor
tensor([[[[-0.4491],
[-0.9950]],
[[-0.4015],
[-0.7830]],
[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]]],
[[[-0.4015],
[-0.7830]],
[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]]],
[[[-0.4214],
[-0.7040]],
[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]]],
[[[-0.4044],
[-0.8347]],
[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]]],
[[[-0.1591],
[-0.2196]],
[[-0.7565],
[-0.2843]],
[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]]],
[[[ 0.6579],
[ 0.8156]],
[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[ 0.3157],
[ 0.1223]]],
[[[ 0.7341],
[ 0.8159]],
[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[ 0.0447],
[-0.0061]],
[[ 2.0763],
[ 0.0612]]],
[[[-0.4985],
[ 0.2341]],
[[ 0.2148],
[ 0.8802]],
[[-0.0754],
[ 0.0184]],
[[ 0.2301],
[-0.1545]],
[[ 0.8510],
[ 0.0945]]],
[[[ 0.0692],
[-0.5581]],
[[-0.0967],
[ 0.6544]],
[[-0.5564],
[-0.5217]],
[[ 0.2279],
[-0.5080]],
[[-0.7363],
[ 0.7372]]],
[[[-0.5564],
[-0.5217]],
[[ 0.2279],
[-0.5080]],
[[-0.7363],
[ 0.7372]],
[[ 0.8861],
[ 0.7655]],
[[-0.8919],
[ 0.7664]]],
[[[ 0.8861],
[ 0.7655]],
[[-0.8919],
[ 0.7664]],
[[-0.8112],
[ 0.3714]],
[[ 0.9299],
[ 0.6214]],
[[-0.7294],
[-0.5508]]],
[[[-0.4489],
[-0.4065]],
[[-0.0877],
[ 0.6241]],
[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]]],
[[[-0.0877],
[ 0.6241]],
[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]],
[[-0.2952],
[ 0.7422]]],
[[[-0.1453],
[ 0.0135]],
[[ 0.7303],
[ 0.3243]],
[[ 0.1931],
[-0.6004]],
[[-0.2952],
[ 0.7422]],
[[ 0.3889],
[-0.6838]]],
[[[ 0.7991],
[-0.6135]],
[[-1.4757],
[-1.5111]],
[[-1.0754],
[-0.0688]],
[[-0.6506],
[-0.2693]],
[[-1.5095],
[ 0.4624]]],
[[[ 0.9131],
[-0.9125]],
[[-0.6183],
[ 0.8433]],
[[ 0.9738],
[ 0.3281]],
[[-0.5031],
[-0.4598]],
[[ 0.1593],
[-0.3712]]],
[[[ 0.9750],
[-0.9102]],
[[-0.6536],
[-0.6315]],
[[-0.3103],
[ 0.5418]],
[[ 0.5857],
[-0.1400]],
[[ 0.9161],
[ 0.0171]]],
[[[-0.6536],
[-0.6315]],
[[-0.3103],
[ 0.5418]],
[[ 1.1057],
[ 0.2460]],
[[-2.7828],
[-2.2797]],
[[ 0.3490],
[ 0.2585]]],
[[[-0.3103],
[ 0.5418]],
[[ 1.7955],
[-0.1588]],
[[ 0.8436],
[-1.7921]],
[[-0.6244],
[ 1.0041]],
[[ 1.0931],
[-0.8503]]],
[[[-0.0524],
[ 0.9105]],
[[-0.9822],
[-0.4804]],
[[-0.9744],
[ 0.9138]],
[[-0.8952],
[-0.8566]],
[[ 0.6737],
[-0.8683]]],
[[[-0.9744],
[ 0.9138]],
[[-0.8952],
[-0.8566]],
[[ 0.6737],
[-0.8683]],
[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]]],
[[[ 0.6737],
[-0.8683]],
[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]]],
[[[-0.9753],
[-0.1303]],
[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]],
[[-0.9776],
[ 0.0929]]],
[[[ 0.0419],
[-0.9881]],
[[-0.8326],
[-0.4707]],
[[ 0.4066],
[-0.1309]],
[[-1.2605],
[ 0.3255]],
[[-0.6375],
[-0.2480]]],
[[[ 0.4066],
[-0.1309]],
[[ 1.0853],
[-1.1135]],
[[ 0.6383],
[ 0.4616]],
[[-0.4527],
[-0.5026]],
[[-0.4430],
[-0.6436]]],
[[[ 0.8820],
[ 0.4469]],
[[ 0.1421],
[-0.3563]],
[[ 0.4081],
[-0.0754]],
[[-0.9183],
[-0.7156]],
[[ 0.3346],
[-0.7763]]],
[[[-0.9183],
[-0.7156]],
[[ 0.3346],
[-0.7763]],
[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]]],
[[[ 0.3346],
[-0.7763]],
[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]]],
[[[ 0.3761],
[ 0.7810]],
[[-0.9151],
[ 0.6707]],
[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]]],
[[[ 0.1624],
[-0.0102]],
[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]],
[[ 0.1176],
[-1.5554]],
[[-0.6394],
[ 0.0705]]],
[[[ 0.9643],
[ 0.5449]],
[[-0.6978],
[ 0.3336]],
[[ 0.9371],
[-0.5188]],
[[ 1.6102],
[ 0.7461]],
[[-1.0329],
[ 0.5870]]],
[[[-0.6978],
[ 0.3336]],
[[ 0.4909],
[-1.8273]],
[[-1.3642],
[ 0.4963]],
[[ 0.9806],
[-1.0712]],
[[ 0.7456],
[-2.0334]]]])
from lightzero.
Hello, thank you for your detailed feedback. We have confirmed that this was a redundant operation and it has been fixed in the latest commit on the main branch. Thank you once again for your active contribution.
from lightzero.
Hello, thank you for your feedback. We have identified this bug and it has now been fixed in the latest commit 27188cf. BTW, as you can see, there might be a more efficient implementation for this data processing code. We would greatly appreciate it if you could provide an optimized version. Best wishes!
from lightzero.
Related Issues (20)
- No module named 'lzero.worker.gumbel_muzero_collector' HOT 1
- gumbel_muzero error HOT 1
- Installation fails on MacBook M1 Pro HOT 6
- alphazero MCTS not working: cannot import mcts_alphazero HOT 4
- Confusion between "battle_mode" and "mcts_mode" HOT 2
- AttributeError: 'EasyDict' object has no attribute 'replay_path_gif' HOT 2
- Is there a missing .gitmodules file? HOT 2
- A typo in the comment of _ucb_score HOT 2
- [action_mask error] HOT 6
- Sampled MuZero and Sampled EfficientZero HOT 3
- Default lunar lander settings result in RuntimeError during model evaluation HOT 2
- Bipedal continuous discretized sampled efficientzero config error HOT 2
- Tensors on different devices when using GPU (SampledEfficientZeroPolicy) HOT 2
- gomoku muzero self play train problem HOT 1
- `SampledEfficientZeroModel` does not pass `lstm_hidden_size` through `DynamicsNetwork` HOT 2
- Question about gumbel_scale and dirichlet noise in Gumbel MuZero HOT 1
- Does `downsample = True` lead to masking input data? HOT 1
- JAX support HOT 3
- how to help ai learn faster and better HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightzero.