GithubHelp home page GithubHelp logo

chandar-lab / recall2imagine Goto Github PK

View Code? Open in Web Editor NEW
43.0 7.0 3.0 2.35 MB

Recall 2 Imagine, a World Model with superhuman memory. Oral (1.2%) @ ICLR 2024

Home Page: https://recall2imagine.github.io/

License: MIT License

Dockerfile 0.50% Python 98.96% Shell 0.54%
dreamer iclr2024 model-based-reinforcement-learning model-based-rl reinforcement-learning state-space-models deep-reinforcement-learning deep-rl s4 ssm

recall2imagine's Issues

The dmc vision (task=dmc_walker_walk) has very bad performance

Hello, I ran R2I with command:

current_date=$(date "+%Y%m%d-%H%M%S")
python recall2imagine/train.py \
    --configs dmc_vision \
    --ssm_type mimo \
    --wdb_name  dmc_original_${current_date} \
    --logdir ./logs/dmc_original_${current_date}

and got very low scores. According to DreamerV3 paper, it can achieve score > 900.
image

However, R2I can only achieve score < 200.
image

I think the hyperparameters are the same, in R2I's config.yaml, I saw

run.train_ratio: 512
  run.steps: 1e6
  rssm.deter: 512
  .*\.cnn_depth: 32
  .*\.layers: 2
  .*\.units: 512

, this aligns with DreamerV3:
image

Can you explain why this happens? Maybe it's because the ssm backbone is not so good as gru in this task? Or there's sth wrong with my hyperparameters?

atari100k:atari_pong fails on Batcher creation

Running with:

python3 ./recall2imagine/train.py /
  --logdir /logdir/$(date +%Y%m%d-%H%M%S) /
  --configs atari100k tiny
  --task atari_pong
  --jax.platform=cpu

Fails with:

AttributeError: 'function' object has no attribute 'serializer'

Traceback:

