GithubHelp home page GithubHelp logo

sbi-dev / sbi Goto Github PK

View Code? Open in Web Editor NEW
547.0 22.0 134.0 62.39 MB

Simulation-based inference toolkit

Home Page: https://sbi-dev.github.io/sbi/

License: Apache License 2.0

Python 99.96% Dockerfile 0.04%
simulation-based-inference likelihood-free-inference bayesian-inference parameter-estimation pytorch machine-learning

sbi's Introduction

PyPI version Contributions welcome Tests codecov GitHub license DOI

sbi: simulation-based inference

Getting Started | Documentation

sbi is a PyTorch package for simulation-based inference. Simulation-based inference is the process of finding parameters of a simulator from observations.

sbi takes a Bayesian approach and returns a full posterior distribution over the parameters of the simulator, conditional on the observations. The package implements a variety of inference algorithms, including amortized and sequential methods. Amortized methods return a posterior that can be applied to many different observations without retraining; sequential methods focus the inference on one particular observation to be more simulation-efficient. See below for an overview of implemented methods.

sbi offers a simple interface for posterior inference in a few lines of code

from sbi.inference import SNPE
# import your simulator, define your prior over the parameters
# sample parameters theta and observations x
inference = SNPE(prior=prior)
_ = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior()

Installation

sbi requires Python 3.8 or higher. A GPU is not required, but can lead to speed-up in some cases. We recommend to use a conda virtual environment (Miniconda installation instructions). If conda is installed on the system, an environment for installing sbi can be created as follows:

# Create an environment for sbi (indicate Python 3.8 or higher); activate it
$ conda create -n sbi_env python=3.10 && conda activate sbi_env

Independent of whether you are using conda or not, sbi can be installed using pip:

pip install sbi

To test the installation, drop into a python prompt and run

from sbi.examples.minimal import simple
posterior = simple()
print(posterior)

Tutorials

If you're new to sbi, we recommend starting with our Getting Started tutorial.

You can easily access and run these tutorials by opening a Codespace on this repo. To do so, click on the green "Code" button and select "Open with Codespaces". This will provide you with a fully functional environment where you can run the tutorials as Jupyter notebooks.

Inference Algorithms

The following inference algorithms are currently available. You can find instructions on how to run each of these methods here.

Neural Posterior Estimation: amortized (NPE) and sequential (SNPE)

Neural Likelihood Estimation: amortized (NLE) and sequential (SNLE)

Neural Ratio Estimation: amortized (NRE) and sequential (SNRE)

Neural Variational Inference, amortized (NVI) and sequential (SNVI)

Mixed Neural Likelihood Estimation (MNLE)

Feedback and Contributions

We welcome any feedback on how sbi is working for your inference problems (see Discussions) and are happy to receive bug reports, pull requests, and other feedback (see contribute). We wish to maintain a positive community; please read our Code of Conduct.

Acknowledgments

sbi is the successor (using PyTorch) of the delfi package. It started as a fork of Conor M. Durkan's lfi. sbi runs as a community project. See also credits.

Support

sbi has been supported by the German Federal Ministry of Education and Research (BMBF) through project ADIMEM (FKZ 01IS18052 A-D), project SiMaLeSAM (FKZ 01IS21055A) and the Tübingen AI Center (FKZ 01IS18039A).

License

Apache License Version 2.0 (Apache-2.0)

Citation

If you use sbi consider citing the sbi software paper, in addition to the original research articles describing the specific sbi-algorithm(s) you are using.

