GithubHelp home page GithubHelp logo

jax-ml / bayeux Goto Github PK

View Code? Open in Web Editor NEW
104.0 3.0 4.0 1.39 MB

State of the art inference for your bayesian models.

Home Page: https://jax-ml.github.io/bayeux/

License: Apache License 2.0

Python 100.00%

bayeux's Introduction

Bayeux

Stitching together models and samplers

Unittests PyPI version

bayeux lets you write a probabilistic model in JAX and immediately have access to state-of-the-art inference methods. The API aims to be simple, self descriptive, and helpful. Simply provide a log density function (which doesn't even have to be normalized), along with a single point (specified as a pytree) where that log density is finite. Then let bayeux do the rest!

Installation

pip install bayeux-ml

Quickstart

We define a model by providing a log density in JAX. This could be defined using a probabilistic programming language (PPL) like numpyro, PyMC, TFP, distrax, oryx, coix, or directly in JAX.

import bayeux as bx
import jax

normal_density = bx.Model(
  log_density=lambda x: -x*x,
  test_point=1.)

seed = jax.random.key(0)

opt_results = normal_density.optimize.optax_adam(seed=seed)
# OR!
idata = normal_density.mcmc.numpyro_nuts(seed=seed)
# OR!
surrogate_posterior, loss = normal_density.vi.tfp_factored_surrogate_posterior(seed=seed)

Read more

This is not an officially supported Google product.

bayeux's People

Contributors

colcarroll avatar theorashid avatar thomascolthurst avatar thomasjpfan avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

bayeux's Issues

Tensorflow probability MCMC and VI methods do not work with Bambi models

Tensorflow probability samplers fail when attempting to sample from a Bambi model.

import bambi as bmb
import bayeux as bx
import jax

data = bmb.load_data("ANES")
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]

model = bmb.Model("vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli")
model.build()

bx_model = bx.Model.from_pymc(model.backend.model)
bx_model.mcmc.tfp_hmc(seed=jax.random.key(0))
TypeError: float() argument must be a string or a real number, not 'ShapedArray'

This same TypeError persists when using any TFP MCMC algorithm.

When attempting to use the TFP VI method, the following error is raised

bx_model.vi.tfp_factored_surrogate_posterior(seed=jax.random.key(0))
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

I haven't looked into why these errors are happening yet. I just wanted to bring this to your attention. Since the Bambi backend model model.backend.model is a PyMC model, these errors may also happen with PyMC models.

Suggestions for Enhancing Tutorials

Description

The tutorials provided are undeniably valuable, serving as beacons guiding learners through the intricacies of various concepts. However, there's always room for enhancement and expansion. Here are some suggestions:

  • Corner Plots: Introducing corner plots would significantly augment the visualization capabilities of the tutorials.
  • Detailed Tutorial on Statistical Models: A comprehensive tutorial focusing on the implementation of sophisticated statistical models such as the Poisson Process, Gaussian Process, or any Hierarchical Likelihood function would be immensely beneficial.

Save PyMC `Deterministic`s to idata

Hey, have been playing around with this a bit from PyMC, so glad this exists now! Unfortunately I'm not getting Deterministics recorded in idata.posterior. Happy to attempt PR if you point me where to start?

Can Bayeux help leverage TFP for Apple Silicon GPU?

I am on a M1 Ultra with the silicon GPU. Given that Tensorflow has Apple Silicon GPU support. Can I leverage that via Bayeux? I installed Tensorflow + the TensorFlow Metal library (and Tensorflow sees the GPU) but I cannot figure out how to tell Bayeux to tell Tensorflow-Probability to use Tensorflow instead of JAX, and thus the GPU.

Screenshot 2024-02-21 at 12 51 55โ€ฏPM

Error for Bambi/PYMC Model when using TFP

When I run the following, I get the following error:

TypeError: float() argument must be a string or a real number, not 'ShapedArray'

import bayeux as bx
import bambi as bmb
import pymc as pm
import pandas as pd
import jax
import arviz as az

dist = pm.Normal.dist(mu=100, sigma=30)

