chandar-lab / recall2imagine Goto Github PK
View Code? Open in Web Editor NEWRecall 2 Imagine, a World Model with superhuman memory. Oral (1.2%) @ ICLR 2024
Home Page: https://recall2imagine.github.io/
License: MIT License
Recall 2 Imagine, a World Model with superhuman memory. Oral (1.2%) @ ICLR 2024
Home Page: https://recall2imagine.github.io/
License: MIT License
Hello, I ran R2I with command:
current_date=$(date "+%Y%m%d-%H%M%S")
python recall2imagine/train.py \
--configs dmc_vision \
--ssm_type mimo \
--wdb_name dmc_original_${current_date} \
--logdir ./logs/dmc_original_${current_date}
and got very low scores. According to DreamerV3 paper, it can achieve score > 900.
However, R2I can only achieve score < 200.
I think the hyperparameters are the same, in R2I's config.yaml
, I saw
run.train_ratio: 512
run.steps: 1e6
rssm.deter: 512
.*\.cnn_depth: 32
.*\.layers: 2
.*\.units: 512
Can you explain why this happens? Maybe it's because the ssm backbone is not so good as gru in this task? Or there's sth wrong with my hyperparameters?
Running with:
python3 ./recall2imagine/train.py /
--logdir /logdir/$(date +%Y%m%d-%H%M%S) /
--configs atari100k tiny
--task atari_pong
--jax.platform=cpu
Fails with:
AttributeError: 'function' object has no attribute 'serializer'
Traceback:
Traceback (most recent call last):
File "./recall2imagine/train.py", line 241, in <module>
main()
File "./recall2imagine/train.py", line 81, in main
embodied.run.train_eval(
File "/code/recall2imagine/embodied/run/train_eval.py", line 74, in train_eval
dataset_train = agent.dataset(train_replay.dataset)
File "/code/recall2imagine/jaxagent.py", line 104, in dataset
batcher = embodied.BatcherSM(
File "/code/recall2imagine/embodied/core/batcher.py", line 123, in __init__
self._serializer = self._replay.serializer
AttributeError: 'function' object has no attribute 'serializer'
NOTE:
on train_eval.py one can change:
dataset_train = agent.dataset(train_replay.dataset)
to:
dataset_train = agent.dataset(train_replay)
which make sense since the LFS_FIFO object has serializer attribute, but then the same errors goes on the next line:
dataset_train = agent.dataset(train_replay)
dataset_eval = agent.dataset(eval_replay.dataset) <-------------
which is of Uniform and does not have serializer.
Hi, @artemZholus, I want to reproduce your results. But currently O only have an A6000 GPU with about 50GB vram, 48 GB RAM, 256 GB storage. When running
python recall2imagine/train.py \
--configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
--wdb_name memory_maze_9x9 \
--logdir ./logs_memory_maze_9x9
I met
Policy devices: gpu:0
Train devices: gpu:0
2024-04-11 06:36:20.772287: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-04-11 06:36:20.772379: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 9217966080 bytes free, 51033931776 bytes total.
Traceback (most recent call last):
File "recall2imagine/train.py", line 241, in <module>
main()
File "recall2imagine/train.py", line 63, in main
agent = agt.Agent(env.obs_space, env.act_space, step, config)
File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in __init__
super().__init__(agent_cls, *args, **kwargs)
File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in __init__
self.varibs = self._init_varibs(obs_space, act_space)
File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
state, varibs = self._init_train(varibs, rng, data['is_first'])
File "/home/lyk/Projects/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
created = init(statics, rng, *args, **kw)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
return self.bind_with_trace(top_trace, args, params)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
return primitive.impl(*tracers, **params)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
compiled = _pjit_lower(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
executable = UnloadedMeshExecutable.from_hlo(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
xla_executable, compile_options = _cached_compilation(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
xla_executable = dispatch.compile_or_get_cached(
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
return backend_compile(backend, computation, compile_options,
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
I think it is because my vram is too limited. Is there any way to run your code in this case?
Hi, @artemZholus I'm now reproducing your projec(popgym) in my PC.
The rep and dyn loss seems to be stable While the score is very different from your figure, Especially the steps after 40M
After searching, I've find that the act_opt_loss is very bad ,You can see it as here.
I'm just using the code and config as your repo, using the command you provided.
Notice: The running was once interupt at 20M and i rerun it ,It resume to run from 20M to 120 M
Could you help me ?
Hello, I wonder how can we get/generate the scores from the experiments of r2i. In the code, I saw the JSON logger writing to two files: 'metrics.jsonl' and 'scores.jsonl'. Neither of them has the same content format as the .json
files under ./score
in original DreamerV3 code, where each uncompressed .json
file has the format:
task": "atari_battle_zone",
"method": "dreamerv3",
"seed": "3",
"xs": [ ...],
"ys": [ ...],
I'm trying to reproduce the results of r2i and am currently trying to plotting the performance curve. Thanks for your help!
When I use the default config.yaml to run the experiments of Memory Maze and POPGym, the following error occurs:
Process env: Traceback (most recent call last): File "/conda/miniconda/envs/py38/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap self.run() File "/conda/miniconda/envs/py38/lib/python3.8/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/code/recall2imagine/embodied/core/distr.py", line 199, in _wrapper fn(*args) File "/code/recall2imagine/embodied/run/parallel.py", line 172, in env obs = env.step(act) File "/code/recall2imagine/embodied/core/wrappers.py", line 376, in step obs = self.env.step(action) File "/code/recall2imagine/embodied/core/wrappers.py", line 160, in step obs = self.env.step(action) File "/code/recall2imagine/embodied/core/wrappers.py", line 115, in step assert action[self._key].min() == 0.0, action AssertionError: {'action': array([nan, nan, nan, nan, nan, nan], dtype=float32), 'reset': array(False)}
It seems to be related to env reset. Could you give me some suggestions on why the error occurs and how to solve it? Looking forward to your reply.
I used the conda env and have installed jaxlib 0.4.13 with cuda11.cudnn86. When I run the code, the error occurred:
E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.
In the paper, there is a state that
leverage the output state policy in nonmemory environments and the full state policy or hidden state policy within memory environments.
But in the configs.yaml, there is only in the POPGym task where the actor.inputs is set to be [stoch, hidden]. Why? Looking forward to your reply.
Hi, I am interested to verify what flavour was used in the official metrics in the paper.
E.g is it train, train_eval, train_save, train_holdout or parallel?
Thanks in advance
Hi, I tried to run the command
python recall2imagine/train.py \
--configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
--wdb_name memory_maze_9x9 \
--logdir ./logs_memory_maze_9x9
on my headless server. But I got this error:
Traceback (most recent call last):
File "recall2imagine/train.py", line 6, in <module>
import mujoco
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/__init__.py", line 47, in <module>
from mujoco.gl_context import *
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/gl_context.py", line 38, in <module>
from mujoco.osmesa import GLContext as _GLContext
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/osmesa/__init__.py", line 31, in <module>
from OpenGL import GL
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/GL/__init__.py", line 4, in <module>
from OpenGL.GL.VERSION.GL_1_1 import *
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/GL/VERSION/GL_1_1.py", line 14, in <module>
from OpenGL.raw.GL.VERSION.GL_1_1 import *
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/raw/GL/VERSION/GL_1_1.py", line 7, in <module>
from OpenGL.raw.GL import _errors
File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/raw/GL/_errors.py", line 4, in <module>
_error_checker = _ErrorChecker( _p, _p.GL.glGetError )
AttributeError: 'NoneType' object has no attribute 'glGetError'
Is there any way to solve this?
Thanks!
I have installed all the dependencies specified in the readme.md
. My system config is: Ubuntu22.04, x86_64 with CUDA Version: 12.4.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.