GithubHelp home page GithubHelp logo

hbgtjxzbbx / vnet-tensorflow Goto Github PK

View Code? Open in Web Editor NEW

This project forked from miguelmonteiro/vnet-tensorflow

0.0 1.0 0.0 1.3 MB

Tensorflow implementation of the V-Net architecture for medical imaging segmentation.

Python 100.00%

vnet-tensorflow's Introduction

Tensorflow implementation of V-Net

This is a Tensorflow implementation of the "V-Net" architecture used for 3D medical imaging segmentation. This code only implements the Tensorflow graph, it must be used within a training program.

Visual Representation of the Network

This is an example a network this code implements.

VNetDiagram

How to use

The function v_net(tf_input, input_channels, output_channels, n_channels) has the following arguments:

  1. tf_input: a rank 5 tensor with shape [batch_size, X, Y, Z, input_channels] where X, Y, Z are the spatial dimensions of the images and input_channels is the number of channels the images have;

  2. input_channels: the number of channels of the input images;

  3. output_channels is the number of desired output channels. v_net() will return a tensor with the same shape as tf_input but with a different number of channels i.e. [batch_size, x, y, z, output_channels].

  4. n_channels is the number of channels used internally in the network. In the original paper this number was 16. This number doubles at every level of the contracting path. See the image for better understanding of this number.

Notes

Apart from the number of input channels input_channels none of the dimensions of tf_input need to be known. This allows reading examples from queues and even train the network with examples of different sizes.

In a binary segmentation problem you could use output_channels=1 with a sigmoid loss and in a three class problem you could use output_channels=3 with a softmax loss.

Example Usage

import tensorflow as tf
from VNet import v_net

input_channels = 6
ouptut_channels = 1
 
tf_input = tf.placeholder(dtype=tf.float32, shape=(10, 190, 190, 20, input_channels))

logits = v_net(tf_input, input_channels, 16, output_channels)

logits will have shape [10, 190, 190, 20, 1], it can the be flattened and used in the sigmoid cross entropy function.

Implementation details

There are two different slightly implementations in the code.

VNetOriginal.py implements the network as is in the original paper with three small differences.

  1. The input can have more than one channel. If the input has more than one channel han one more convolution is added in level one to increase the input number of channels to match n_channels. If the input has only one channel than it is broadcasted in the first skip connection (repeated ``n_channel` times).
  2. n_channels does not need to be 16.
  3. The output does not need to have two channels like in the original architecture.

If you input_channels=1, n_channels=16 and output_channels=2 the function v_net() implements the original architecture.

VNet.py is an updated version of the architecture with the following improvements/fixes:

  1. Relu non-linearities replaced with PRelu (Parametric Relu)
  2. The "V-Net" paper implemented the element-wise sum of the skip connection after the non-linearity. However, according to the original residual network paper, this should be done before the last non-linearity of the convolution block. This is fixed in this implementation.

vnet-tensorflow's People

Contributors

miguelmonteiro avatar

Watchers

paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.