draws = pm.draw(dist, draws=1_000, random_seed=1000)

df = pd.DataFrame(data=draws, columns=['heights'])

formula = bmb.Formula('heights ~ 1')

model = bmb.Model(formula=formula, family='gaussian', data=df)

model.build()

bx_model = bx.Model.from_pymc(model.backend.model)

idata = bx_model.mcmc.tfp_nuts(seed=jax.random.key(0))

az.summary(idata)

The traceback is as follows:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], [line 22](vscode-notebook-cell:?execution_count=1&line=22)
     [18](vscode-notebook-cell:?execution_count=1&line=18) model.build()
     [20](vscode-notebook-cell:?execution_count=1&line=20) bx_model = bx.Model.from_pymc(model.backend.model)
---> [22](vscode-notebook-cell:?execution_count=1&line=22) idata = bx_model.mcmc.tfp_nuts(seed=jax.random.key(0))
     [24](vscode-notebook-cell:?execution_count=1&line=24) az.summary(idata)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205), in _TFPBase.__call__(self, seed, **kwargs)
    [194](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:194) initial_running_variance = [
    [195](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:195)     tfp.experimental.stats.sample_stats.RunningVariance.from_stats(
    [196](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:196)         num_samples=jnp.array(1, part.dtype),
    [197](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:197)         mean=jnp.zeros_like(part),
    [198](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:198)         variance=jnp.ones_like(part))
    [199](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:199)     for part in initial_transformed_position]
    [201](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:201) # The public API expects a JointDistribution. Much of the above is adapted
    [202](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:202) # from the source code for
    [203](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:203) # `tfp.experimental.mcmc.windowed_adaptive_{nuts|hmc}`, but handling a raw
    [204](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:204) # log density, and doing the structure flattening with `jax.tree_utils`.
--> [205](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:205) draws, trace = tfp.experimental.mcmc.windowed_sampling._do_sampling(
    [206](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:206)     kind=self.algorithm,
    [207](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:207)     proposal_kernel_kwargs=proposal_kernel_kwargs,
    [208](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:208)     dual_averaging_kwargs=dual_averaging_kwargs,
    [209](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:209)     num_draws=extra_parameters["num_draws"],
    [210](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:210)     num_burnin_steps=extra_parameters["num_adaptation_steps"],
    [211](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:211)     initial_position=initial_transformed_position,
    [212](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:212)     initial_running_variance=initial_running_variance,
    [213](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:213)     bijector=None,
    [214](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:214)     trace_fn=_TRACE_FNS[self.algorithm],
    [215](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:215)     return_final_kernel_results=False,
    [216](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:216)     chain_axis_names=None,
    [217](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:217)     shard_axis_names=None,
    [218](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:218)     seed=sample_key)
    [220](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:220) draws = self.transform_fn(jax.tree_util.tree_unflatten(treedef, draws))
    [221](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/bayeux/_src/mcmc/tfp.py:221) if extra_parameters["return_pytree"]:

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551), in _do_sampling(kind, proposal_kernel_kwargs, dual_averaging_kwargs, num_draws, num_burnin_steps, initial_position, initial_running_variance, trace_fn, bijector, return_final_kernel_results, seed, chain_axis_names, shard_axis_names)
    [543](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:543) """Sample from base HMC kernel."""
    [544](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:544) kernel = make_windowed_adapt_kernel(
    [545](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:545)     kind=kind,
    [546](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:546)     proposal_kernel_kwargs=proposal_kernel_kwargs,
   (...)
    [549](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:549)     chain_axis_names=chain_axis_names,
    [550](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:550)     shard_axis_names=shard_axis_names)
