GithubHelp home page GithubHelp logo

mi2g / accelerated-langevin-imla Goto Github PK

View Code? Open in Web Editor NEW
2.0 1.0 1.0 29.8 MB

Code for the paper "Accelerated Bayesian imaging by relaxed proximal-point Langevin sampling" (https://arxiv.org/abs/2308.09460).

License: GNU General Public License v3.0

MATLAB 52.04% Python 47.96%

accelerated-langevin-imla's Introduction

Code for the paper "Accelerated Bayesian imaging by relaxed proximal-point Langevin sampling"

by Teresa Klatzer, Paul Dobson, Yoann Altmann, Marcelo Pereyra, Jesús María Sanz-Serna, Konstantinos C. Zygalakis https://arxiv.org/abs/2308.09460

For any question, help needed, or problem found with the source code please contact us.

Abstract

This paper presents a new accelerated proximal Markov chain Monte Carlo methodology to perform Bayesian inference in imaging inverse problems with an underlying convex geometry. The proposed strategy takes the form of a stochastic relaxed proximal-point iteration that admits two complementary interpretations. For models that are smooth or regularised by Moreau-Yosida smoothing, the algorithm is equivalent to an implicit midpoint discretisation of an overdamped Langevin diffusion targeting the posterior distribution of interest. This discretisation is asymptotically unbiased for Gaussian targets and shown to converge in an accelerated manner for any target that is κ-strongly log-concave (i.e., requiring in the order of √ κ iterations to converge, similarly to accelerated optimisation schemes), comparing favorably to [M. Pereyra, L. Vargas Mieles, K.C. Zygalakis, SIAM J. Imaging Sciences, 13,2 (2020), pp. 905-935] which is only provably accelerated for Gaussian targets and has bias. For models that are not smooth, the algorithm is equivalent to a Leimkuhler–Matthews discretisation of a Langevin diffusion targeting a Moreau-Yosida approximation of the posterior distribution of interest, and hence achieves a significantly lower bias than conventional unadjusted Langevin strategies based on the Euler-Maruyama discretisation. For targets that are κ-strongly log-concave, the provided non-asymptotic convergence analysis also identifies the optimal time step which maximizes the convergence speed. The proposed methodology is demonstrated through a range of experiments related to image deconvolution with Gaussian and Poisson noise, with assumption-driven and data-driven convex priors.

Languages

Programming languages are Python (Motion deblurring experiments) Matlab (Poisson experiments).

Images

Images are taken from the Berkeley BSDS300 data set from this subset.

Preparations

To run the Matlab code, you'll need to install the library included as a submodule in libs/L-BFGS-B-C. Detailed instructions can be found in the original repository by Stephen Becker here. The sampling code requires the parallel toolbox, but this is not essential.

To run the Python code, you'll be required to install Python 3.9 and the following packages

torch
torchmetrics
tqdm
matplotlib
mpl_toolkits
hdf5storage
PIL
numpy
scipy

We recommend the use of CUDA, but it is not required. Watch out for statements like device = 'cuda:0' and replace them by device = 'cpu' as required.

Further, we make use of the sampling-tools package, which can be retrieved and installed from here.

In order to use the convex ridge regularizer within the motion deconvolution experiments, I have created a fork here and an installable package. Download the repository and

$ cd cvx_nn_models
$ pip install .

The package can then be imported in Python using import cvx_nn_models.

Motion deconvolution experiments

In the scripts, you can choose the respective experiment (castle, person, lizard) by setting a configuration parameter. The script contains specific hyperparameters for each image and select different blur kernels (see paper for details).

To run the sampling using IMLA, run motion_deconvolution/deblur_imla_motion.py.

To run the sampling using SKROCK, run motion_deconvolution/deblur_skrock_motion.py.

To run the sampling using ULA, run motion_deconvolution/deblur_ula_motion.py.

To run the sampling using PnP-ULA, run motion_deconvolution/PnP_ULA_motion.py.

To compute the MAP solution, run motion_deconvolution/deblur_map.py.

Poisson experiments

To run the sampling using the Reflected Implicit Midpoint Algorithm (R-IMLA), run poisson/grid_rimla.m. Results will be saved in an automatically created directory called poisson/results.

To run the sampling using the Reflected SKROCK algorithm, run poisson/grid_rskrock.m.

To run the sampling using Reflected MYULA, run poisson/poisson_deblurring_TV_rmyula.m.

To run the sampling using Reflected PMALA, run poisson/poisson_deblurring_TV_rpmala.m.

Evaluation scripts and required chains / data to reproduce figures in the paper available upon request.

1D experiments

Code to sample distributions for Figure 3 can be found in one_d_examples/one_d_prox.m.

Citation

If you find our code helpful in your research or work, please cite our paper.

Funding

We acknowledge funding from projects BOLT, BLOOM and LEXCI: This work was supported by the UK Research and Innovation (UKRI) Engineering and Physical Sciences Research Council (EPSRC) grants EP/V006134/1 , EP/V006177/1 and EP/T007346/1, EP/W007673/1 and EP/W007681/1. JMS has been funded by Ministerio de Ciencia e Innovación (Spain) through project PID2022-136585NB-C21, MCIN/AEI/10.13039/501100011033/FEDER, UE.

accelerated-langevin-imla's People

Contributors

freyyia avatar

Stargazers

 avatar  avatar

Watchers

 avatar

Forkers

yanxiongbin

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.