GithubHelp home page GithubHelp logo

gyang274 / capsulesem Goto Github PK

View Code? Open in Web Editor NEW
62.0 4.0 31.0 11.12 MB

A tensorflow implementation of Hinton's [matrix capsules with EM routing](https://openreview.net/pdf?id=HJWLfGWRb)

Home Page: https://gyang274.github.io/capsulesEM/

Python 100.00%
capsule matrix-capsules tensorflow hinton capsules capsule-network capsnet capsnet-tensorflow capsule-networks capsule-net

capsulesem's Introduction

Capsule

A Tensorflow Implementation of Hinton's Matrix Capsules with EM Routing.

Quick Start

$ git clone https://github.com/gyang274/capsulesEM.git && cd capsulesEM

$ cd src

$ python train.py

# open a new terminal (ctrl + alt + t)

$ python tests.py

Note:

  1. Tensorflow v1.4.0.

  2. This train.py and tests.py assumes the user have 2 GPU card: train.py will use the first GPU card, and tests.py will use the second one. In case a different setting required, or multiple GPUs are available for training, modify visible_device_list in session_config in slim.learning.train() in train.py, or modify visible_device_list in session_config in slim.evaluation.evaluation_loop() in tests.py.

Status

MNIST

  1. (R0I1) Network architecture same as in paper, Matrix Capsules with EM Routing, Figure 1.

    • Spread loss only, no reconstruction loss.

    • Adam Optimizer, learning rate default 0.001, no learning rate decay.

    • Batch size 24 (due to limit of GPU memory), iteration 1.

    • GPU: half K80 12GB memory, 2s-3s per training step.

    • Step: 43942, Test Accuracy: 99.37%.

    Screenshot Tensorboard

    Remark: Because of allow_smaller_final_batch=False and batch_size=24, test is running on a random sample 9984 of 10000, so worse case test accuracy could be 99.21%. Modify the src/datasets/mnist.py and src/test.py to run test on full test dataset.

  2. (R0I2) As above, except iteration 2. (TODO)

  3. (R1I2) As above, add reconstruction loss, iteration 2. (TODO)

Matrix Capsules Nets and Layers

Build a matrix capsules neural network as the same way of building CNN:

def capsules_net(inputs, num_classes, iterations, name='CapsuleEM-V0'):
  """Replicate the network in `Matrix Capsules with EM Routing.`
  """

  with tf.variable_scope(name) as scope:

    # inputs [N, H, W, C] -> conv2d, 5x5, strides 2, channels 32 -> nets [N, OH, OW, 32]
    nets = _conv2d_wrapper(
      inputs, shape=[5, 5, 1, 32], strides=[1, 2, 2, 1], padding='SAME', add_bias=True, activation_fn=tf.nn.relu, name='conv1'
    )
    # inputs [N, H, W, C] -> conv2d, 1x1, strides 1, channels 32x(4x4+1) -> (poses, activations)
    nets = capsules_init(
      nets, shape=[1, 1, 32, 32], strides=[1, 1, 1, 1], padding='VALID', pose_shape=[4, 4], name='capsule_init'
    )
    # inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 2 -> (poses, activations)
    nets = capsules_conv(
      nets, shape=[3, 3, 32, 32], strides=[1, 2, 2, 1], iterations=iterations, name='capsule_conv1'
    )
    # inputs: (poses, activations) -> capsule-conv 3x3x32x32x4x4, strides 1 -> (poses, activations)
    nets = capsules_conv(
      nets, shape=[3, 3, 32, 32], strides=[1, 1, 1, 1], iterations=iterations, name='capsule_conv2'
    )
    # inputs: (poses, activations) -> capsule-fc 1x1x32x10x4x4 shared view transform matrix within each channel -> (poses, activations)
    nets = capsules_fc(
      nets, num_classes, iterations=iterations, name='capsule_fc'
    )

    poses, activations = nets

  return poses, activations