--> [551](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:551) return sample.sample_chain(
    [552](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:552)     num_draws,
    [553](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:553)     initial_position,
    [554](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:554)     kernel=kernel,
    [555](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:555)     num_burnin_steps=num_burnin_steps,
    [556](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:556)     # pylint: disable=g-long-lambda
    [557](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:557)     trace_fn=lambda state, pkr: trace_fn(
    [558](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:558)         state, bijector, pkr.step <= dual_averaging_kwargs[
    [559](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:559)             'num_adaptation_steps'], pkr.inner_results.inner_results.
    [560](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:560)         inner_results),
    [561](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:561)     # pylint: enable=g-long-lambda
    [562](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:562)     return_final_kernel_results=return_final_kernel_results,
    [563](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/experimental/mcmc/windowed_sampling.py:563)     seed=seed)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359), in sample_chain(num_results, current_state, previous_kernel_results, kernel, num_burnin_steps, num_steps_between_results, trace_fn, return_final_kernel_results, parallel_iterations, seed, name)
    [352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352)   seed, next_state, current_kernel_results = loop_util.smart_for_loop(
    [353](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:353)       loop_num_iter=num_steps,
    [354](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:354)       body_fn=_seeded_one_step,
    [355](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:355)       initial_loop_vars=list(seed_state_and_results),
    [356](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:356)       parallel_iterations=parallel_iterations)
    [357](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:357)   return seed, next_state, current_kernel_results
--> [359](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:359) (_, _, final_kernel_results), (all_states, trace) = loop_util.trace_scan(
    [360](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:360)     loop_fn=_trace_scan_fn,
    [361](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:361)     initial_state=(seed, current_state, previous_kernel_results),
    [362](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:362)     elems=tf.one_hot(
    [363](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:363)         indices=0,
    [364](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:364)         depth=num_results,
    [365](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:365)         on_value=1 + num_burnin_steps,
    [366](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:366)         off_value=1 + num_steps_between_results,
    [367](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:367)         dtype=tf.int32),
    [368](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:368)     # pylint: disable=g-long-lambda
    [369](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:369)     trace_fn=lambda seed_state_and_results: (
    [370](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:370)         seed_state_and_results[1], trace_fn(*seed_state_and_results[1:])),
    [371](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:371)     # pylint: enable=g-long-lambda
    [372](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:372)     parallel_iterations=parallel_iterations)
    [374](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:374) if return_final_kernel_results:
    [375](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:375)   return CheckpointableStatesAndTrace(
    [376](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:376)       all_states=all_states,
    [377](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:377)       trace=trace,
    [378](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:378)       final_kernel_results=final_kernel_results)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232), in trace_scan(loop_fn, initial_state, elems, trace_fn, trace_criterion_fn, static_trace_allocation_size, condition_fn, parallel_iterations, name)
    [224](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:224)   trace_arrays, num_steps_traced = ps.cond(
    [225](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:225)       trace_criterion_fn(state) if trace_criterion_fn else True,
    [226](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:226)       lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
    [227](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:227)                num_steps_traced + 1),
    [228](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:228)       lambda: (trace_arrays, num_steps_traced))
    [230](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:230)   return i + 1, state, num_steps_traced, trace_arrays
--> [232](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:232) _, final_state, _, trace_arrays = tf.while_loop(
    [233](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:233)     cond=condition_fn if condition_fn is not None else lambda *_: True,
    [234](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:234)     body=_body,
    [235](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:235)     loop_vars=(0, initial_state, 0, trace_arrays),
    [236](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:236)     maximum_iterations=length,
    [237](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:237)     parallel_iterations=parallel_iterations)
    [239](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:239) # unflatten
    [240](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:240) stacked_trace = tf.nest.pack_sequence_as(
    [241](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:241)     initial_trace, [ta.stack() for ta in trace_arrays],
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:242)     expand_composites=True)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102), in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     [99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99)       args = pack_body(body(*args))
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:100)     return args, ()
--> [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:102)   loop_vars, _ = lax.scan(
    [103](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:103)       override_body_fn, loop_vars, xs=None, length=maximum_iterations)
    [104](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:104)   return loop_vars
    [105](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:105) else:

    [... skipping hidden 9 frame]

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99), in _while_loop_jax.<locals>.override_body_fn(args, _)
     [96](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:96)   args = lax.cond(c, args, lambda args: pack_body(body(*args)), args,
     [97](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:97)                   lambda args: args)
     [98](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:98) elif sc:
