Auto-assigning NUTS sampler...
INFO:pymc3:Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
INFO:pymc3:Initializing NUTS using jitter+adapt_diag...
---------------------------------------------------------------------------
FilteredStackTrace Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
21 with model:
---> 22 pm.sample(mode=jax_mode)
~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
480 _log.info("Auto-assigning NUTS sampler...")
--> 481 start_, step = init_nuts(
482 init=init,
~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
2133
-> 2134 step = pm.NUTS(potential=potential, model=model, **kwargs)
2135
~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
167 """
--> 168 super().__init__(vars, **kwargs)
169
~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
92
---> 93 super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
94
~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
253 q = func.dict_to_array(model.test_point)
--> 254 logp, dlogp = func(q)
255 except ValueError:
~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
738
--> 739 output = self._theano_function(array)
740 if grad_out is None:
~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
978 outputs = (
--> 979 self.fn()
980 if output_subset is None
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
702 ):
--> 703 thunk()
704 for old_s in old_storage:
~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
653 ):
--> 654 outputs = [
655 jax_impl_jit(*[x[0] for x in thunk_inputs])
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
654 outputs = [
--> 655 jax_impl_jit(*[x[0] for x in thunk_inputs])
656 for jax_impl_jit in jax_impl_jits
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
126 func_args = [fn(*inputs) for fn in input_funcs]
--> 127 return return_func(*func_args)
128
~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
545 def reshape(x, shape):
--> 546 return jnp.reshape(x, shape)
547
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
1145 try:
-> 1146 return a.reshape(newshape, order=order) # forward to method for ndarrays
1147 except AttributeError:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
1191 newshape = newshape[0]
-> 1192 return _reshape(a, newshape, order=order)
1193
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
1167 def _reshape(a, newshape, order="C"):
-> 1168 computed_newshape = _compute_newshape(a, newshape)
1169 if order == "C":
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
1156 def check(size):
-> 1157 return size if type(size) is Poly else core.concrete_or_error(
1158 int, size, "The error arose in jax.numpy.reshape.")
FilteredStackTrace: jax.core.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:
operation yn:bool[] = lt yl:int64[] ym:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
operation yq:int64[] = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = select a b c
in (d,) }
device=None
donated_invars=(False, False, False)
name=_where ] yn:bool[] yo:int64[] yp:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
ConcretizationTypeError Traceback (most recent call last)
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
702 ):
--> 703 thunk()
704 for old_s in old_storage:
~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
653 ):
--> 654 outputs = [
655 jax_impl_jit(*[x[0] for x in thunk_inputs])
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
654 outputs = [
--> 655 jax_impl_jit(*[x[0] for x in thunk_inputs])
656 for jax_impl_jit in jax_impl_jits
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
136 try:
--> 137 return fun(*args, **kwargs)
138 except Exception as e:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
208 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209 out = xla.xla_call(
210 flat_fun,
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
1143 def bind(self, fun, *args, **params):
-> 1144 return call_bind(self, fun, *args, **params)
1145
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1134 with maybe_new_sublevel(top_trace):
-> 1135 outs = primitive.process(top_trace, fun, tracers, params)
1136 return map(full_lower, apply_todos(env_trace_todo(), outs))
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1146 def process(self, trace, fun, tracers, params):
-> 1147 return trace.process_call(self, fun, tracers, params)
1148
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
576 def process_call(self, primitive, f, tracers, params):
--> 577 return primitive.impl(f, *tracers, **params)
578 process_map = process_call
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
530 *unsafe_map(arg_spec, args))
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
233 else:
--> 234 ans = call(fun, *args)
235 cache[key] = (ans, fun.stores)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
594 if config.omnistaging_enabled:
--> 595 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
596 if any(isinstance(c, core.Tracer) for c in consts):
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1022 main.jaxpr_stack = () # type: ignore
-> 1023 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1024 del main
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1003 in_tracers = map(trace.new_arg, in_avals)
-> 1004 ans = fun.call_wrapped(*in_tracers)
1005 out_tracers = map(trace.full_raise, ans)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
150 try:
--> 151 ans = self.f(*args, **dict(self.params, **kwargs))
152 except:
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
126 func_args = [fn(*inputs) for fn in input_funcs]
--> 127 return return_func(*func_args)
128
~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
545 def reshape(x, shape):
--> 546 return jnp.reshape(x, shape)
547
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
1145 try:
-> 1146 return a.reshape(newshape, order=order) # forward to method for ndarrays
1147 except AttributeError:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
1191 newshape = newshape[0]
-> 1192 return _reshape(a, newshape, order=order)
1193
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
1167 def _reshape(a, newshape, order="C"):
-> 1168 computed_newshape = _compute_newshape(a, newshape)
1169 if order == "C":
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
1156 def check(size):
-> 1157 return size if type(size) is Poly else core.concrete_or_error(
1158 int, size, "The error arose in jax.numpy.reshape.")
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
873 else:
--> 874 raise_concretization_error(val, context)
875 else:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
852 f"Encountered tracer value: {val}")
--> 853 raise ConcretizationTypeError(msg)
854
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:
operation yn:bool[] = lt yl:int64[] ym:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
operation yq:int64[] = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = select a b c
in (d,) }
device=None
donated_invars=(False, False, False)
name=_where ] yn:bool[] yo:int64[] yp:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
During handling of the above exception, another exception occurred:
ConcretizationTypeError Traceback (most recent call last)
<ipython-input-20-21adaeaad34c> in <module>
20
21 with model:
---> 22 pm.sample(mode=jax_mode)
~/projects/pymc/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, return_inferencedata, idata_kwargs, mp_ctx, pickle_backend, **kwargs)
479 # By default, try to use NUTS
480 _log.info("Auto-assigning NUTS sampler...")
--> 481 start_, step = init_nuts(
482 init=init,
483 chains=chains,
~/projects/pymc/pymc3/sampling.py in init_nuts(init, chains, n_init, model, random_seed, progressbar, **kwargs)
2132 raise ValueError(f"Unknown initializer: {init}.")
2133
-> 2134 step = pm.NUTS(potential=potential, model=model, **kwargs)
2135
2136 return start, step
~/projects/pymc/pymc3/step_methods/hmc/nuts.py in __init__(self, vars, max_treedepth, early_max_treedepth, **kwargs)
166 `pm.sample` to the desired number of tuning steps.
167 """
--> 168 super().__init__(vars, **kwargs)
169
170 self.max_treedepth = max_treedepth
~/projects/pymc/pymc3/step_methods/hmc/base_hmc.py in __init__(self, vars, scaling, step_scale, is_cov, model, blocked, potential, dtype, Emax, target_accept, gamma, k, t0, adapt_step_size, step_rand, **theano_kwargs)
91 vars = inputvars(vars)
92
---> 93 super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
94
95 self.adapt_step_size = adapt_step_size
~/projects/pymc/pymc3/step_methods/arraystep.py in __init__(self, vars, model, blocked, dtype, logp_dlogp_func, **theano_kwargs)
252 func.set_extra_values(model.test_point)
253 q = func.dict_to_array(model.test_point)
--> 254 logp, dlogp = func(q)
255 except ValueError:
256 if logp_dlogp_func is not None:
~/projects/pymc/pymc3/model.py in __call__(self, array, grad_out, extra_vars)
737 out = grad_out
738
--> 739 output = self._theano_function(array)
740 if grad_out is None:
741 return output
~/projects/Theano-PyMC/theano/compile/function_module.py in __call__(self, *args, **kwargs)
977 try:
978 outputs = (
--> 979 self.fn()
980 if output_subset is None
981 else self.fn(output_subset=output_subset)
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
705 old_s[0] = None
706 except Exception:
--> 707 raise_with_op(node, thunk)
708
709 f = streamline_default_f
~/projects/Theano-PyMC/theano/gof/link.py in raise_with_op(node, thunk, exc_info, storage_map)
346 # extra long error message in that case.
347 pass
--> 348 reraise(exc_type, exc_value, exc_trace)
349
350
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/six.py in reraise(tp, value, tb)
700 value = tp()
701 if value.__traceback__ is not tb:
--> 702 raise value.with_traceback(tb)
703 raise value
704 finally:
~/projects/Theano-PyMC/theano/gof/link.py in streamline_default_f()
701 thunks, order, post_thunk_old_storage
702 ):
--> 703 thunk()
704 for old_s in old_storage:
705 old_s[0] = None
~/projects/Theano-PyMC/theano/sandbox/jax.py in thunk(node, jax_impl_jits, thunk_outputs)
652 node=node, jax_impl_jits=jax_impl_jits, thunk_outputs=thunk_outputs
653 ):
--> 654 outputs = [
655 jax_impl_jit(*[x[0] for x in thunk_inputs])
656 for jax_impl_jit in jax_impl_jits
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
653 ):
654 outputs = [
--> 655 jax_impl_jit(*[x[0] for x in thunk_inputs])
656 for jax_impl_jit in jax_impl_jits
657 ]
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
135 def reraise_with_filtered_traceback(*args, **kwargs):
136 try:
--> 137 return fun(*args, **kwargs)
138 except Exception as e:
139 if not is_under_reraiser(e):
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/api.py in f_jitted(*args, **kwargs)
207 _check_arg(arg)
208 flat_fun, out_tree = flatten_fun(f, in_tree)
--> 209 out = xla.xla_call(
210 flat_fun,
211 *args_flat,
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in bind(self, fun, *args, **params)
1142
1143 def bind(self, fun, *args, **params):
-> 1144 return call_bind(self, fun, *args, **params)
1145
1146 def process(self, trace, fun, tracers, params):
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
1133 tracers = map(top_trace.full_raise, args)
1134 with maybe_new_sublevel(top_trace):
-> 1135 outs = primitive.process(top_trace, fun, tracers, params)
1136 return map(full_lower, apply_todos(env_trace_todo(), outs))
1137
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
1145
1146 def process(self, trace, fun, tracers, params):
-> 1147 return trace.process_call(self, fun, tracers, params)
1148
1149 def post_process(self, trace, out_tracers, params):
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
575
576 def process_call(self, primitive, f, tracers, params):
--> 577 return primitive.impl(f, *tracers, **params)
578 process_map = process_call
579
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
527
528 def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name, donated_invars):
--> 529 compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,
530 *unsafe_map(arg_spec, args))
531 try:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in memoized_fun(fun, *args)
232 fun.populate_stores(stores)
233 else:
--> 234 ans = call(fun, *args)
235 cache[key] = (ans, fun.stores)
236 return ans
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/xla.py in _xla_callable(fun, device, backend, name, donated_invars, *arg_specs)
593 abstract_args, arg_devices = unzip2(arg_specs)
594 if config.omnistaging_enabled:
--> 595 jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
596 if any(isinstance(c, core.Tracer) for c in consts):
597 raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_jaxpr_final(fun, in_avals)
1021 main.source_info = fun_sourceinfo(fun.f) # type: ignore
1022 main.jaxpr_stack = () # type: ignore
-> 1023 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
1024 del main
1025 return jaxpr, out_avals, consts
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/interpreters/partial_eval.py in trace_to_subjaxpr_dynamic(fun, main, in_avals)
1002 trace = DynamicJaxprTrace(main, core.cur_sublevel())
1003 in_tracers = map(trace.new_arg, in_avals)
-> 1004 ans = fun.call_wrapped(*in_tracers)
1005 out_tracers = map(trace.full_raise, ans)
1006 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
149
150 try:
--> 151 ans = self.f(*args, **dict(self.params, **kwargs))
152 except:
153 # Some transformations yield from inside context managers, so we have to
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
124
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
128
~/projects/Theano-PyMC/theano/sandbox/jax.py in <listcomp>(.0)
124
125 def jax_func(*inputs):
--> 126 func_args = [fn(*inputs) for fn in input_funcs]
127 return return_func(*func_args)
128
~/projects/Theano-PyMC/theano/sandbox/jax.py in jax_func(*inputs)
125 def jax_func(*inputs):
126 func_args = [fn(*inputs) for fn in input_funcs]
--> 127 return return_func(*func_args)
128
129 jax_funcs.append(update_wrapper(jax_func, return_func))
~/projects/Theano-PyMC/theano/sandbox/jax.py in reshape(x, shape)
544 def jax_funcify_Reshape(op):
545 def reshape(x, shape):
--> 546 return jnp.reshape(x, shape)
547
548 return reshape
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in reshape(a, newshape, order)
1144 def reshape(a, newshape, order="C"):
1145 try:
-> 1146 return a.reshape(newshape, order=order) # forward to method for ndarrays
1147 except AttributeError:
1148 return _reshape(a, newshape, order=order)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape_method(a, *newshape, **kwargs)
1190 type(newshape[0]) is not Poly):
1191 newshape = newshape[0]
-> 1192 return _reshape(a, newshape, order=order)
1193
1194
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _reshape(a, newshape, order)
1166
1167 def _reshape(a, newshape, order="C"):
-> 1168 computed_newshape = _compute_newshape(a, newshape)
1169 if order == "C":
1170 return lax.reshape(a, computed_newshape, None)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in _compute_newshape(a, newshape)
1157 return size if type(size) is Poly else core.concrete_or_error(
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
1161 if newsize < 0:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in <listcomp>(.0)
1157 return size if type(size) is Poly else core.concrete_or_error(
1158 int, size, "The error arose in jax.numpy.reshape.")
-> 1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
1160 newsize = _prod((newshape,) if type(newshape) is Poly else newshape)
1161 if newsize < 0:
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/numpy/lax_numpy.py in check(size)
1155 else: iterable = True
1156 def check(size):
-> 1157 return size if type(size) is Poly else core.concrete_or_error(
1158 int, size, "The error arose in jax.numpy.reshape.")
1159 newshape = [check(size) for size in newshape] if iterable else check(newshape)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in concrete_or_error(force, val, context)
872 return force(val.aval.val)
873 else:
--> 874 raise_concretization_error(val, context)
875 else:
876 return force(val)
~/miniconda3/envs/pymc3theano/lib/python3.8/site-packages/jax/core.py in raise_concretization_error(val, context)
851 "See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.\n\n"
852 f"Encountered tracer value: {val}")
--> 853 raise ConcretizationTypeError(msg)
854
855
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected.
The error arose in jax.numpy.reshape.
While tracing the function incsubtensor at /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:125, this value became a tracer due to JAX operations on these lines:
operation yn:bool[] = lt yl:int64[] ym:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
operation yq:int64[] = xla_call[ backend=None
call_jaxpr={ lambda ; a b c.
let d = select a b c
in (d,) }
device=None
donated_invars=(False, False, False)
name=_where ] yn:bool[] yo:int64[] yp:int64[]
from line /Users/twiecki/projects/Theano-PyMC/theano/sandbox/jax.py:127 (jax_func)
You can use transformation parameters such as `static_argnums` for `jit` to avoid tracing particular arguments of transformed functions.
See https://jax.readthedocs.io/en/latest/faq.html#abstract-tracer-value-encountered-where-concrete-value-is-expected-error for more information.
Encountered tracer value: Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
Apply node that caused the error: Sum{acc_dtype=float64}(MakeVector{dtype='float64'}.0)
Toposort index: 46
Inputs types: [TensorType(float64, vector)]
Inputs shapes: [(3,)]
Inputs strides: [(8,)]
Inputs values: [array([0.69049938, 0. , 0. ])]
Outputs clients: [['output']]
Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
File "<ipython-input-20-21adaeaad34c>", line 22, in <module>
pm.sample(mode=jax_mode)
File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 481, in sample
start_, step = init_nuts(
File "/Users/twiecki/projects/pymc/pymc3/sampling.py", line 2134, in init_nuts
step = pm.NUTS(potential=potential, model=model, **kwargs)
File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/nuts.py", line 168, in __init__
super().__init__(vars, **kwargs)
File "/Users/twiecki/projects/pymc/pymc3/step_methods/hmc/base_hmc.py", line 93, in __init__
super().__init__(vars, blocked=blocked, model=model, dtype=dtype, **theano_kwargs)
File "/Users/twiecki/projects/pymc/pymc3/step_methods/arraystep.py", line 245, in __init__
func = model.logp_dlogp_function(
File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1005, in logp_dlogp_function
costs = [self.logpt]
File "/Users/twiecki/projects/pymc/pymc3/model.py", line 1015, in logpt
logp = tt.sum([tt.sum(factor) for factor in factors])
HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint and storage map footprint of this apply node.