GithubHelp home page GithubHelp logo

lyan-ing / cloud-segmentation Goto Github PK

View Code? Open in Web Editor NEW

This project forked from phixerino/cloud-segmentation

0.0 0.0 0.0 32.68 MB

Cloud segmentation on satellite images from Sentinel-2 Cloud Mask Catalogue. Implemented in PyTorch and exported to ONNX for inference.

Python 98.52% Makefile 0.78% Dockerfile 0.70%

cloud-segmentation's Introduction

cloud-segmentation

Cloud segmentation on satellite images from Sentinel-2 Cloud Mask Catalogue, written in PyTorch. DataLoader is implemented for binary segmentation with CLOUD and CLEAR/CLOUD_SHADOW classes. By default this repo assumes 4 (RGB + NIR) channels, but this can be changed through arguments/config.

Install

pip install -r requirements.txt

or build docker image and run container with:

make build
make

You might have to change base image according to your CUDA and cuDNN versions. If you want to do training, then also change LOCAL_DATASET in Makefile to your dataset path.

Inference

Inference is accelerated with ONNX Runtime. You can download model here and then run inference:

python inference.py --model DeepLabV3Plus_resnet101_1678896432.onnx --source data/test_subscene.npy --save --out_folder data/preds/

Predicted masks will be saved in out_folder. You can also plot masks with --show.

Short example of how to do inference in Google Colab is in notebooks/sentinel_segmentation_inference.ipynb

Training

TODO: Custom dataset

Sentinel-2 Cloud Mask Catalogue

Download and unzip subscenes.zip and masks.zip to your dataset path. This path can be changed in cfg/config.json under dataset_path.

Data loading

Images can be loaded with different tiling strategies by changing these settings:

  • subscene_width, subscene_height - manually resize subscenes and masks before tiling
  • train_tile_stride_x, train_tile_stride_y, val_tile_stride_x, val_tile_stride_y - lower stride than tile height/width will lead to overlap
  • train_scale, val_scale - automatically scale subscenes and masks so that tiles fit the entire image. This is done after manual resizing with subscene_width/subscene_height. Allowed values are [None, 'down', 'up']

Train

Modify cfg/config.json to your liking and run:

python train.py

You can also override any setting through command-line arguments without modifying the config file, for example:

python3 train.py --epochs 100 --batch_size 128 --lr 0.01 --optimizer AdamW --scheduler cos --warmup_epochs 5 --decoder_name UnetPlusPlus --encoder_name resnet50 --loss CE --no_wandb_log

If you want to utilize more GPUs, enable Distributed Data Parallel (DDP), where --nproc_per_node sets the number of GPUs:

torchrun --standalone --nproc_per_node 2 train.py

Example of training progress:

Results

Model mIoU
DeepLabV3+ with ResNet101 86.43
Unet++ with ResNet101 84.55

The DeepLabV3+ model with ResNet101 encoder was trained with these settings:

  • pretrained encoder on ImageNet
  • 9:1 train/val split
  • 50 epochs
  • 20 epochs early stop
  • 64 batch size
  • AdamW optimizezr
  • 0.0005 learning rate
  • 0.005 weight decay
  • 3 linear warmup epochs
  • linear lr scheduler
  • Dice loss
  • mean IoU val metric
  • augmentations: rotation (max 60 degrees, probability 0.5), horizontal flip (probability 0.5), vertical flip (probability 0.5)

TODO

  • Multi-class segmentation
  • Resume training
  • Automated hyperparameter tuning
  • Add more metrics

Export

After training, the PyTorch model can be exported to ONNX with:

python export.py --model_file weights/my_model.pt

cloud-segmentation's People

Contributors

phixerino avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.