---> [99](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:99)   args = pack_body(body(*args))
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:100) return args, ()

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222), in trace_scan.<locals>._body(i, state, num_steps_traced, trace_arrays)
    [220](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:220) def _body(i, state, num_steps_traced, trace_arrays):
    [221](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:221)   elem = elems_array.read(i)
--> [222](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:222)   state = loop_fn(state, elem)
    [224](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:224)   trace_arrays, num_steps_traced = ps.cond(
    [225](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:225)       trace_criterion_fn(state) if trace_criterion_fn else True,
    [226](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:226)       lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
    [227](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:227)                num_steps_traced + 1),
    [228](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:228)       lambda: (trace_arrays, num_steps_traced))
    [230](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:230)   return i + 1, state, num_steps_traced, trace_arrays

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352), in sample_chain.<locals>._trace_scan_fn(seed_state_and_results, num_steps)
    [351](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:351) def _trace_scan_fn(seed_state_and_results, num_steps):
--> [352](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:352)   seed, next_state, current_kernel_results = loop_util.smart_for_loop(
    [353](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:353)       loop_num_iter=num_steps,
    [354](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:354)       body_fn=_seeded_one_step,
    [355](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:355)       initial_loop_vars=list(seed_state_and_results),
    [356](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:356)       parallel_iterations=parallel_iterations)
    [357](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/mcmc/sample.py:357)   return seed, next_state, current_kernel_results

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111), in smart_for_loop(loop_num_iter, body_fn, initial_loop_vars, parallel_iterations, unroll_threshold, name)
    [101](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:101) if (loop_num_iter_ is None
    [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:102)     or tf.executing_eagerly()
    [103](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:103)     # large values for loop_num_iter_ will cause ridiculously slow
   (...)
    [108](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:108)   # Cast to int32 to run the comparison against i in host memory,
    [109](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:109)   # where while/LoopCond needs it.
    [110](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:110)   loop_num_iter = tf.cast(loop_num_iter, dtype=tf.int32)
--> [111](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:111)   return tf.while_loop(
    [112](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:112)       cond=lambda i, *args: i < loop_num_iter,
    [113](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:113)       body=lambda i, *args: [i + 1] + list(body_fn(*args)),
    [114](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:114)       loop_vars=[np.int32(0)] + initial_loop_vars,
    [115](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:115)       parallel_iterations=parallel_iterations
    [116](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:116)   )[1:]
    [117](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:117) result = initial_loop_vars
    [118](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/substrates/jax/internal/loop_util.py:118) for _ in range(loop_num_iter_):

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90), in _while_loop_jax(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, maximum_iterations, name)
     [88](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:88)   def override_cond_fn(args):
     [89](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:89)     return cond(*args)
---> [90](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:90)   return lax.while_loop(override_cond_fn, override_body_fn, loop_vars)
     [91](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:91) elif back_prop:
     [92](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/control_flow.py:92)   def override_body_fn(args, _):

    [... skipping hidden 4 frame]

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102), in register_pytrees.<locals>.register.<locals>.unflatten(info, xs)
    [100](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:100) keys, metadata = info
    [101](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:101) parameters = dict(list(zip(keys, xs)), **metadata)
--> [102](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/linalg.py:102) return cls(**parameters)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171), in LinearOperatorDiag.__init__(self, diag, is_non_singular, is_self_adjoint, is_positive_definite, is_square, name)
    [161](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:161) parameters = dict(
    [162](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:162)     diag=diag,
    [163](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:163)     is_non_singular=is_non_singular,
   (...)
    [167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:167)     name=name
    [168](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:168) )
    [170](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:170) with ops.name_scope(name, values=[diag]):
--> [171](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:171)   self._diag = linear_operator_util.convert_nonref_to_tensor(
    [172](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:172)       diag, name="diag")
    [173](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:173)   self._check_diag(self._diag)
    [175](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_diag.py:175)   # Check and auto-set hints.

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134), in convert_nonref_to_tensor(value, dtype, dtype_hint, name)
    [130](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:130)     raise TypeError(
    [131](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:131)         f"Argument `value` must be of dtype `{dtype_name(dtype_base)}` "
    [132](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:132)         f"Received: `{dtype_name(value_dtype_base)}`.")
    [133](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:133)   return value
