GithubHelp home page GithubHelp logo

thohidu / pytorch-toolbelt Goto Github PK

View Code? Open in Web Editor NEW

This project forked from bloodaxe/pytorch-toolbelt

0.0 1.0 0.0 1.64 MB

PyTorch extensions for fast R&D prototyping and Kaggle farming

License: MIT License

Makefile 0.11% Python 99.89%

pytorch-toolbelt's Introduction

Pytorch-toolbelt

Build Status Documentation Status

A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming:

What's inside

  • Easy model building using flexible encoder-decoder architecture.
  • Modules: CoordConv, SCSE, Hypercolumn, Depthwise separable convolution and more.
  • GPU-friendly test-time augmentation TTA for segmentation and classification
  • GPU-friendly inference on huge (5000x5000) images
  • Every-day common routines (fix/restore random seed, filesystem utils, metrics)
  • Losses: BinaryFocalLoss, Focal, ReducedFocal, Lovasz, Jaccard and Dice losses, Wing Loss and more.
  • Extras for Catalyst library (Visualization of batch predictions, additional metrics)

Showcase: Catalyst, Albumentations, Pytorch Toolbelt example: Semantic Segmentation @ CamVid

Why

Honest answer is "I needed a convenient way to re-use code for my Kaggle career". During 2018 I achieved a Kaggle Master badge and this been a long path. Very often I found myself re-using most of the old pipelines over and over again. At some point it crystallized into this repository.

This lib is not meant to replace catalyst / ignite / fast.ai. Instead it's designed to complement them.

Installation

pip install pytorch_toolbelt

Showcase

Encoder-decoder models construction

from pytorch_toolbelt.modules import encoders as E
from pytorch_toolbelt.modules import decoders as D

class FPNSegmentationModel(nn.Module):
    def __init__(self, encoder:E.EncoderModule, num_classes, fpn_features=128):
        self.encoder = encoder
        self.decoder = D.FPNDecoder(encoder.output_filters, fpn_features=fpn_features)
        self.fuse = D.FPNFuse()
        input_channels = sum(self.decoder.output_filters)
        self.logits = nn.Conv2d(input_channels, num_classes,kernel_size=1)
        
    def forward(self, input):
        features = self.encoder(input)
        features = self.decoder(features)
        features = self.fuse(features)
        logits = self.logits(features)
        return logits
        
def fpn_resnext50(num_classes):
  encoder = E.SEResNeXt50Encoder()
  return FPNSegmentationModel(encoder, num_classes)
  
def fpn_mobilenet(num_classes):
  encoder = E.MobilenetV2Encoder()
  return FPNSegmentationModel(encoder, num_classes)

Compose multiple losses

from pytorch_toolbelt import losses as L

loss = L.JointLoss(L.FocalLoss(), 1.0, L.LovaszLoss(), 0.5)

Test-time augmentation

from pytorch_toolbelt.inference import tta

# Truly functional TTA for image classification using horizontal flips:
logits = tta.fliplr_image2label(model, input)

# Truly functional TTA for image segmentation using D4 augmentation:
logits = tta.d4_image2mask(model, input)

# TTA using wrapper module:
tta_model = tta.TTAWrapper(model, tta.fivecrop_image2label, crop_size=512)
logits = tta_model(input)

Inference on huge images:

import numpy as np
import torch
import cv2

from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger
from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, to_numpy


image = cv2.imread('really_huge_image.jpg')
model = get_model(...)

# Cut large image into overlapping tiles
tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight='pyramid')

# HCW -> CHW. Optionally, do normalization here
tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)]

# Allocate a CUDA buffer for holding entire mask
merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight)

# Run predictions for tiles and accumulate them
for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True):
    tiles_batch = tiles_batch.float().cuda()
    pred_batch = model(tiles_batch)

    merger.integrate_batch(pred_batch, coords_batch)

# Normalize accumulated mask and convert back to numpy
merged_mask = np.moveaxis(to_numpy(merger.merge()), 0, -1).astype(np.uint8)
merged_mask = tiler.crop_to_orignal_size(merged_mask)

Advanced examples

  1. Inria Sattelite Segmentation
  2. CamVid Semantic Segmentation

pytorch-toolbelt's People

Contributors

bloodaxe avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.