GithubHelp home page GithubHelp logo

giovanniguidi / deeplabv3-pytorch Goto Github PK

View Code? Open in Web Editor NEW
42.0 3.0 9.0 1.57 MB

Implementation of the DeepLabV3+ model in PyTorch for semantic segmentation, trained on DeepFashion2 dataset

Python 32.73% Jupyter Notebook 67.26% Shell 0.01%

deeplabv3-pytorch's Introduction

DeepLab V3+ Network for Semantic Segmentation

This project is based on one of the state-of-the-art algorithms for semantic segmentation, DeepLabV3+ by the Google research group (Chen et al. 2018, https://arxiv.org/abs/1802.02611). Semantic segmentation is the task of predicting for each pixel of an image a "semantic" label, such as tree, street, sky, car (and of course background).

This algorithm is here applied to the DeepFashion2 dataset (Ge et al. 2019), one of the most popular dataset used by fashion research groups. The dataset contains 491K images of 13 popular clothing categories with bounding boxes, and almost 185K images with segmentation, from both commercial shopping stores and consumers.

DeepLabV3+ model is very complex, but the biggest difference compared to other models is the use of "atrous convolutions" in the encoder (which was already suggested in the first DeepLab model by Chen et al. 2016), in a configuration called Atrous Spatial Pyramid Pooling (ASPP). ASPP is composed by different atrous convolution layers in parallel with a different atrous rate, allowing to capture information at multiple scales and extract denser feature maps (see the image below and the paper for details).

Fig. 1: DeepLabV3+ model (source Chen et al. 2018)

Virtual environment

First you need to create a virtual environment.

Using Conda you can type:

conda create --name deeplab --python==3.7.1
conda activate deeplab

Dependencies

This project is based on the PyTorch Deep Learning library.

Install the dependencies by:

pip install -r requirements.txt 

Dataset

Download the dataset from:

https://github.com/switchablenorms/DeepFashion2

Before using those data you need to convert the labels in a format which can seamless enter into a semantic segmentation algorithm. In this case we use .png images, where the value of each pixel is the cloth class (so from 1 to 13), but other choices are possible. The background class has value 0.

You need to create a script to convert the polygons in DeepFashion2 labels into a proper format for the algorithm, or you can download the labels from:

https://drive.google.com/drive/folders/1O8KLZa1AABlLS6DlkkzHOgPqvT89GB_9?usp=sharing

This folder contains also the train/val/test split json in case you want to use the same split I used.

Parameters

All the parameters of the model are in configs/config.yml.

Weights

The trained weights can be found here:

https://drive.google.com/drive/folders/1O8KLZa1AABlLS6DlkkzHOgPqvT89GB_9?usp=sharing

The model can be trained with different backbones (resnet, xception, drn, mobilenet). The weights on the Drive has been trained with the ResNet backbone, so if you want to use another backbone you need to train from scratch (although the backbone weights are always pre-trained on ImageNet).

Train

To train a model run:

python main.py -c configs/config.yml --train

You can set "weights_initialization" to "true" in config.yml, in order to restore the training after an interruption.

During training the best and last snapshots can be stored if you set those options in "training" in config.yml.

Inference

To predict on the full test set run and get the metrics do:

python main.py -c configs/config.yml --predict_on_test

In "./test_images/" there are some images that can be used for testing the model. To predict on a single image you can run:

python main.py -c configs/config.yml --predict --filename test_images/068834.jpg

You can also check the "inference.ipynb" notebook for visual assessing the predictions.

Results

Here is an example of the results:

Fig. 2: Prediction on DeepFashion2

On the test set we get this metrics:

accuracy: 0.84
accuracy per class: 0.47
mean IoU: 0.34
freq weighted IoU: 0.79

Train on other data

This implementation can be easily used on other dataset. The expected input of the model are .jpg images, and the labels are in .png format, with 1 channel (i.e. shape = (y_size, x_size)), and pixel value corresponding to the target class. In principle you only need to modify the data_generator.

References

[1] Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation

[2] Rethinking Atrous Convolution for Semantic Image Segmentation

deeplabv3-pytorch's People

Contributors

giovanniguidi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

deeplabv3-pytorch's Issues

RuntimeError:given groups=1

An error occurred while I was trying to train the VOC2007 data set:
RuntimeError:Given groups=1, weight of size 256 304 3 3,expected input[1,2096,129,129]to have 304 channels, but got 2096 channels instead.
The error code is:
decoder.py/line 45:
self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
BatchNorm(256),
nn.ReLU(),
nn.Dropout(0.1),
nn.Conv2d(256, num_classes, kernel_size=1, stride=1))
........
x = self.last_conv(x)
Could you tell me how to get the input channel 304?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.