GithubHelp home page GithubHelp logo

kunni918 / posteriors Goto Github PK

View Code? Open in Web Editor NEW

This project forked from normal-computing/posteriors

0.0 0.0 0.0 5.32 MB

Uncertainty quantification with PyTorch

Home Page: https://normal-computing.github.io/posteriors/

License: Apache License 2.0

Python 100.00%

posteriors's Introduction

logo

Installation | Quickstart | Methods | Friends | Contributing | Documentation

What is posteriors?

General purpose python library for uncertainty quantification with PyTorch.

  • Composable: Use with transformers, lightning, torchopt, torch.distributions and more!
  • Extensible: Add new methods! Add new models!
  • Functional: Easier to test, closer to mathematics!
  • Scalable: Big model? Big data? No problem!
  • Swappable: Swap between algorithms with ease!

Installation

posteriors is available on PyPI and can be installed via pip:

pip install posteriors

Quickstart

posteriors is functional first and aims to be easy to use and extend. Let's try it out by training a simple model with variational inference:

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch import nn, utils, func
import torchopt
import posteriors

dataset = MNIST(root="./data", transform=ToTensor())
train_loader = utils.data.DataLoader(dataset, batch_size=32, shuffle=True)
num_data = len(dataset)

classifier = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 10))
params = dict(classifier.named_parameters())


def log_posterior(params, batch):
    images, labels = batch
    images = images.view(images.size(0), -1)
    output = func.functional_call(classifier, params, images)
    log_post_val = (
        -nn.functional.cross_entropy(output, labels)
        + posteriors.diag_normal_log_prob(params) / num_data
    )
    return log_post_val, output


transform = posteriors.vi.diag.build(
    log_posterior, torchopt.adam(), temperature=1 / num_data
)  # Can swap out for any posteriors algorithm

state = transform.init(params)

for batch in train_loader:
    state = transform.update(state, batch)

Observe that posteriors recommends specifying log_posterior and temperature such that log_posterior remains on the same scale for different batch sizes. posteriors algorithms are designed to be stable as temperature goes to zero.

Further, the output of log_posterior is a tuple containing the evaluation (single-element Tensor) and an additional argument (TensorTree) containing any auxiliary information we'd like to retain from the model call, here the model predictions. If you have no auxiliary information, you can simply return torch.tensor([]) as the second element. For more info see torch.func.grad (with has_aux=True) or the documentation.

Check out the tutorials for more detailed usage!

Methods

posteriors supports a variety of methods for uncertainty quantification, including:

With full details available in the API documentation.

posteriors is designed to be easily extensible, if you're favorite method is not listed above, raise an issue and we'll see what we can do!

Friends

Interfaces seamlessly with:

The functional transform interface is strongly inspired by frameworks such as optax and blackjax.

As well as other UQ libraries fortuna, laplace, numpyro, pymc and uncertainty-baselines.

Contributing

You can report a bug or request a feature by creating a new issue on GitHub.

If you want to contribute code, please check the contributing guide.

posteriors's People

Contributors

kaelandt avatar samduffield avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.