GithubHelp home page GithubHelp logo

dereypl / rbm-for-mnist Goto Github PK

View Code? Open in Web Editor NEW

This project forked from juanpablo2310/rbm-for-mnist

0.0 0.0 0.0 398 KB

Tensorflow implementation of a Restricted Boltzmann Machine

Python 3.19% Jupyter Notebook 96.81%

rbm-for-mnist's Introduction

Restricted Boltzmann Machine (RBM) for MNIST reconstruction

TensorFlow implementation of a Restricted Boltzmann Machine (RBM) for MNIST digits reconstruction.

Check this video for some background.

Requirements

RBM Graphical Model

rbm

Restricted Boltzmann Machines (RBMs) are a class of undirected probabilistic graphical models containing a layer of observable variables and a single layer of latent variables. In RBMs, there are no connections within a layer.

The whole system (hidden and visible nodes) is described by an energy function:

  • E(v,h) = -v^{T}Wh -v^{T}b - h^{T}c

As in statistical physics, high-energy configurations are less probable. The joint probability distribution is defined as:

  • p(v,h) = e^{-E(v,h)}/Z where Z is the partition function (intractable)

Our goal is to learn the joint probability distribution that maximizes the probability over the data, also known as likelihood.

  • p(v) = sum_{h}p(v,h} = e^{-F(v)}/Z where F(v) is called Free Energy

Inference

The Conditional distribution factorizes (no intra layer connections):

  • p(h_{j}=1|v) = p(h_{j}=1, v) / ( p(h_{j}=0, v) + p(h_{j}=1, v) ) = sigmoid(c_{j}+v^{T}W_{:j})

  • p(v_{i}=1|h) = sigmoid(b_{i}+W_{i:}h)

Learning

The parameters of our model are the weights W and the biases b, c.

Maximizing the log-Likelihood

Derive log-likelihood and gradient formulas. (TODO)

it is impractical to compute the exact log-likelihood gradient (expectation of the joint distribution).

Contrastive divergence

Idea:

  1. Replace the expectation by a point estimate at v'
  2. Obtain the point v' by Gibbs Sampling
  3. Start sampling chain at v(t)

1-step divergence:

  • Positive divergence: $h(v)v^{T}$
  • Negative divergence: $h(v')v'^{T}$ where v' is reconstructed from a sample from h(v)

Pseudocode:

  1. For each training example v(t):

    i. Generate a negative sample v' using k steps of Gibbs Sampling, starting at v(t)

    ii. Update parameters

     $w_{new} = w_{old} + \epsilon * (h(v(t))v(t)^{T}-h(v')v'^{T}) $
    
     $b_{new} = b_{old} + \epsilon * (h(v(t))-h(v'))$
    
     $c_{new} = c_{old} + \epsilon * (v(t)-v')$
    
  2. Go back to 1. until stoppng criteria

The following figure is a representation of the feature detectors. The hidden nodes encode a lower dimensional representation of the data (visible nodes).

bernoulli_ft

Usage

Run the main.ipynb file in jupyter

Results

Under progress

recon2

Extensions

Deep Boltzmann Machines and Deep Belief Networks.

Contrastive Divergence k (for k>1 step of MCMC simulation) w/ weight cost or temperature [Tieleman 08]. video for Persistent Contrastive Divergence.

...

rbm-for-mnist's People

Contributors

micheldeudon avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.