--> [134](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:134) return ops.convert_to_tensor(
    [135](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:135)     value, dtype=dtype, dtype_hint=dtype_hint, name=name
    [136](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/gen/linear_operator_util.py:136) )

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167), in _convert_to_tensor(value, dtype, dtype_hint, name)
    [164](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:164)     pass
    [166](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:166) if ret is None:
--> [167](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:167)   ret = conversion_func(value, dtype=dtype)
    [168](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:168) return ret

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243), in _default_convert_to_tensor(value, dtype)
    [240](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:240) # If no dtype is provided, we try the inferred dtype and fallback to int64 or
    [241](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:241) # float32 depending on the type of conversion error we see.
    [242](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:242) try:
--> [243](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:243)   return _default_convert_to_tensor_with_dtype(value, inferred_dtype)
    [244](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:244) except _Int64ToInt32Error as e:
    [245](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:245)   return np.array(value, dtype=np.int64)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286), in _default_convert_to_tensor_with_dtype(value, dtype, error_if_mismatch)
    [283](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:283) is_arraylike = hasattr(value, 'dtype')
    [284](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:284) if is_arraylike:
    [285](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:285)   # Duck-typed for `onp.array`/`oonp.generic`
--> [286](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:286)   arr = np.array(value)
    [287](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:287)   if dtype is not None:
    [288](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:288)     # arr.astype(None) forces conversion to float64
    [289](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:289)     return arr.astype(dtype)

File [~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158), in array(object, dtype, copy, order, ndmin)
   [2151](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2151) out: ArrayLike
   [2153](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2153) if all(not isinstance(leaf, Array) for leaf in leaves):
   [2154](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2154)   # TODO(jakevdp): falling back to numpy here fails to overflow for lists
   [2155](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2155)   # containing large integers; see discussion in
   [2156](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2156)   # https://github.com/google/jax/pull/6047. More correct would be to call
   [2157](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2157)   # coerce_to_array on each leaf, but this may have performance implications.
-> [2158](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2158)   out = np.array(object, dtype=dtype, ndmin=ndmin, copy=False)  # type: ignore[arg-type]
   [2159](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2159) elif isinstance(object, Array):
   [2160](https://file+.vscode-resource.vscode-cdn.net/Users/zweli/Desktop/~/mambaforge/envs/bambi_tf/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2160)   assert object.aval is not None

TypeError: float() argument must be a string or a real number, not 'ShapedArray'

numpyro models run in numpyro but not using bayeux

A couple of examples of models that run in numpyro but not in bayeux. First example runs but does not produce the correct answer. Second example does not run and has shape errors associated with the number of chains.

numpyro==0.15.0
bayeux-ml==0.1.12
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1

key = random.PRNGKey(0)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,))

def model():
	alpha = numpyro.sample("alpha", dist.Normal(0, 3))
	sigma = numpyro.sample("sigma", dist.HalfNormal(1))
	numpyro.sample("y", dist.Normal(alpha, sigma), obs=data)


# this runs fine, samples only from alpha and sigma and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work and seems to sample from the observed sites
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0))

# it would also be nice to write the numpyro model as def model(data=None)
# and call bayeux as bx.Model.from_numpyro(model, data=data)
import jax.numpy as jnp
from jax import random

import arviz as az
import bayeux as bx
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC

N = 100
true_alpha = 1.1
true_sigma = 0.1
true_beta = 0.8

key = random.PRNGKey(0)
x = jnp.linspace(0, 1, N)
data = true_alpha + true_sigma * random.normal(key=key, shape=(N,)) + true_beta * x

def model():
	alpha = numpyro.sample("alpha", dist.Normal(0, 3))
	sigma = numpyro.sample("sigma", dist.HalfNormal(1))
	beta = numpyro.sample("beta", dist.Normal(0, 3))
	mu = alpha + beta * x
	numpyro.sample("y", dist.Normal(mu, sigma), obs=data)

# this runs fine and recovers the parameters
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=2)
mcmc.run(random.key(0))
mcmc.print_summary()