@article{tejero-cantero2020sbi,
  doi = {10.21105/joss.02505},
  url = {https://doi.org/10.21105/joss.02505},
  year = {2020},
  publisher = {The Open Journal},
  volume = {5},
  number = {52},
  pages = {2505},
  author = {Alvaro Tejero-Cantero and Jan Boelts and Michael Deistler and Jan-Matthis Lueckmann and Conor Durkan and Pedro J. Gonçalves and David S. Greenberg and Jakob H. Macke},
  title = {sbi: A toolkit for simulation-based inference},
  journal = {Journal of Open Source Software}
}

The above citation refers to the original version of the sbi project and has a persistent DOI. Additionally, new releases of sbi are citable via Zenodo, where we create a new DOI for every release.

sbi's People

Contributors

alvorithm avatar augustes avatar baschdl avatar bkmi avatar conormdurkan avatar coschroeder avatar danielmk avatar dgreenberg avatar famura avatar felixp8 avatar glouppe avatar gmoss13 avatar jan-matthis avatar janfb avatar jnsbck avatar jsvetter avatar julialinhart avatar louisrouillard avatar manuelgloeckler avatar matthijspals avatar michaeldeistler avatar milagorecki avatar plcrodrigues avatar ppjgoncalves avatar psteinb avatar rdgao avatar theogruner avatar tommoral avatar ziaeemehr avatar zinastef avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

sbi's Issues

SNPE-A

Problem: we do not yet have an implementation for the post-hoc correction of SNPE-A

  • We will have to see what features we require from the MDN itself and correspondingly rewrite the MDN class.
  • then implement the post-hoc correction
  • functional tests on linear and non-linear Gaussian

Use types pervasively

Problem

We use types in some function signatures for documentation, but unless one has IDE-level checking, they are not else taken-advantage of.

Solution

  • use types pervasively
  • activate mypy in the CI pipeline (Github Actions).
  • see also #75 for impact on docstrings.

SRE evaluating posterior density

Problem
Calling log_prob() on a Posterior instance fit with SRE explodes at unnormalized_log_prob = self.neural_net.log_prob(inputs, context) -- a classifier network does not implement log_prob.

As written in Conor's paper 2019, SRE can do evaluation if trained in a single round, but not if trained on multi-round.

see

Proposed solution

  • Add an attribute self.round to Posterior in sbi_posterior.py and assert at the beginning of log_prob() to only evaluate for SRE if self.round==1

  • Additionally, we have to implement the equation for evaluating the probability still: p(theta|x) = r(theta,x) * p(theta) (see screenshot above).

Inconsistent indication of types in docstrings

Problem

We have a variety of styles, including

  • Pyro-style (warranted where looking forward to near-term integration in Pyro)
  • name: type, explanation (for example - inputs: torch.tensor(), parameters theta)
  • name (type): explanation (for example - num_samples (int): desired number of samples). This is what PyTorch uses
  • variants of the above using braces { } or square brackets [ ] for the type

Suggested solution

Adopt the PyTorch convention across the board or elimitate types completely given that we're going to specify them in the signature already. This is what thinc does - loss.py example, and would be @Meteore 's favorite.

Action

  1. @jan-matthis @janfb @michaeldeistler: argue & vote in the comments.
  2. once decided, adjust if necessary the plugin that is generating the stubs for you.

FYI I get PyTorch-style docstring templates from VSCode's plugin autoDocstring by Nils Werner (0.4.0).

More functional tests

Add more functional tests.

For example, currently, all inference algorithms are tested purely with MAFs. We should ensure that everything also works with MDNs, e.g. single parameter or data dimension, uniform/Gaussian priors, MCMC sampling.

SNPE-B bug importance weights

There is a bug in SNPE-B. The importance weights are in log-space.

In line 116 in inference/snpe/snpe_b, instead of:

log_prob = self.calibration_kernel(context) * (
log_prob_posterior + log_prob_prior - log_prob_proposal
)

we need

log_prob = self.calibration_kernel(context) * (
log_prob_posterior * torch.exp(log_prob_prior - log_prob_proposal)
)

z-scoring for MDNs

Problem

Currently, z-scoring of the observations x is not implemented for MDNs in SNPE

Background

  • the z-scoring of observations is implemented in SnpeBase. There, we build an nn.Sequential called embedding, which consists of a z-scoring layer and a neural network to encode the observations x, see here.
  • embedding is then set as attribute of the neural_net itself (termed _neural_posterior.embedding_net), which is called at the beginning of every network pass, see here.
  • through this z-scoring is part of the network, and therefore needs to be built into MDNs.

Solution

  • once MDNs have an attribute embedding_net, we only have to get rid of
    if not isinstance(self._neural_posterior.neural_net, MultivariateGaussianMDN): here in SnpeBase
  • then, the function set_embedding_net() in the Posterior class should set the embedding net to the nn.Sequential described above and everything should work

Simulator base methods

Right now, simulators implement a number of methods. Besides simulate, those include get_ground_truth_parameters, get_ground_truth_observation, normalization_parameters, parameter_plotting_limits.

Ideally, however, a simulator can simply be a function that returns outputs given parameters. Ground truth is not always available, normalisation and plotting parameters should be inferred automatically.

SNPE-B functional test on nonlinear Gaussian

We currently have a functional test for SNPE-B on the linearGaussian, but we also want a test on the nonlinearGaussian. Current problem in implementing this is loading the target-samples due to a pickle error, see #86

Repeated boilerplate across Inference classes: derive from common parent

Problem

Some of the argument checking across SnpeBase, SNL and SRE is repeated code.

Solution

Refactor that commonality and others into a common parent class, NeuralInference.

Follow-up

Each of SNPE A, B, C only override/define one method: _get_log_prob_proposal_posterior(self, inputs, context, masks); consider dealing with it in SnpeBase directly - unless we dispatch elsewhere on concrete class.

Port summarizer for Lotka-Volterra from numpy to PyTorch

The modules lotka_volterra_pytorch and mg1_pytorch both import from a non-checked-in module, summarizers. By default they will try to summarize observations (https://github.com/mackelab/sbi/blob/7f62e3c23a1cc56ffd92f6355636e61934c80881/lfi/simulators/lotka_volterra_pytorch.py#L74, https://github.com/mackelab/sbi/blob/7f62e3c23a1cc56ffd92f6355636e61934c80881/lfi/simulators/mg1_pytorch.py#L23) and fail here.

These pytorch simulator varians do not seem to be in use at the moment.

User input checks

Problem

It is not clear what we are currently expecting from the user, e.g., what the requirements on the prior, the simulator and the observed data are:

  • return np lists or Tensor
  • always use batch dimensions?

Suggestion

Require the user to provide three inputs with certain properties:

  • prior: an pytorch.distributions.Distribution like object with methods .log_prob and .sample() with signatures as in pytorch
  • simulator: a function that takes in torch.Tensor and returns a torch.Tensor, both with batch dimension, even if it is only a single simulation.
  • observed data: a torch.Tensor with Size [1, data_dim].
  • we should assert those properties at the beginning

This will make the code more readable because dont have to take care of numpy arrays, appending dimensions etc.

Add functional tests for SNPE B

We need tests for snpe a and b on simple examples like linear Gaussian and nonlinear Gaussian.

Similar to the tests in tests/test_linearGaussian_apt.py.

Normalizing training samples

normalization is hard coded in simulators at the moment. we want to change that

  • run a set of pilot sims at the beginning, use them for normalization
  • should happen during inference, independent of simulator
  • ideally, simulator is just a function taking params, returning data

Return Posterior

After training, we should return the posterior and not require the user to access it via the inference object.

We depend on a custom numpy implementation of slice sampling

Problem

For performance reasons, we are using a home-grown slice sampler programmed by Conor in numpy. We don't know if this is strictly necessary and why.

Solution

Benchmark slice-np against the pyro samplers and determine whether and under which precise circumstances there is a performance gap between slice-pyro and slice-np. If possible, reduce our dependency to Pyro sampling methods, which are more capable and will be externally maintained. @jan-matthis

Follow-up work

  • If slice-np needs to stay, then make mcmc_method string be explicit about pyro vs. numpy (e.g. slice_pyro, not slice. Use underscores as they facilitate identification in parameterized pytests).
  • Refactor the sampling code, potentially the Posterior object would receive a sampler and not just a potential_function. @Meteore

Pickle error with nonlinear Gaussian tests

Problem: On MacOS (don't know about linux), I can not load the ground truth samples for the nonlinear Gaussian, getting the error: ValueError: Cannot load file containing pickled data when allow_pickle=False. When I explicitly set allow_pickle=True, I get that the npy file could not be identified as a pickle file.

Move ABC methods

  • for creating reference posteriors for functional tests we should habe rejection abc and abc smc in sbi as well.

MCMC samplers not robust to additional dimension in true_observation

Problem: When generating an inference object (no matter if SNPE, SRE, SNL), the user needs to provide a true_observation. In the current implementation, the MCMC samplers require the observation to be of shape (num_dim). If instead, the user provides (1, num_dim), the MCMC sampler in SRE will fail as it can not concatenate (theta, x) (which will be the input to the classifier.

Suggestion: In sbi_posterior, at the beginning of the sampling method, we check for the shape of context (which corresponds to true_observation in the inference classes) and squeeze the vector if necessary.

Robustness in summary writer

Not all simulators have ground truth parameters available for tracking in TensorBoard. Should update the summary writers to fail gracefully if they are not available.

Adopt pytorch-lightning

Problem

We don't want to write boilerplate logging code, multi-gpu, etc.

Solution

Adopt pytorch-lightning. We decided on lightning because we see it as more easily reversible experiment as compared with competitors such as Ignite or or Catalyst.

Follow-up

Deal with verbose destructuring via specific function or module-level setting of preferred device.

functional tests

Couple of open questions for functional tests:

  • how long should they take maximum?
  • at the moment we are running tests only for apt. should run for sre and snl as well?

Configuration management

Problem

We want to have flexibility to shift towards declarative configuration management. Some benefits:

  • expose every hyperparameter to enhance reproducibility
  • build a library/zoo of models
  • facilitate the creation of user models by modification of existing ones.

Solution

We considered: (a) roll-our-own and (b) Google's GIN.
(a) would lead us to reinvent, in particular, the function call syntax that thinc and GIN already provide. And create a non-core code liability.
(b) seems to be too complex / less well adapted for our objective.
Adopt the config module of thinc, possibly without depending on the entirety of thinc.

Follow-up work

  • deprecate get_sbi_posterior in favor of config mechanism - consider rename (density_estimator).

NOTE: we may want to start at pyknós level, defining nflows models via thinc.

Delete broken example notebooks

Problem

Refactorings resulted in non-working notebooks.

Solution

Make notebooks conform to new interfaces, when these are stabilized.

  • Define a milestone for (somewhat) stable interfaces.

SNPE-C non-atomic

We want APT loss for MoGs without atomic proposals. That is, we use Gaussian formulas etc. to solve the integrals defining Z
https://arxiv.org/pdf/1905.07488.pdf
Appendix A1

We want to use MoG proposals and MoG posteriors. Supporting single-component Gaussians with simpler/faster code might also be desirable.

Multiprocessing and MCMC sampling

Problem: When using pyro MCMC samplers, there is an option to parallelize with multiple mcmc chains, using the variable num_chains in sbi_posterior. By default, this variable takes the available number of cores.

However, there is an issue when using the multiprocessing package: one needs to call the function from a main, that is initiate running the code with
if __name__ == '__main__':

See image below for reference:
Screen Shot 2020-03-18 at 1 56 04 pm

When using pytest, this is not done and the tests break.

Troubleshooting: The scripts run through when either running it from a main function as written above, or when running it from a jupyter notebook (or lab).

Suggestion: I believe we have two options: 1) We set the number of chains as a parameter and set it to num_chains=1 when using pytest. In most other cases, users should use either jupyter or a main.
2) We call a main from pytest. I have not tested whether this can be done.
Aside, it would be best if we could catch if users do not call the function either from a notebook or from a main() and give a warning if not, but not sure how one would do that.

Other issues: 1) self.true_observation needs different shapes when using a single chain (batch dimension needed) then when using multiple chains (no batch dim needed)
2) In the latest commit, the MMD sometimes exceeds the threshold when using multiple chains. Probably a bug there.
3) potential_function should consistently be potential_fn for all pyro samplers (currently wrong for slice sampler.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.