GithubHelp home page GithubHelp logo

lollcat / fab-torch Goto Github PK

View Code? Open in Web Editor NEW
39.0 2.0 6.0 272.5 MB

Flow Annealed Importance Sampling Bootstrap (FAB). ICLR 2023.

License: MIT License

Python 86.16% Jupyter Notebook 13.84%
annealed-importance-sampling boltzmann-distribution normalizing-flow boltzmann-generator

fab-torch's Introduction

Flow Annealed Importance Sampling Bootstrap (FAB)

Overview

Code for the paper Flow Annealed Importance Sampling Bootstrap (FAB).

FAB in JAX: See the JAX implementation of the FAB algorithm in the fab-jax repo. The fab-jax code is cleaner, faster and easier to use - hence we recommend it over the fab-torch code. Additionally, the fab-jax code applies FAB to some new problems, including the commonly used, challenging, 1600 dimensional log Gaussian Cox process [Møller et al., 1998, Arbel et al., 2021, Mathews et al., 2022, Zhang et at., 2023].

See About the code for further details on how to use the FAB codebase on new problems. Please contact us if you need any help running the code and replicating our experiments.

Methods of Installation

The package can be installed via pip by navigating in the repository directory and running

pip install --upgrade .

To run the alanine dipeptide experiments, you will need to install the OpenMM Library as well as openmmtools. This can be done via conda.

conda install -c conda-forge openmm openmmtools

Experiments

NB: See README within experiments/{problem-name} for further details on training and evaluation for each problem.

NB: Quickstart notebooks are simply to get up and running with the code with some visualisation of results after a little bit of training. To replicate the results from the paper run the python commands described below.

Gaussian Mixture Model

Quickstart (NB just for getting started, to replicate results from paper see python command below)

Open In Colab

For this problem we use a mixture of 40 two dimensional Gaussian distributions. This allows for easy visualisation of the various methods for training the flow. We provide a colab notebook with an example of training a flow on the GMM problem, comparing FAB to training a flow with KL divergence minimisation. This can be run in a short period of time (10 min) and provides a clear visualisation of how FAB is able to discover new modes and fit them.

To run the experiment for the FAB with a prioritised replay buffer (for the first seed), use the following command:

python experiments/gmm/run.py training.use_buffer=True training.prioritised_buffer=True

To run the full set of experiments see the README for the GMM experiments.

The below plot shows samples from various trained models, with the GMM problem target contours in the background. Gaussian Mixture Model samples vs contours

Many Well distribution

Quickstart (NB just for getting started, to replicate results from paper see python command below)

Open In Colab

The Many Well distribution is made up of multiple repeats of the Double Well distribution, from the original Boltzmann generators paper.

We provide a colab notebook comparing FAB to training a flow via KL divergence minimisation, on the 6 dimensional Many Well problem, where the difference between the two methods is apparent after a short (<5 min) training period. This experiment can be run locally on a laptop using just CPU.

Additionally, we provide the colab notebook Open In Colab which demos inference with the flow trained with FAB (+prioritised buffer) on the 32 dim Many Well problem.

To run the experiment for the FAB with a prioritised replay buffer (for the first seed) on the 32 dimensional Many Well problem, use the following command:

python experiments/many_well/run.py training.use_buffer=True training.prioritised_buffer=True

To run the full set of experiments see the README for the Many Well experiments.

The below plot shows samples for our model (FAB) vs training a flow by reverse KL divergence minimisation, with the Many Well problem target contours in the background. This visualisation is for the marginal pairs of the distributions for the first four elements of the x. Many Well distribution FAB vs training by KL divergence minimisation

Alanine dipeptide

In our final experiment, we approximate the Boltzmann distribution of alanine dipeptide in an implicit solvent, which is a molecule with 22 atoms and a popular model system. The molecule is visualized in the figure below. The right figure shows the probability density of for the dihedral angle $\phi$ comparing the ground truth, which was obtrained with a molecular dynamics (MD) simulation, the models trained with our method as well as maximum likelihood on MD samples.

