GithubHelp home page GithubHelp logo

mackelab / amortized-decision-making Goto Github PK

View Code? Open in Web Editor NEW
3.0 1.0 0.0 14.59 MB

Repository for the paper "Amortized Bayesian Decision Making for simulation-based models" - Mila @milagorecki and Michael @michaeldeistler

Jupyter Notebook 99.19% Python 0.81%

amortized-decision-making's Introduction

Amortized Bayesian Decision Making for Simulation-based Models

This repository provides the implementation used for the paper Amortized Bayesian Decision Making for simulation-based models. For the full git commit history with correct time stamps, see the branch paper.

In this work we address the question of how to perform Bayesian decision making on stochastic simulators, and how one can circumvent the need to compute an explicit approximation to the posterior. We propose two methods to obtain amortized Bayesian decisions:

(1) NPE-MC uses the parametric posterior returned by neural posterior estimation (NPE, see sbi) and computes the expected cost of decisions by Monte Carlo sampling from the posterior approximation (NPE-MC).

(2) BAM circumvents the need to learn the (potentially high-dimensional) posterior distribution explicitly and instead only requires to train a feedforward neural network which is trained to directly predict the expected costs for any data and action.

Installation

To install, clone the project and run

pip install . 

Dependencies are listed in pyproject.toml and will be installed automatically when installing the package.

Data generation

Both methods use a dataset that is generated in the following way:

  1. sample parameters from the prior, $\theta\sim p(\theta)$
  2. simulate $\theta$ to obtain observations, $x\sim p(x|\theta)$

In order to generate a dataset of (parameter, data) pairs for training BAM or NPE-MC, run for example

python generate_data.py --task toy_example --type continuous --ntrain 500 --ntest 500

During the training of BAM, a third step is performed in every epoch:

  1. sample actions $a\sim p(a)$ and compute the ground truth costs $c(\theta, a)$ for every (parameter, data) pair

This way, we can train a feedforward network to regress onto the expected costs of taking action $a$ when $x$ is observed.

Training BAM and NPE

To train BAM, run for example:

python train_nn.py task.name=toy_example action=continuous seed=0

To train NPE, run for example:

python train_npe.py task.name=toy_example action=continuous model=npe seed=0

Results

We systematically evaluate the performance of NPE-MC and BAM on four tasks where the ground truth posterior is available. We used a synthesized toy example introduced in the paper and three previously published simulators with ground truth posteriors (see Simulation-Based Inference Benchmark).

We demonstrate that BAM can largely improve accuracy over NPE-MC on challenging benchmark tasks that both methods can infer decisions for the Bayesian Virtual Epileptic Patient.

amortized-decision-making's People

Contributors

milagorecki avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar

amortized-decision-making's Issues

[Code Quality] Incomplete Repo information

Your repo does not comply to the standards we defined in the lab.
Make sure for your repo to have:

  • a description including the github handle of the owner
  • a > 3 line README.md

If you don't update your repo, it will be disabled and then archived.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.