# this does not work
bx_model = bx.Model.from_numpyro(model)
idata = bx_model.mcmc.numpyro_nuts(seed=random.key(0), num_chains=2)

# mul got incompatible shapes for broadcasting: (2,), (100,).
# issue with multiple chains

<class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type

I get the below error when I run this simple TFP NUTS instance

x = jnp.array([1.1, 2, 3])

@tfd.JointDistributionCoroutineAutoBatched
def model():
  mu = yield tfd.Normal(0, 1, name='mu')
  sigma = yield tfd.Gamma(1, 1, name='sigma')
  yield tfd.Normal(mu, sigma, name='observed')

bx_model = bx.Model.from_tfp(model.experimental_pin(observed=x))
idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))

error

TypeError: Value <tf.Tensor: shape=(8, 3), dtype=float32, numpy=
array([[-... ]], dtype=float32)> with type <class 'tensorflow.python.framework.ops.EagerTensor'> is not a valid JAX type

PyMC + Blackjax Fail with latest version 1.2.0

Hi ๐Ÿ‘‹ ! I am trying to run the PyMC example with blackjax==1.2.0 and I am getting this error (with blackjax==1.1.1 works fine) ๐Ÿฅฒ

Should I open an issue in the blackjax repo as well? I am not sure where the error is coming from.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[1], line 24
     15     pm.Normal(
     16         "observed",
     17         avg_effect + avg_stddev * school_effects,
     18         treatment_stddevs,
     19         observed=treatment_effects,
     20     )
     22 bx_model = bx.Model.from_pymc(model)
---> 24 idata = bx_model.mcmc.blackjax_nuts(seed=jax.random.key(0))
     26 az.summary(idata)

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:73, in _BlackjaxSampler.__call__(self, seed, **kwargs)
     71 def __call__(self, seed, **kwargs):
     72   init_key, sample_key = jax.random.split(seed)
---> 73   kwargs = self.get_kwargs(**kwargs)
     74   initial_state = self.get_initial_state(
     75       init_key, num_chains=kwargs["extra_parameters"]["num_chains"])
     77   return _sample_blackjax(
     78       initial_state=self.inverse_transform_fn(initial_state),
     79       algorithm=_ALGORITHMS[self.algorithm],
   (...)
     82       seed=sample_key,
     83       kwargs=kwargs)

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:63, in _BlackjaxSampler.get_kwargs(self, **kwargs)
     61 extra_parameters = get_extra_kwargs(kwargs)
     62 constrained_log_density = self.constrained_log_density()
---> 63 adaptation_kwargs, run_kwargs = get_adaptation_kwargs(
     64     adapt_fn, algorithm, constrained_log_density, extra_parameters | kwargs)
     65 return {adapt_fn: adaptation_kwargs,
     66         "adapt.run": run_kwargs,
     67         algorithm: get_algorithm_kwargs(
     68             algorithm, constrained_log_density, kwargs),
     69         "extra_parameters": extra_parameters}

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:260, in get_adaptation_kwargs(adaptation_algorithm, algorithm, log_density, kwargs)
    257   adaptation_required.remove("algorithm")
    258   adaptation_kwargs["algorithm"] = algorithm
    259   adaptation_kwargs = (
--> 260       get_algorithm_kwargs(algorithm, log_density, kwargs) | adaptation_kwargs
    261   )
    263 adaptation_required = adaptation_required - adaptation_kwargs.keys()
    265 if adaptation_required:

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/bayeux/_src/mcmc/blackjax.py:310, in get_algorithm_kwargs(algorithm, log_density, kwargs)
    303 kwargs_with_defaults = {
    304     "logdensity_fn": log_density,
    305     "step_size": 0.5,
    306     "num_integration_steps": 16,
    307 } | kwargs
    308 shared.update_with_kwargs(
    309     algorithm_kwargs, reqd=algorithm_required, kwargs=kwargs_with_defaults)
--> 310 algorithm_required.remove("logdensity_fn")
    311 algorithm_required.discard("inverse_mass_matrix")
    312 algorithm_required.discard("alpha")

KeyError: 'logdensity_fn'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.