GithubHelp home page GithubHelp logo

qiuweimin1332499 / segmentation_models Goto Github PK

View Code? Open in Web Editor NEW

This project forked from qubvel/segmentation_models

0.0 0.0 0.0 1.83 MB

Segmentation models with pretrained backbones. Keras and TensorFlow Keras.

License: MIT License

Python 100.00%

segmentation_models's Introduction

Python library with Neural Networks for Image Segmentation based on Keras and TensorFlow.

The main features of this library are:

  • High level API (just two lines of code to create model for segmentation)
  • 4 models architectures for binary and multi-class image segmentation (including legendary Unet)
  • 25 available backbones for each architecture
  • All backbones have pre-trained weights for faster and better convergence
  • Helpful segmentation losses (Jaccard, Dice, Focal) and metrics (IoU, F-score)

Important note

Some models of version 1.* are not compatible with previously trained models, if you have such models and want to load them - roll back with:

$ pip install -U segmentation-models==0.2.1

Table of Contents

Quick start

Library is build to work together with Keras and TensorFlow Keras frameworks

import segmentation_models as sm
# Segmentation Models: using `keras` framework.

By default it tries to import keras, if it is not installed, it will try to start with tensorflow.keras framework. There are several ways to choose framework:

  • Provide environment variable SM_FRAMEWORK=keras / SM_FRAMEWORK=tf.keras before import segmentation_models
  • Change framework sm.set_framework('keras') / sm.set_framework('tf.keras')

You can also specify what kind of image_data_format to use, segmentation-models works with both: channels_last and channels_first. This can be useful for further model conversion to Nvidia TensorRT format or optimizing model for cpu/gpu computations.

import keras
# or from tensorflow import keras

keras.backend.set_image_data_format('channels_last')
# or keras.backend.set_image_data_format('channels_first')

Created segmentation model is just an instance of Keras Model, which can be build as easy as:

model = sm.Unet()

Depending on the task, you can change the network architecture by choosing backbones with fewer or more parameters and use pretrainded weights to initialize it:

model = sm.Unet('resnet34', encoder_weights='imagenet')

Change number of output classes in the model (choose your case):

# binary segmentation (this parameters are default when you call Unet('resnet34')
model = sm.Unet('resnet34', classes=1, activation='sigmoid')
# multiclass segmentation with non overlapping class masks (your classes + background)
model = sm.Unet('resnet34', classes=3, activation='softmax')
# multiclass segmentation with independent overlapping/non-overlapping class masks
model = sm.Unet('resnet34', classes=3, activation='sigmoid')

Change input shape of the model:

# if you set input channels not equal to 3, you have to set encoder_weights=None
# how to handle such case with encoder_weights='imagenet' described in docs
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)

Simple training pipeline

import segmentation_models as sm

BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)

# load your data
x_train, y_train, x_val, y_val = load_data(...)

# preprocess input
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)

# define model
model = sm.Unet(BACKBONE, encoder_weights='imagenet')
model.compile(
    'Adam',
    loss=sm.losses.bce_jaccard_loss,
    metrics=[sm.metrics.iou_score],
)

# fit model
# if you use data generator use model.fit_generator(...) instead of model.fit(...)
# more about `fit_generator` here: https://keras.io/models/sequential/#fit_generator
model.fit(
   x=x_train,
   y=y_train,
   batch_size=16,
   epochs=100,
   validation_data=(x_val, y_val),
)

Same manipulations can be done with Linknet, PSPNet and FPN. For more detailed information about models API and use cases Read the Docs.

Examples

Models training examples:
  • [Jupyter Notebook] Binary segmentation (cars) on CamVid dataset here.
  • [Jupyter Notebook] Multi-class segmentation (cars, pedestrians) on CamVid dataset here.

Models and Backbones

Models

Unet Linknet
unet_image linknet_image
PSPNet FPN
psp_image fpn_image

Backbones

Type Names
VGG 'vgg16' 'vgg19'
ResNet 'resnet18' 'resnet34' 'resnet50' 'resnet101' 'resnet152'
SE-ResNet 'seresnet18' 'seresnet34' 'seresnet50' 'seresnet101' 'seresnet152'
ResNeXt 'resnext50' 'resnext101'
SE-ResNeXt 'seresnext50' 'seresnext101'
SENet154 'senet154'
DenseNet 'densenet121' 'densenet169' 'densenet201'
Inception 'inceptionv3' 'inceptionresnetv2'
MobileNet 'mobilenet' 'mobilenetv2'
EfficientNet 'efficientnetb0' 'efficientnetb1' 'efficientnetb2' 'efficientnetb3' 'efficientnetb4' 'efficientnetb5' efficientnetb6' efficientnetb7'
All backbones have weights trained on 2012 ILSVRC ImageNet dataset (encoder_weights='imagenet').

Installation

Requirements

  1. python 3
  2. keras >= 2.2.0 or tensorflow >= 1.13
  3. keras-applications >= 1.0.7, <=1.0.8
  4. image-classifiers == 1.0.*
  5. efficientnet == 1.0.*

PyPI stable package

$ pip install -U segmentation-models

PyPI latest package

$ pip install -U --pre segmentation-models

Source latest version

$ pip install git+https://github.com/qubvel/segmentation_models

Documentation

Latest documentation is avaliable on Read the Docs

Change Log

To see important changes between versions look at CHANGELOG.md

Citing

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models},
  Year = {2019},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models}}
}

License

Project is distributed under MIT Licence.

segmentation_models's People

Contributors

qubvel avatar habi avatar gazay avatar mathandy avatar btrotta avatar zcoder avatar hasibzunair avatar ilyaovodov avatar gagolucasm avatar sluki avatar tyler-d avatar chawater avatar 678098 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.