Alanine dipeptide and its dihedral angles; Comparison of probability densities

Furthermore, we compared the Ramachandran plots of the different methods in the following figure.

Ramachandran plot of alanine dipeptide

The weights for the flow model trained with FAB are available on huggingface. Additionally, we provide the colab notebook Open In Colab which demos inference with this trained model.

To reproduce our experiment, use the experiments/aldp/train.py script. The respective configuration files are located in experiments/aldp/config. We used the seeds 0, 1, and 2 in our runs.

The data used to evaluate our models and to train the flow model with maximum likelihood is provided on Zenodo. If you want to use the configuration files in experiments/aldp/config as is, you should put the data in the experiment/aldp/data folder.

DOI

About the code

The main FAB loss can be found in core.py, and we provide a simple training loop to train a flow with this loss (or other flow - loss combinations that meet the spec) in train.py The FAB training algorithm with the prioritised buffer can be found in train_with_prioritised_buffer.py. Additionally, we provide the code for running the SNR/dimensionality analysis with p and q set to independent Gaussians. in the fab-jax-old repository. For training the CRAFT model on the GMM problem we forked the Annealed Flow Transport repository. This fork may be found here, and may be used for training the CRAFT model.

As we are still adding improvements to the efficiency and stability of the code, make sure you use the latest version. Additionally, if you spot any areas of the code that could be improved then make an issue and we will be more than happy to fix it. For the version of the code that was used in the paper see our releases.

Applying FAB to a new problem:

The most important thing to get right when applying FAB to a given problem is to make sure that AIS is returning reasonable samples, where by reasonable we just mean that the samples from AIS are closer to the target than the flow. Simply visualising the samples from the flow and AIS provides a good check for whether this is the case. Making sure that the transition kernel (e.g. HMC) is working well (e.g. has well tuned step size) is key for AIS to work well.

An additional source of instability can be if the target energy function gives spurious values to points that have extreme values. For example, evaluating the density of a zero-mean unit variance Gaussian on a point that has a value of 100 will give a spurious values. One can fix this by manually setting the log prob of the target to be -inf for regions that are known to be far outside of where samples from the target lie.

Feel free to contact us if you would like any help getting FAB to work nicely!

Normalizing Flow Libraries

We offer a simple wrapper that allows for various normalising flow libraries to be plugged into this repository. The main library we rely on is normflows.

Citation

If you use this code in your research, please cite it as:

Laurence I. Midgley, Vincent Stimper, Gregor N. C. Simm, Bernhard Schölkopf, José Miguel Hernández-Lobato. Flow Annealed Importance Sampling Bootstrap. The Eleventh International Conference on Learning Representations. 2023.

Bibtex