Traceback (most recent call last):
  File "./recall2imagine/train.py", line 241, in <module>
    main()
  File "./recall2imagine/train.py", line 81, in main
    embodied.run.train_eval(
  File "/code/recall2imagine/embodied/run/train_eval.py", line 74, in train_eval
    dataset_train = agent.dataset(train_replay.dataset)
  File "/code/recall2imagine/jaxagent.py", line 104, in dataset
    batcher = embodied.BatcherSM(
  File "/code/recall2imagine/embodied/core/batcher.py", line 123, in __init__
    self._serializer = self._replay.serializer
AttributeError: 'function' object has no attribute 'serializer'

NOTE:
on train_eval.py one can change:

dataset_train = agent.dataset(train_replay.dataset)

to:

dataset_train = agent.dataset(train_replay)

which make sense since the LFS_FIFO object has serializer attribute, but then the same errors goes on the next line:

dataset_train = agent.dataset(train_replay)
dataset_eval = agent.dataset(eval_replay.dataset) <-------------

which is of Uniform and does not have serializer.

How to run on a GPU with limited VRAM

Hi, @artemZholus, I want to reproduce your results. But currently O only have an A6000 GPU with about 50GB vram, 48 GB RAM, 256 GB storage. When running

python recall2imagine/train.py \
    --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
    --wdb_name memory_maze_9x9 \
    --logdir ./logs_memory_maze_9x9

I met

Policy devices: gpu:0
Train devices:  gpu:0
2024-04-11 06:36:20.772287: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:439] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2024-04-11 06:36:20.772379: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:443] Memory usage: 9217966080 bytes free, 51033931776 bytes total.
Traceback (most recent call last):
  File "recall2imagine/train.py", line 241, in <module>
    main()
  File "recall2imagine/train.py", line 63, in main
    agent = agt.Agent(env.obs_space, env.act_space, step, config)
  File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 20, in __init__
    super().__init__(agent_cls, *args, **kwargs)
  File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 50, in __init__
    self.varibs = self._init_varibs(obs_space, act_space)
  File "/home/lyk/Projects/Recall2Imagine/recall2imagine/jaxagent.py", line 245, in _init_varibs
    state, varibs = self._init_train(varibs, rng, data['is_first'])
  File "/home/lyk/Projects/Recall2Imagine/recall2imagine/ninjax.py", line 199, in wrapper
    created = init(statics, rng, *args, **kw)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 250, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr = _python_pjit_helper(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 163, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 2677, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/core.py", line 815, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1203, in _pjit_call_impl
    return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums,
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1187, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/pjit.py", line 1120, in _pjit_call_impl_python
    compiled = _pjit_lower(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2323, in compile
    executable = UnloadedMeshExecutable.from_hlo(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2645, in from_hlo
    xla_executable, compile_options = _cached_compilation(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/interpreters/pxla.py", line 2555, in _cached_compilation
    xla_executable = dispatch.compile_or_get_cached(
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 497, in compile_or_get_cached
    return backend_compile(backend, computation, compile_options,
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

I think it is because my vram is too limited. Is there any way to run your code in this case?

POPGym performance issue.

Hi, @artemZholus I'm now reproducing your projec(popgym) in my PC.
The rep and dyn loss seems to be stable While the score is very different from your figure, Especially the steps after 40M
image
After searching, I've find that the act_opt_loss is very bad ,You can see it as here.

image

I'm just using the code and config as your repo, using the command you provided.
Notice: The running was once interupt at 20M and i rerun it ,It resume to run from 20M to 120 M
Could you help me ?

Getting raw data for plots

Hello, I wonder how can we get/generate the scores from the experiments of r2i. In the code, I saw the JSON logger writing to two files: 'metrics.jsonl' and 'scores.jsonl'. Neither of them has the same content format as the .json files under ./score in original DreamerV3 code, where each uncompressed .json file has the format:

task": "atari_battle_zone",
        "method": "dreamerv3",
        "seed": "3",
        "xs": [ ...],
        "ys": [ ...],

I'm trying to reproduce the results of r2i and am currently trying to plotting the performance curve. Thanks for your help!

AssertionError: {'action': array([nan, nan, nan, nan, nan, nan], dtype=float32), 'reset': array(False)}

When I use the default config.yaml to run the experiments of Memory Maze and POPGym, the following error occurs:
Process env: Traceback (most recent call last): File "/conda/miniconda/envs/py38/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap self.run() File "/conda/miniconda/envs/py38/lib/python3.8/multiprocessing/process.py", line 108, in run self._target(*self._args, **self._kwargs) File "/code/recall2imagine/embodied/core/distr.py", line 199, in _wrapper fn(*args) File "/code/recall2imagine/embodied/run/parallel.py", line 172, in env obs = env.step(act) File "/code/recall2imagine/embodied/core/wrappers.py", line 376, in step obs = self.env.step(action) File "/code/recall2imagine/embodied/core/wrappers.py", line 160, in step obs = self.env.step(action) File "/code/recall2imagine/embodied/core/wrappers.py", line 115, in step assert action[self._key].min() == 0.0, action AssertionError: {'action': array([nan, nan, nan, nan, nan, nan], dtype=float32), 'reset': array(False)}
It seems to be related to env reset. Could you give me some suggestions on why the error occurs and how to solve it? Looking forward to your reply.

CuDNN library incompatible error

I used the conda env and have installed jaxlib 0.4.13 with cuda11.cudnn86. When I run the code, the error occurred:
E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:427] Loaded runtime CuDNN library: 8.5.0 but source was compiled with: 8.6.0. CuDNN library needs to have matching major version and equal or higher minor version. If using a binary install, upgrade your CuDNN library. If building from sources, make sure the library loaded at runtime is compatible with the version specified during compile configuration.

About the input of policy

In the paper, there is a state that

leverage the output state policy in nonmemory environments and the full state policy or hidden state policy within memory environments.

But in the configs.yaml, there is only in the POPGym task where the actor.inputs is set to be [stoch, hidden]. Why? Looking forward to your reply.

Flavour that was used in official metrics

Hi, I am interested to verify what flavour was used in the official metrics in the paper.

E.g is it train, train_eval, train_save, train_holdout or parallel?

Thanks in advance

OpenGL error on the headless server

Hi, I tried to run the command

python recall2imagine/train.py \
    --configs mmaze --task gym_memory_maze:MemoryMaze-9x9-v0 \
    --wdb_name memory_maze_9x9 \
    --logdir ./logs_memory_maze_9x9

on my headless server. But I got this error:

Traceback (most recent call last):
  File "recall2imagine/train.py", line 6, in <module>
    import mujoco
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/__init__.py", line 47, in <module>
    from mujoco.gl_context import *
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/gl_context.py", line 38, in <module>
    from mujoco.osmesa import GLContext as _GLContext
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/mujoco/osmesa/__init__.py", line 31, in <module>
    from OpenGL import GL
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/GL/__init__.py", line 4, in <module>
    from OpenGL.GL.VERSION.GL_1_1 import *
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/GL/VERSION/GL_1_1.py", line 14, in <module>
    from OpenGL.raw.GL.VERSION.GL_1_1 import *
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/raw/GL/VERSION/GL_1_1.py", line 7, in <module>
    from OpenGL.raw.GL import _errors
  File "/home/lyk/miniconda3/envs/recall2imagine/lib/python3.8/site-packages/OpenGL/raw/GL/_errors.py", line 4, in <module>
    _error_checker = _ErrorChecker( _p, _p.GL.glGetError )
AttributeError: 'NoneType' object has no attribute 'glGetError'

Is there any way to solve this?
Thanks!

I have installed all the dependencies specified in the readme.md. My system config is: Ubuntu22.04, x86_64 with CUDA Version: 12.4.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.