yukunxia / carla_ilqr_mpc Goto Github PK
View Code? Open in Web Editor NEWImplementation of the real-time MPC based on iLQR in Carla simulator
License: MIT License
Implementation of the real-time MPC based on iLQR in Carla simulator
License: MIT License
First of all, thank you for open-sourcing this wonderful project!
You mentioned the running frequency could be higher than 1000Hz. Do you mean the MPC control frequency? Could you explain what parameters you use to achieve this speed?
For me, I use MPC Horizon 10 and Max Iteration 300 but only achieve 10Hz. MPC generating one control action costs 0.1s.
Is this a normal situation? Are there other parameters I should adjust?
Your ilqr_jax_MPC code is not running I tried some modifications in code but still it is not working in Jax 0.2.9
D:\CARLA_0.9.5\PythonAPI\examples>py -3.7 ilqr_jax_MPC.py
pygame 2.0.1 (SDL 2.0.14, Python 3.7.0)
Hello from the pygame community. https://www.pygame.org/contribute.html
call reset
0%| | 0/2000 [00:00<?, ?it/s]WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
0%| | 0/2000 [00:00<?, ?it/s]
Traceback (most recent call last):
File "ilqr_jax_MPC.py", line 385, in
x_trj, u_trj, cost_trace = run_ilqr_main(state, u_trj, waypoints)
File "ilqr_jax_MPC.py", line 292, in run_ilqr_main
x_trj = rollout(x0, u_trj)
File "ilqr_jax_MPC.py", line 99, in rollout
return np.array((x0, x_trj,x_final))
File "C:\Users\Max\AppData\Local\Programs\Python\Python37\lib\site-packages\jax_src\numpy\lax_numpy.py", line 2859, in array
out = stack([asarray(elt, dtype=dtype) for elt in object])
File "C:\Users\Max\AppData\Local\Programs\Python\Python37\lib\site-packages\jax_src\numpy\lax_numpy.py", line 2859, in
out = stack([asarray(elt, dtype=dtype) for elt in object])
File "C:\Users\Max\AppData\Local\Programs\Python\Python37\lib\site-packages\jax_src\numpy\lax_numpy.py", line 2887, in asarray
return array(a, dtype=dtype, copy=False, order=order)
File "C:\Users\Max\AppData\Local\Programs\Python\Python37\lib\site-packages\jax_src\numpy\lax_numpy.py", line 2859, in array
out = stack([asarray(elt, dtype=dtype) for elt in object])
File "C:\Users\Max\AppData\Local\Programs\Python\Python37\lib\site-packages\jax_src\numpy\lax_numpy.py", line 2671, in stack
raise ValueError("All input arrays must have the same shape.")
jax._src.traceback_util.FilteredStackTrace: ValueError: All input arrays must have the same shape.
Hi,
Can you tell what are the dependencies and their versions?
I cannot run the code.
Hi, I'm trying to run this code in my computer which is ubuntu 18.04 and meet some problems.
pygame 2.1.2 (SDL 2.0.16, Python 3.8.13)
Hello from the pygame community. https://www.pygame.org/contribute.html
Traceback (most recent call last):
File "ilqr_jax_MPC.py", line 164, in
jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init()
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/src/api.py", line 527, in cache_miss
out_flat = xla.xla_call(
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1937, in bind
return call_bind(self, fun, *args, **params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1953, in call_bind
outs = top_trace.process_call(primitive, fun, tracers, params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 687, in process_call
return primitive.impl(f, *tracers, **params)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 208, in _xla_call_impl
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/linear_util.py", line 295, in memoized_fun
ans = call(fun, *args)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 257, in _xla_callable_uncached
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/dispatch.py", line 302, in lower_xla_callable
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/profiler.py", line 294, in wrapper
return func(*args, **kwargs)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2188, in trace_to_jaxpr_final2
jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 2139, in trace_to_subjaxpr_dynamic2
out_tracers = map(trace.full_raise, ans)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/_src/util.py", line 47, in safe_map
return list(map(f, *args))
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 415, in full_raise
return self.pure(val)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/interpreters/partial_eval.py", line 1761, in new_const
aval = raise_to_shaped(get_aval(c), weak_type=dtypes.is_weakly_typed(c))
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1163, in get_aval
return concrete_aval(x)
File "/home/hj/anaconda3/envs/carla/lib/python3.8/site-packages/jax/core.py", line 1155, in concrete_aval
raise TypeError(f"Value {repr(x)} with type {type(x)} is not a valid JAX "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Value <CompiledFunction of <function jacfwd..jacfun at 0x7fe3d73ac670>> with type <class 'jaxlib.xla_extension.CompiledFunction'> is not a valid JAX type
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "ilqr_jax_MPC.py", line 164, in
jac_l, hes_l, jac_l_final, hes_l_final, jac_f = derivative_init()
The jax is 0.3.16 and jaxlib is 0.3.15 with CUDA11.1 and cuDNN8.0.5, do you know how to fix these problemes?
Thank you very much and looking forward to your reply.
Hi @Tanman1234 @YukunXia, thanks for the discussion here and thank @YukunXia for the code.
I also have problem with jaxlib after installing jax 0.1.68, specifically a ImportError: cannot import name 'pytree' from 'jaxlib'.
I am working on ubuntu 18.04.
Here is what i have done for installation of jax 0.1.68 by following https://github.com/google/jax/tree/jaxlib-v0.1.68:
pip install --upgrade pip
sudo ln -s /path/to/cuda /usr/local/cuda-11.1
pip install --upgrade jax==0.1.68 jaxlib==0.1.67+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
-result of installation:
Successfully installed jax-0.1.68 jaxlib-0.1.67+cuda111
-run lqr_jax_MPC.py and get ImportError:
Traceback (most recent call last):
File "/home/control/Documents/Carla projects/Carla_iLQR_MPC/MPC/ilqr_jax_MPC.py", line 1, in <module>
from jax import jit, jacfwd, jacrev, hessian, lax
File "/home/control/.local/lib/python3.7/site-packages/jax/__init__.py", line 16, in <module>
from .api import (
File "/home/control/.local/lib/python3.7/site-packages/jax/api.py", line 38, in <module>
from . import core
File "/home/control/.local/lib/python3.7/site-packages/jax/core.py", line 30, in <module>
from . import dtypes
File "/home/control/.local/lib/python3.7/site-packages/jax/dtypes.py", line 31, in <module>
from .lib import xla_client
File "/home/control/.local/lib/python3.7/site-packages/jax/lib/__init__.py", line 51, in <module>
from jaxlib import pytree
ImportError: cannot import name 'pytree' from 'jaxlib' (/home/control/.local/lib/python3.7/site-packages/jaxlib/__init__.py)
For your info, Cuda version 11.5 is installed. cudnn-linux-x86_64-8.3.1.22_cuda11.5 is downloaded and the symlinks of its files in /include and /lib are copied in usr/local/cuda/include and usr/local/cuda/lib
with:
cd folder/extracted/cdnn_contents
sudo cp include/cudnn.h /usr/local/cuda/include
sudo cp lib/libcudnn* /usr/local/cuda/lib64
sudo chmod a+r /usr/local/cuda/lib64/libcudnn*
I have also tried to install jaxlib from source but not succeed, it came up with this problem: https://stackoverflow.com/questions/70324228/how-to-deal-with-error-infinite-symlink-expansion-detected-in-building-jax-from
Also tried to install jaxlib and jax==0.2.14, 0.2.16 with pip and the lqr_jax_MPC.py still shows different Errors.
Do you know how to deal with these pitfalls?
Many thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.