In particular,

  • capsules_init() takes a CNN layer as inputs, and produces a matrix capsule layer (e.g., primaryCaps) as output.

    This operation is corresponding to the layer A -> B in the paper.

  • capsules_conv() takes a matrix capsule layer (e.g., primaryCaps, ConvCaps1) as inputs, and produces a matrix capsule layer (e.g., ConvCaps1, ConvCaps2) as output.

    This operation is corresponding to the layer B -> C and C -> D in the paper.

  • capsules_fc() takes a matrix capsule layer (e.g., ConvCaps2) as inputs, and produces an output matrix capsule layer with poses and activations (e.g., Class Capsules) as output.

    This operation is correponding to the layer D -> E in the paper.

TODO

  1. How tf.stop_gradient() in EM? How iteration > 1 cause NaN in loss and capsules_init() activations?

  2. Add learning_rate decay in train.py

  3. Add train.py/tests.py on smallNORB.

Questions

  1. $$\lambda$$ schedule is never mentioned in paper.

  2. The place encode in lower level and rate encode in higher level is not discussed, other than a coordinate addition in last layer.

GitHub Page

This gh-pages includes all notes.

GitHub Repository

This github repository includes all source codes.

capsulesem's People

Contributors

gyang274 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

capsulesem's Issues

using nets.capsules_v0 Fucntion

when i using this function to get the
poses and activations , i get this error and i searched much about that but i cant know the reasons

ValueError: Tensor("capsulesEM-V0/conv1/weights:0", shape=(5, 5, 1, 32), dtype=float32_ref, device=/device:CPU:0) must be from the same graph as Tensor("shuffle_batch:0", shape=(24, 28, 28, 1), dtype=float32).

i did not change just try to run your code

thanks in advance

License not clear

I really enjoy seeing that someone developed code for this interesting paper.
For me it would be great if I could use parts of the code, but the license is unclear.
Perhaps you could put a license online?

The E-step and stop_gradient

Nice to see there's already an implementation of this!

I just stumbled across tensorflow's "stop_gradient" function. In the examples of where the function might be needed, they mention "The EM algorithm where the M-step should not involve backpropagation through the output of the E-step."

Does this also apply when using the EM algorithm for routing? I don't think I read anything about this in the paper, but then again the paper is very sparse with information about the backpropagation...
Not calculating the gradients for the E-step might considerably speed up training, I believe.
Thoughts?

"NaN" and the initialization of poses and transformation matrices

I read in your Readme that you're having issues with getting NaNs as result.
I've had similar issues and found that after the first iteration, sometimes my activations would be negative. This causes an issue in the next iteration since the activations are used in the calculation of the variance in the MStep, which becomes negative. I then take the square root of that number to calculate the standard deviation and get a "nan".

I believe the issues went away when I fiddled with initial poses and transformation matrices and initialized them differently (make sure the transformation matrices are initialized randomly and the initial poses aren't all the same). I have to investigate further but I believe the trouble is when all votes lie close together.

All of this might be a bug in my code, but I thought I'd mention it in case your nan's come from a similar issue.

About the initialization of the poses: The paper uses a simple convolution to produce the initial poses in the primary layer, which I find confusing. Why throw away the spatial information if you can use it? I'm currently testing an initialization which is just the pose expressed as a transformation matrix:

1 0 0 x/w - 0.5
0 1 0 y/h - 0.5
0 0 1 0
0 0 0 1

(Ignore the bold formatting please, github doesn't support tables without headers.)
w: image width
h: image height
x: x-component of pixel's position in image (0...w)
y: y-component of pixel's position in image (0...h)

I then initialize the transformations as random values (similar to the weights in a normal convolutional network).

Question about the routing for test

Hi!
Thanks for your kind share of the code. There is a question which makes me confused.
I am not sure whether the EM routing should be executed iteratively when the testing procedure. From your code, I am finding the differences of EM routing procedures for training and testing. Is it means that for every batch the parameters involved with the EM routing need be initialized and updated as the paper Procedure 1 and the EM_ROUTING in your code? If I am wrong, please correct me.
Thanks a lot!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.