GithubHelp home page GithubHelp logo

ozzie00 / capsnet-keras Goto Github PK

View Code? Open in Web Editor NEW

This project forked from xifengguo/capsnet-keras

0.0 2.0 0.0 645 KB

A Keras implementation of CapsNet in NIPS2017 paper "Dynamic Routing Between Capsules". Now test error = 0.34%.

License: MIT License

Python 100.00%

capsnet-keras's Introduction

CapsNet-Keras

License

A Keras implementation of CapsNet in the paper:
Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017
The current average test error = 0.34% and best test error = 0.30%.

Differences with the paper:

  • We use the learning rate decay with decay factor = 0.9 and step = 1 epoch,
    while the paper did not give the detailed parameters (or they didn't use it?).
  • We only report the test errors after 50 epochs training.
    In the paper, I suppose they trained for 1250 epochs according to Figure A.1? Sounds crazy, maybe I misunderstood.
  • We use MSE (mean squared error) as the reconstruction loss and the coefficient for the loss is lam_recon=0.0005*784=0.392.
    This should be equivalent with using SSE (sum squared error) and lam_recon=0.0005 as in the paper.

Recent updates:

  • Add example for multi-gpu training. On two GTX 1080Ti, speed=55s/epoch, in contrast to 80s/epoch on a single GTX 1080Ti.
  • Correct the Routing algorithm. Now the gradients in inner iterations are blocked. Reorganize the dimensions of Tensors in this part and optimize some operations to speed up. About 100s / epoch on a single GTX 1070 GPU.
  • Rename dim_vector to dim_capsule.
  • Change prior b from Variable to constant and move it from build to call. Although it is equivalent to the former version, but the current version is easier to understand. Thanks to #15.

TODO

  • Conduct experiments on other datasets.
  • Explore interesting characteristics of CapsuleNet.

Contacts

  • Your contributions to the repo are always welcome. Open an issue or contact me with E-mail [email protected] or WeChat wenlong-guo.

Usage

Step 1. Install Keras>=2.0 with TensorFlow>=1.2 backend.

pip install tensorflow-gpu
pip install keras

Step 2. Clone this repository to local.

git clone https://github.com/XifengGuo/CapsNet-Keras.git
cd CapsNet-Keras

Step 3. Train a CapsNet on MNIST

Training with default settings:

python capsulenet.py

Training with one routing iteration (default 3).

python capsulenet.py --num_routing 1

Other parameters include batch_size, epochs, lam_recon, shift_fraction, save_dir can be passed to the function in the same way. Please refer to capsulenet.py

Step 4. Test a pre-trained CapsNet model

Suppose you have trained a model using the above command, then the trained model will be saved to result/trained_model.h5. Now just launch the following command to get test results.

$ python capsulenet.py --is_training 0 --weights result/trained_model.h5

It will output the testing accuracy and show the reconstructed images. The testing data is same as the validation data. It will be easy to test on new data, just change the code as you want.

You can also just download a model I trained from https://pan.baidu.com/s/1sldqQo1

Step 5. Train on multi gpus

This requires Keras>=2.0.9. After updating Keras:

python capsulenet-multi-gpu.py --gpus 2

It will automatically train on multi gpus for 50 epochs and then output the performance on test dataset. But during training, no validation accuracy is reported.

Results

Test Errors

CapsNet classification test error on MNIST. Average and standard deviation results are reported by 3 trials. The results can be reproduced by launching the following commands.

python capsulenet.py --num_routing 1 --lam_recon 0.0    #CapsNet-v1   
python capsulenet.py --num_routing 1 --lam_recon 0.392  #CapsNet-v2
python capsulenet.py --num_routing 3 --lam_recon 0.0    #CapsNet-v3 
python capsulenet.py --num_routing 3 --lam_recon 0.392  #CapsNet-v4
Method Routing Reconstruction MNIST (%) Paper
Baseline -- -- -- 0.39
CapsNet-v1 1 no 0.39 (0.024) 0.34 (0.032)
CapsNet-v2 1 yes 0.36 (0.009) 0.29 (0.011)
CapsNet-v3 3 no 0.40 (0.016) 0.35 (0.036)
CapsNet-v4 3 yes 0.34 (0.016) 0.25 (0.005)

Losses and accuracies:

Training Speed

About 100s / epoch on a single GTX 1070 GPU.
About 80s / epoch on a single GTX 1080Ti GPU.
About 55s / epoch on two GTX 1080Ti GPU by using capsulenet-multi-gpu.py.

Reconstruction result

The result of CapsNet-v4 by launching

python capsulenet.py --is_training 0 --weights result/trained_model.h5

Digits at top 5 rows are real images from MNIST and digits at bottom are corresponding reconstructed images.

The model structure:

Other Implementations

capsnet-keras's People

Contributors

kmader avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.