GithubHelp home page GithubHelp logo

larslorch / dibs Goto Github PK

View Code? Open in Web Editor NEW
43.0 4.0 11.0 13.33 MB

DiBS: Differentiable Bayesian Structure Learning, NeurIPS 2021

Home Page: https://arxiv.org/abs/2105.11839

License: MIT License

Python 100.00%
bayesian-inference bayesian-networks variational-inference structure-learning causal-inference

dibs's Introduction

DiBS: Differentiable Bayesian Structure Learning

Documentation | Quick Start | Example Notebooks | Change Log | Branches | Reference

Downloads Documentation Status PyPi

Overview

This is the Python JAX implementation for DiBS (Lorch et al., 2021), a fully differentiable method for joint Bayesian inference of the DAG and parameters of general, causal Bayesian networks. In this implementation, DiBS inference is performed with SVGD (Liu and Wang, 2016). Since DiBS and SVGD operate on continuous tensors and solely rely on Monte Carlo estimation and gradient ascent-like updates, the inference code leverages efficient vectorized operations, automatic differentiation, just-in-time compilation, and hardware acceleration, fully implemented with JAX.

To install the latest stable release, run:

pip install dibs-lib

The documentation is linked here: Documentation.

Quick Start

The following code snippet demonstrates how to use the dibs package. In this example, we use DiBS to generate 10 DAG and parameter samples from the joint posterior over Gaussian Bayes nets with means modeled by neural networks.

from dibs.inference import JointDiBS
from dibs.target import make_nonlinear_gaussian_model
import jax.random as random
key = random.PRNGKey(0)

# simulate some data
key, subk = random.split(key)
data, graph_model, likelihood_model = make_nonlinear_gaussian_model(key=subk, n_vars=20)

# sample 10 DAG and parameter particles from the joint posterior
dibs = JointDiBS(x=data.x, interv_mask=None, graph_model=graph_model, likelihood_model=likelihood_model)
key, subk = random.split(key)
gs, thetas = dibs.sample(key=subk, n_particles=10, steps=1000)

The argument x for JointDiBS is a matrix of shape [N, d] and could be any real-world data set. interv_mask is a binary mask of the same shape that indicates whether or not a variable was intervened upon in a given sample (interv_mask=None indicates observational data and is equivalent to interv_mask=jax.numpy.zeros_like(x)).

Example Notebooks

Try out a working example notebook in Google Colab, which runs directly from your browser. Whenever a GPU backend is available to JAX, dibs will automatically leverage it to accelerate its computations, so you can select the free GPU runtime available in Google Colab for speed-up.

Open In Colab

Analogous notebooks can be found inside the examples/ folder. Executing the code will generate samples from the joint posterior with DiBS and simultaneously visualize the matrices of edge probabilities modeled by the individual particles that are transported by SVGD during inference.





Change Log

  • 4 Jul 2022: Inference from interventional data via the interventional log (marginal) likelihood, assuming known, hard interventions. To model soft or random interventions, the likelihoods in the model classes in dibs/models/ can be easily modified.

  • 14 Mar 2022: Published to PyPI

  • 12 Mar 2022: Extended BGe marginal likelihood to be well-defined inside the probability simplex. The computation remains exact for binary entries but is well-behaved for soft relaxations of the graph. This allows reparameterization (Gumbel-softmax) gradient estimation for the BGe score.

  • 14 Dec 2021: Documentation added

Branches and Custom Installation

The repository consists of two branches:

  • master (recommended, on PyPI): Lightweight and easy-to-use package for using DiBS in your research or applications.
  • full: Comprehensive code to reproduce the experimental results in (Lorch et al., 2021). The purpose of this branch is reproducibility; the branch is not updated anymore and may contain outdated notation and documentation.

The latest stable release is published on PyPI, so the best way to install dibs is using pip as shown above. For custom installations, we recommend using conda and generating a new environment via environment.yml. Next, clone the code repository:

git clone https://github.com/larslorch/dibs.git

Finally, install the dibs package with

pip install -e .

Reference

@article{lorch2021dibs,
  title={DiBS: Differentiable Bayesian Structure Learning},
  author={Lorch, Lars and Rothfuss, Jonas and Sch{\"o}lkopf, Bernhard and Krause, Andreas},
  journal={Advances in Neural Information Processing Systems},
  volume={34},
  year={2021}
}

dibs's People

Contributors

haeggee avatar larslorch avatar mnazaal avatar st-- avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

dibs's Issues

Consider referring to python-igraph instead of igraph

Please consider referring to igraph instead of python-igraph in the following three locations:

https://github.com/search?q=repo%3Alarslorch%2Fdibs%20python-igraph&type=code

and dropping the pin to version 0.8.3, which is extremely outdated. Behaviour changes since then are mainly the result of bugfixes. If there is a behaviour change that affects this project, feel free to open an issue in the python-igraph repo and ask about it.

The python-igraph name is deprecated on PyPI and will soon cease to receive updates. See igraph/python-igraph#699 for details. However, the first version available under the igraph name is 0.9.8 from October 2021. 0.8.3 (current pin) is not available under this name.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.