GithubHelp home page GithubHelp logo

enigmatisms / capsnet Goto Github PK

View Code? Open in Web Editor NEW
4.0 2.0 0.0 85.37 MB

My own implementation of CVPR 2017 paper: Dynamic Routing Between Capsules

License: Apache License 2.0

Python 100.00%
capsule-network pytorch deep-learning mnist

capsnet's Introduction

CapsNet


My own implementation of CVPR 2017 paper by Sara Sabour, Nicholas Frosst, Geoffrey E Hinton: Dynamic Routing Between Capsules. This is, by far, the hardest model I've ever trained, damn it. The first version of capsule network is simple, and in the paper, it is implemented for MNIST classification, which is also the implementation of mine. My implementation in Pytorch 1.7 is simpler than those of others, I assume. Yet it took me a long time to find out what was wrong in my original implementation, I found some parts of the paper confusing...


Usage Guidance

To run this code, please do as follows (windows 10). test_caps.py is the only executable file in this repository. Make sure the working directory is correctly set to .\\py\\.

cd .\\py\\
python .\\test_caps.py -d -c --batch_size=100

The available arguments are (run python .\\test_caps.py --help to find out)

usage: test_caps.py [-h] [--epochs epochs] [--batch_size BATCH_SIZE]
                    [--routing_iter ROUTING_ITER] [--save_time SAVE_TIME]
                    [--gamma GAMMA] [--recons_ratio RECONS_RATIO]
                    [--test_time TEST_TIME] [-d] [-c] [-l]

optional arguments:
  -h, --help            show this help message and exit
  --epochs EPOCHS     Training lasts for . epochs
  --batch_size BATCH_SIZE
                        Batch size
  --routing_iter ROUTING_ITER
                        Dynamic routing iteration number
  --save_time SAVE_TIME
                        Save generated img every <> batches
  --gamma GAMMA         Exponential lr coefficient
  --recons_ratio RECONS_RATIO
                        The ratio of reconstruction error
  --test_time TEST_TIME
                        Test frequency for validation set
  -d, --del_dir         Delete dir ./logs and start new tensorboard records
  -c, --cuda            Use CUDA to speed up training
  -l, --load            Load the trained '.pth' model

Also, be sure to have: imgs, models, etc. in the root directory. My .gitignore ignores all these for repository tidiness.


Directory Structure

  • imgs: Keeping the reconstructed images.
  • models: Where the .pth files are stored.
  • py: Implementations
  • logs: Tensorboard scalar records.

Requirements

Run:

python -m pip install -r requirements.txt

To find out. The most important requirement (env) is Pytorch(1.7) and CUDA(10.1), tensorboard support is also needed, which requires you to have TensorFlow on your device.


Results

I have trained the model for 5 epochs, for in each epoch, batch size = 50, which makes 1200 batches. '5 epochs' is enough for showing that capsule network works and my implementation is correct.

After 5 epochs of training, I have run 2 more single-epoch training, with exponentially-decreasing learning rate, gamma is selected to be 0.998, wishing to do a little bit fine-tuning job. The training acc, test acc and confusion matrix are shown as follows:

Acc (not fine-tuned is about 98.4% on test set, while nearly 100% on the train set).

Acc on train set Acc on test set
Loss Confusion matrix
Reconstructed images (1) Reconstructed images (2)

capsnet's People

Contributors

enigmatisms avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.