GithubHelp home page GithubHelp logo

jwyang / probtorch Goto Github PK

View Code? Open in Web Editor NEW

This project forked from probtorch/probtorch

0.0 3.0 0.0 7.43 MB

Probabilistic Torch is library for deep generative models that extends PyTorch

License: Apache License 2.0

Python 99.84% Shell 0.16%

probtorch's Introduction

Probabilistic Torch is library for deep generative models that extends PyTorch. It is similar in spirit and design goals to Edward and Pyro, sharing many design characteristics with the latter.

The design of Probabilistic Torch is intended to be as PyTorch-like as possible. Probabilistic Torch models are written just like you would write any PyTorch model, but make use of three additional constructs:

  1. A library of reparameterized distributions that implement methods for sampling and evaluation of the log probability mass and density functions

  2. A Trace data structure, which is both used to instantiate and store random variables.

  3. Objective functions that approximate the lower bound on the log marginal likelihood using Monte Carlo and Importance-weighted estimators.

This repository accompanies the NIPS 2017 paper:

@inproceedings{narayanaswamy2017learning,
    title = {Learning Disentangled Representations with Semi-Supervised Deep Generative Models},
    author = {Narayanaswamy, Siddharth and Paige, T. Brooks and van de Meent, Jan-Willem and Desmaison, Alban and Goodman, Noah and Kohli, Pushmeet and Wood, Frank and Torr, Philip},
    booktitle = {Advances in Neural Information Processing Systems 30},
    editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
    pages = {5927--5937},
    year = {2017},
    publisher = {Curran Associates, Inc.},
    url = {http://papers.nips.cc/paper/7174-learning-disentangled-representations-with-semi-supervised-deep-generative-models.pdf}
}

Contributors

(in order of joining)

  • Jan-Willem van de Meent
  • Siddharth Narayanaswamy
  • Brooks Paige
  • Alban Desmaison
  • Alican Bozkurt
  • Amirsina Torfi

Installation

  1. Install PyTorch [instructions]
  2. Clone this repository
git clone [email protected]:probtorch/probtorch.git
  1. Refer to the examples/ subdirectory for Jupyter notebooks that illustrate usage.

  2. To build and read the API documentation, please do the following

cd docs
pip install -r requirements.txt
make html
open build/html/index.html

Mini-Tutorial: Semi-supervised MNIST

Models in Probabilistic Torch define variational autoencoders. Both the encoder and the decoder model can be implemented as standard PyTorch models that subclass nn.Module.

In the __init__ method we initialize network layers, just as we would in a PyTorch model. In the forward method, we additionally initialize a Trace variable, which is a write-once dictionary-like object. The Trace data structure implements methods for instantiating named random variables, whose values and log probabilities are stored under the specifed key.

Here is an implementation for the encoder of a standard semi-supervised VAE, as introduced by Kingma and colleagues [1]

import torch
import torch.nn as nn
import probtorch

class Encoder(nn.Module):
    def __init__(self, num_pixels=784, num_hidden=50, num_digits=10, num_style=2):
        super(self.__class__, self).__init__()
        self.h = nn.Sequential(
                    nn.Linear(num_pixels, num_hidden), 
                    nn.ReLU())
        self.y_log_weights = nn.Linear(num_hidden, num_digits)
        self.z_mean = nn.Linear(num_hidden + num_digits, num_style)
        self.z_log_std = nn.Linear(num_hidden + num_digits, num_style)
    
    def forward(self, x, y_values=None, num_samples=10):
        q = probtorch.Trace()
        x = x.expand(num_samples, *x.size())
        if y_values is not None:
            y_values = y_values.expand(num_samples, *y_values.size())
        h = self.h(x)
        y = q.concrete(self.y_log_weights(h), 0.66,
                       value=y_values, name='y')
        h2 = torch.cat([y, h], -1)
        z = q.normal(self.z_mean(h2), 
                     torch.exp(self.z_log_std(h2)), 
                     name='z')
        return q

In the code above, the method q.concrete samples or observes from a Concrete/Gumbel-Softmax relaxation of the discrete distribution, depending on whether supervision values y_values are provided. The method q.normal samples from a univariate normal.

The resulting trace q now contains two entries q['y'] and q['z'], which are instances of a RandomVariable class, which stores both the value and the log probability associated with the variable. The stored values are now used to condition execution of the decoder model:

from torch.autograd import Variable

def binary_cross_entropy(x_mean, x, EPS=1e-9):
    return - (torch.log(x_mean + EPS) * x + 
              torch.log(1 - x_mean + EPS) * (1 - x)).sum(-1)

class Decoder(nn.Module):
    def __init__(self, num_pixels=784, num_hidden=50, num_digits=10, num_style=2):
        super(self.__class__, self).__init__()
        self.num_digits = num_digits
        self.h = nn.Sequential(
                   nn.Linear(num_style + num_digits, num_hidden),
                   nn.ReLU())
        self.x_mean = nn.Sequential(
                        nn.Linear(num_hidden, num_pixels),
                        nn.Sigmoid())

    def forward(self, x, q=None):
        if q is None:
            q = probtorch.Trace()
        p = probtorch.Trace()
        y = p.concrete(Variable(torch.zeros(x.size(0), self.num_digits)), 0.66,
                       value=q['y'], name='y')
        z = p.normal(0.0, 1.0, value=q['z'], name='z')
        h = self.h(torch.cat([y, z], -1))
        p.loss(binary_cross_entropy, self.x_mean(h), x, name='x')
        return p

The model above can be used both for conditioned forward execution, but also for generation. The reason for this is that q[k] returns None for variable names k that have not been instantiated.

To train the model components above, probabilistic Torch provides objectives that compute an estimate of a lower bound on the log marginal likelihood, which can now be maximized with standard PyTorch optimizers

from probtorch.objectives.montecarlo import elbo
from random import rand
# initialize model and optimizer
enc = Encoder()
dec = Decoder()
optimizer =  torch.optim.Adam(list(enc.parameters())
                              + list(dec.parameters()))
# define subset of batches that will be supervised
supervise = [rand() < 0.01 for _ in data]
# train model for 10 epochs
for epoch in range(10):
    for b, (x, y) in data:
        x = Variable(x)
        if supervise[b]:
            y = Variable(y)
            q = enc(x, y)
        else:
            q = enc(x)
        p = dec(x, q)
        loss = -elbo(q, p, sample_dim=0, batch_dim=1)
        loss.backward()
        optimizer.step()

For a more details, see the Jupyter notebooks in the examples/ subdirectory.

References

[1] Kingma, Diederik P, Danilo J Rezende, Shakir Mohamed, and Max Welling. 2014. “Semi-Supervised Learning with Deep Generative Models.” http://arxiv.org/abs/1406.5298.

probtorch's People

Contributors

jwvdm avatar iffsid avatar alicanb avatar astorfi avatar

Watchers

James Cloos avatar Jianwei Yang avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.