GithubHelp home page GithubHelp logo

shadowkun / stacked_capsule_autoencoders Goto Github PK

View Code? Open in Web Editor NEW

This project forked from akosiorek/stacked_capsule_autoencoders

0.0 0.0 0.0 101 KB

License: Apache License 2.0

Python 96.13% Shell 3.87%

stacked_capsule_autoencoders's Introduction

Stacked Capsule Autoencoders

This code is reposted from the official google-research repository.

This is a Tensorflow implementation of the Stacked Capsule Autoencoder (SCAE), which was introduced in the in the following paper: A. R. Kosiorek, Sara Sabour, Y. W. Teh, and Geoffrey E. Hinton, "Stacked Capsule Autoencoders".

  • Author: Adam R. Kosiorek, Oxford Robotics Institute & Department of Statistics, University of Oxford
  • Email: adamk(at)robots.ox.ac.uk
  • Webpage: http://akosiorek.github.io/

This work was done during Adam's internship at Google Brain in Toronto.

About the project

If you look at natural images containing objects, you will quickly see that the same object can be captured from various viewpoints. Capsule Networks are specifically designed to be robust to viewpoint changes, which makes learning more data-efficient and allows better generalization to unseen viewpoints. This project introduces a novel unsupervised version of Capsule Networks called Stacked Capsule Autoencoders (SCAE). Unlike in the original Capsules, SCAE is a generative model with an affine-aware decoder. This forces the encoder to learn image representation that is equivariant to viewpoint changes, and which leads to state-of-the-art unsupervised classification performance on MNIST and SVHN. For a more detailed description please have a look at the paper or at Adam's blog.

TSNE embeddings of object-capsule presence probabilities on MNIST digits.

Fig 1: TSNE embeddings of object-capsule presence probabilities on MNIST digits color-coded by digit class.

Dependencies

If you execute setup_virtualenv.sh, it will create a virtual environment and install all required dependencies. Alternatively, you can install all the dependencies using pip install -r requirements.txt. You can also manually install Tensorflow v1.15 and the following dependencies:

  • absl_py>=0.8.1
  • imageio>=2.6.1
  • matplotlib>=3.0.3
  • monty>=3.0.2
  • numpy>=1.16.2
  • Pillow>=6.2.1
  • scikit_learn>=0.20.4
  • scipy>=1.2.1
  • dm_sonnet==1.35
  • tensorflow==1.15.0
  • tensorflow_probability==0.8.0
  • tensorflow_datasets==1.3.0

The current implementation is not compatible with tensorflow==2.0.

Understanding the Code

  • The training loop is defined in train.py.
  • The model is built in capsules/configs/model_config.py.
  • The part capsule encoder and decoder are defined in capsules/primary.py.
  • The object capsule can be found in capsules/capsule.py.
  • The SCAE model is defined in capsules/models/scae.py.

Running Experiments

You can train the model by invoking:

# From ../
python -m stacked_capsule_autoencoders.scripts.train --name=experiment_name

Different model configurations and datasets can be chosen by command-line flags. For an overview of the flags, have a look at the top of the training script (train.py) and the top of model and data configs in capsules/configs.

To replicate experiments reported in the paper you can run one of the run_*.sh files:

  1. run_mnist.sh runs the full MNIST model as reported in the paper.
  2. run_mnist_coupled.sh runs the full MNIST model, with the difference that the part mixing probabilities are constrained to be the same as pixel intensities. We used a similar setup for the SVHN experiments.
  3. run_constellation.sh runs the constellation experiment.

You can also pass command line arguments to the above scripts. For example,

./run.sh --batch_size=128

will invoke run.sh script and set the batch size to 128.

Model snapshots, tensorboard logs and intermediate plots (needs passing --plot=True) will be stored in {logdir}/{name} directory where logdir and name can be set as command-line arguments. logdir defaults to google-research/stacked_capsule_autoencoders/checkpoints.

Evaluation

You can use the eval_mnist_model.py script to evaluate a trained (on MNIST) model snapshot. To do so, you need to invoke the script with the same command-line arguments you used to train the model. We provide eval_mnist.sh and eval_mnist_coupled.sh for evaluating snapshots produced by run_mnist.sh and run_mnist_coupeld.sh. Invoking this script will load the model snapshot and report linear classification accuracy, unsupervised classification accuracy (clustering + bipartite graph matching), and also produce a TSNE plot of capsule probabilities. The plot will be stored in the checkpoint directory.

Citation

If you find this repo or the corresponding paper useful in your research, please consider citing:

@inproceedings{Kosiorek2019scae,
  title={Stacked Capsule Autoencoders},
  author={Kosiorek, Adam Roman and Sabour, Sara and Teh, Yee Whye and Hinton, Geoffrey Everest},
  booktitle={Advances in Neural Information Processing Systems},
  url = {https://arxiv.org/abs/1906.06818},
  pdf = {https://arxiv.org/pdf/1906.06818},
  year={2019}
}

Release Notes

Version 1.0

  • Original implementation; contains the constellation and MNIST experiments.

stacked_capsule_autoencoders's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.