@inproceedings{
midgley2023flow,
title={Flow Annealed Importance Sampling Bootstrap},
author={Laurence Illing Midgley and Vincent Stimper and Gregor N. C. Simm and Bernhard Sch{\"o}lkopf and Jos{\'e} Miguel Hern{\'a}ndez-Lobato},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=XCTVFJwS9LJ}
}

fab-torch's People

Contributors

lollcat avatar thargreaves avatar vincentstimper avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

fab-torch's Issues

add val.pt to experiments/aldp/data

I would like to play around with the experiments for the alanine dipeptide.
In order to use the train.py file in experiments/aldp, I need the val.pt file that is used for evaluate_aldp.

It would be really helpful if this data could also be provided, so that I can run the code on my local machine

fix: ais geometric spacing

Fix bug in geometric spacing for ais. This doesn't currently effect any of the experiments, as they use linear spacing.

feat: clean configs

Configs for dw4 and many well should be made cleaner.

  • Comments explaining hyper-parameter choices.
  • Better structuring, e.g. specifiying buffer length in terms of number of batches.
  • Remove un-used arguments (e.g. remove separate options for use_buffer and prioritised_buffer)

Add HMC mass matrix

  • Add option to specify the mass matrix, instead of always assuming unit mass.
  • Clean up HMC code, making it more modular
  • Add an option for evaluation mode where none HMC's parameters are tuned.
  • Add example notebook visualising AIS, including the effect of the number of distributions on the sample size.

BNNs

Create BNN target problem

Code efficiency improvements

  • Allow transition operator to be an optional argument
  • Re-use flow and target evaluations (and \nabla_x log prob(x)) in AIS
  • Clean up config files to make running experiments easier

Add defensive importance sampling

  • Use a mixture distribution with a flow, and a "defensive distribution".
  • The defensive distribution should help with the regions where the flow misses important parts of the target (initially during training) or if the tail of the flow is overly light in certain regions.

Improve sample efficiency

Look into using replay memory (saving samples, log weights and target log probs) in a large data-structure and re-using them), and/or PPO style re-use of samples - as using samples once per the current approach seems super inefficient.

Jupyter/Colab Notebook for Alanine Dipeptide

  • Simplified version of the alanine dipeptide training script as an easy to use jupyter notebook that can run on Google Colab
  • Should produce some meaningful results after ~20min-2h of training

GMM improvements

The following improvements can be made to the GMM problem

  • Set the metropolis step size constant (initialised to 1). Currently this is being tuned with a low target_p_accept which causes the step size to become quite big, meaning a lot of points from AIS are not useful for updating the flow. We could set target_p_accept to something more reasonable (e.g. 0.65), but fixing it to the constant initialised value simplifies things nicely, and we should still get good performance. This also is good for the comparison to SNFs, as our current implementation of SNFs has no step size tuning - so using a fixed step size would allow us to have identical MCMC transition kernels for FAB vs SNFs.
  • Look into numerical instability in the buffer: Currently we sometimes get numerical instability where a large importance weight causes other importance weights to go to zero. In the other problems this isn't much of an issue, but in the GMM problem close to initialisation the flow often places very low probability on some of the modes, causing high importance weights.
  • The metropolis mcmc transition kernel does not have an option to turn off step size tuning for all of training, as set_eval_mode=False results in the step size being tuned even if adjust_step_size was initially set to False.
  • The evaluation scripts can also be made to be cleaner, and run by default at the end of training.

Create sample based test set for many well problem.

Currently, have manually placed points on distributions modes. Additionally we can create a 2D test-set via MCMC and then sample from this for pairs of dimensions for higher dimensional many well problems to get samples approximately from p(x).

Cleaning

Add linting, typing and documentation for all functions

feat: many well evaluation

For Many Well problem allow evaluation of samples and log weights, without having to specify log_prob_fn

Benchmarking

Currently there are many decisions that can be made in the algorithm - these should be benchmarked in simple tests to get a good idea of their effects. This includes:

Testing various versions of the loss

  • Standard FAB loss: bootstrap estimate of alpha divergence lower bound with alpha=2
  • Could also estimate forward KL divergence, and this will have an equation with a simpler form.

Use exponential moving average of normalisation constant
When calculating fab-loss, we can (1) use the unnormalise log weights returned by AIS, as this will still give an expectation proportional to alpha divergence with alpha=2, or (2) we can normalise using the current batch of weights, or (3) we could use an exponential moving average of the normalisation constant, calculated during the training process.

  • this also applies to the log_w of the flow model (and not just the log_w of AIS)

Currently we are doing (2), but it may be better to do (3), and it is worth comparing the performance of all three approaches.

Testing various transition operators

  • Metropolis
  • Hamiltonian Monte Carlo (with various settings for tuning the step size)

Performance bottlenecks
Which parts of FAB are the slowest - can we add JIT to speed these up?

Implement more losses

Add the following alternative losses:

  • Maximise log prob of AIS samples
  • Importance weighted estimate of forward KL with AIS samples

Run tests on double well / many well & GMM problems.

feat: Add valid samples abstraction

Make it easy for user to enter criterion for samples to be valid.
Automatically filter invalid samples from the buffer.
By default this can be if the samples are out of a bounds, or the target/flow log prob is infinate/nan.
Currently this is done in the buffer direclty, but it would be better to expose the option to the user so that they can control it, and so that it's effects are clear.
For alanine dipeptide this could be used to optionally filter based on chirality.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.