GithubHelp home page GithubHelp logo

yuby14 / keras-fcn Goto Github PK

View Code? Open in Web Editor NEW

This project forked from jihongju/keras-fcn

0.0 0.0 0.0 885 KB

A playable implementation of Fully Convolutional Networks with Keras.

License: MIT License

Python 100.00%

keras-fcn's Introduction

keras-fcn

Build Status codecovLicense: MIT

A re-implementation of Fully Convolutional Networks with Keras

Installation

Dependencies

  1. keras
  2. tensorflow/theano/CNTK (CNTK is not tested.)

Install with pip

$ pip install git+https://github.com/JihongJu/keras-fcn.git

Build from source

$ git clone https://github.com/JihongJu/keras-fcn.git
$ cd keras-fcn
$ pip install --editable .

Usage

FCN with VGG16

from keras_fcn import FCN
fcn_vgg16 = FCN(input_shape=(500, 500, 3), classes=21,  
                weights='imagenet', trainable_encoder=True)
fcn_vgg16.compile(optimizer='rmsprop',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
fcn_vgg16.fit(X_train, y_train, batch_size=1)

FCN with VGG19

from keras_fcn import FCN
fcn_vgg19 = FCN_VGG19(input_shape=(500, 500, 3), classes=21,  
                      weights='imagenet', trainable_encoder=True)
fcn_vgg19.compile(optimizer='rmsprop',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
fcn_vgg19.fit(X_train, y_train, batch_size=1)

Custom FCN (VGG16 as an example)

from keras.layers import Input
from keras.models import Model
from keras_fcn.encoders import Encoder
from keras_fcn.decoders import VGGDecoder
from keras_fcn.blocks import (vgg_conv, vgg_fc)
inputs = Input(shape=(224, 224, 3))
blocks = [vgg_conv(64, 2, 'block1'),
          vgg_conv(128, 2, 'block2'),
          vgg_conv(256, 3, 'block3'),
          vgg_conv(512, 3, 'block4'),
          vgg_conv(512, 3, 'block5'),
          vgg_fc(4096)]
encoder = Encoder(inputs, blocks, weights='imagenet',
                  trainable=True)
feat_pyramid = encoder.outputs   # A feature pyramid with 5 scales
feat_pyramid = feat_pyramid[:3]  # Select only the top three scale of the pyramid
feat_pyramid.append(inputs)      # Add image to the bottom of the pyramid


outputs = VGGDecoder(feat_pyramid, scales=[1, 1e-2, 1e-4], classes=21)

fcn_custom = Model(inputs=inputs, outputs=outputs)

And implement a custom Fully Convolutional Network becomes simply define a series of convolutional blocks that one stacks on top of another.

Custom decoders

from keras_fcn.blocks import vgg_deconv, vgg_score
from keras_fcn.decoders import Decoder
decode_blocks = [
vgg_deconv(classes=21, scale=1),            
vgg_deconv(classes=21, scale=0.01),
vgg_deconv(classes=21, scale=0.0001, kernel_size=(16,16), strides=(8,8)),
vgg_score(crop_offset='centered')           # A functional block cropping the
                                            # outcome scores to match the image.
                                            # Can use together with other custom
                                            # blocks
]
outputs = Decoder(feat_pyramid, decode_blocks)

The decode_blocks can be customized as well.

from keras_fcn.layers import CroppingLike2D
def my_decode_block(classes, scale):
    """A functional decoder block.
    :param: classes: Integer, number of classes
    :param scale: Float, weight of the current pyramid scale, varing from 0 to 1

    :return f: A function that takes a feature from the feature pyramid, x,
               applies upsampling and accumulate the result from the top of
               the pyramid.
    """
  def f(x, y):
    x = Lambda(lambda xx: xx * scale)(x)  # First weighs the scale
    x = Conv2D(filters=classes, kernel_size=(1,1))(x)   # Stride 1 conv layers,  
                                                        # replacing the
                                                        # traditional FC layer.
    if y is None:   # First block has no y.
      y = Conv2DTranspose(filters=classes, ...) # Deconvolutional layer or
                                                # Upsampling Layer
    else:
      x = CroppingLike2D(target=y, offset='centered')(x) # Crop the upsampled
                                                         # feature to match
                                                         # the output of one
                                                         # scale up.
      y = add([y, x])
      y = Conv2DTranspose(filters=classes, ...) # Deconv/Upsampling again.
    return y  # return output of the current scale in the feature pyramid
  return x

Try Examples (The example is out-of-date for now)

  1. Download VOC2011 dataset
$ wget "http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar"
$ tar -xvzf VOCtrainval_25-May-2011.tar
$ mkdir ~/Datasets
$ mv TrainVal/VOCdevkit/VOC2011 ~/Datasets
  1. Mount dataset from host to container and start bash in container image

From repository keras-fcn

$ nvidia-docker run -it --rm -v `pwd`:/root/workspace -v ${Home}/Datasets/:/root/workspace/data jihong/keras-gpu bash

or equivalently,

$ make bash
  1. Within the container, run the following codes.
$ cd ~/workspace
$ source venv/bin/activate
$ pip install -r requirements
$ pip setup.py build
$ cd voc2011
$ python train.py

More details see source code of the example in Training Pascal VOC2011 Segmention

Model Architecture

FCN8s with VGG16 as base net:

fcn_vgg16

TODO

  • Add ResNet

keras-fcn's People

Contributors

jihongju avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.