GithubHelp home page GithubHelp logo

mrsebai / aerial-tile-segmentation Goto Github PK

View Code? Open in Web Editor NEW
14.0 2.0 1.0 24.19 MB

Large satellite image semantic segmentation into 6 classes using Tensorflow 2.0 and ISPRS benchmark dataset.

Jupyter Notebook 100.00%
tensorflow2 semantic-segmentation satellite-images deep-learning u-net potsdam

aerial-tile-segmentation's Introduction

Tensorflow python jupyter

Semantic Segmentation of Satellite Images using Tensorflow 2.0

This repo details the steps carried out in order to perform a Semantic Segmentation task on Satellite and/or Aerial images (aka tiles). A Tensorflow 2.0 deep learning model is trained using the ISPRS dataset. The data set contains 38 patches (of 6000x6000 pixels), each consisting of a true orthophoto (TOP) extracted from a larger TOP mosaic.

Each tile is paired with a reference segmentation mask depicting the 6 classes with different colors (see below). example

Development Environment

Tools and libraries:

  • Python 3.5
  • Imageio 2.6
  • Deep Learning Libraries:
  • Low-Level API: Tensorflow 2.0 (with eager_execution enabled)
  • High-Level API: Keras 2.2
  • Input pipeline API: Tensorflow.data
  • Monitoring API: TensorBoard

Infrastructure:

  • 16-Core, 64GB RAM
  • Nvidia 16GB GPU (Tesla P100)
  • VM instance on GCP

Patch extraction and data augmentation using tf.data input pipeline

Due to the size of the tiles (600x6000 pixels), it is not possible to feed them directly to the Tensorflow model which has an image input size limited to 256x256 pixels. Thus it is crucial to build an efficient and flexible input pipeline that reads the tile file, extracts smaller patches, performs data augmentation techniques while being fast enough to avoid data starvation of the model sitting on the GPU during the training phase. Fortunately, Tensorflow's tf.data allows the building of such a pipeline. The tile and its corresponding reference mask are processed in parallel and the produced smaller patches are like shown in the following grid: example

Tensorflow model architecture

The model is based on U-Net convolutional neural network that was enhanced using skip connections and residual blocks borrowed from the Residual Neural Networks that help enhance the flow of the gradient during the backpropagation step. Keras functional API was used to implement the model. example

Model training

We experimented with several loss functions based on recent A.I literature.

In addition, we adopted the learning rate finder to spot the optimum learning-rate for the a choosen loss function. The finding process produces the following loss curve showing the learning rate sweet spot that should be picked (right before global minimum) for optimum training. example

Once the optimum learning rate is found, the training is performed using the one-cycle policy training strategy. The curves below depict the evolution of the learning rate and the SGD momentum during training. example

Naturally, during training, we monitor the performance metrics: Accuracy, IoU, and the loss function as shown below. The training is halted thanks to the Early Stopping strategy once the performance metrics stagnate. example

Best Model Performance Metrics

The model performance measured on the validation dataset is quite amazing especially on the Building, Road and Car classes (IoU > 0.8). Below are the Confusion Matrix and the Per-class IoU metrics along with some reference visuals for the IoU metric. example

Tile Prediction using Test-Time Augmentation

Applying the inference pipeline to a new tile of the same size (6000x6000) could be slow if we loop through the tile to extract the patches, make a batch prediction, and stitch them together the patches to reconstruct the tile. Fortunately, can we perform such inference without any loop thanks to a clever tile reconstruction trick using Tensorflow's tf.scatter_nd. Inference time on tile is reduced from minutes to seconds.

In addition, once we performed an inference on a tile, Test-time augmentation technique enhances by several points the prediction quality as shown below:

  • Without test-time augmentation example

  • With test-time augmentation example

Implementation and Report

aerial-tile-segmentation's People

Contributors

mrsebai avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

jamcody

aerial-tile-segmentation's Issues

How to combine CE loss with Jaccard loss?

Hi,
with some experiments, I found that CE loss and Jaccard loss both have their own advantages, and they are used in the code separately. Did you ever consider to combine them? Will this improve IoU score than single CE loss or Jaccard Loss?

compute_dataset_iou error

Hi, I have changed BATCH_SIZE = 9, and train the model. When I run the code compute_dataset_iou(model, valid_ds, VALID_STEPS_PER_EPOCH), there was an error:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
Input In [263], in <cell line: 2>()
      1 # Iterate over the whole dataset and output the confusion matrix and per class IOU
----> 2 compute_dataset_iou(model, valid_ds, VALID_STEPS_PER_EPOCH)

Input In [262], in compute_dataset_iou(model, ds, steps)
      2 def compute_dataset_iou(model, ds, steps):
----> 3     total_cm = compute_dataset_cm(model, ds, steps)
      4     normalized_cm = total_cm / tf.reduce_sum(total_cm, axis=1, keepdims=True)
      5     sum_over_row = tf.reduce_sum(total_cm, axis=0)

Input In [261], in compute_dataset_cm(model, ds, steps, normalize)
      7     pred = tf.argmax(pred, axis=-1, output_type=tf.int32)
      8     label = patch_pairs[1]
----> 9     total_cm += tf.math.confusion_matrix(tf.reshape(label, [-1]), tf.reshape(pred, [-1]))
     10 # normalizing the confusing matrix
     11 if normalize: 

File ~/anaconda3/envs/airbus/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File ~/anaconda3/envs/airbus/lib/python3.9/site-packages/tensorflow/python/framework/ops.py:7186, in raise_from_not_ok_status(e, name)
   7184 def raise_from_not_ok_status(e, name):
   7185   e.message += (" name: " + name if name is not None else "")
-> 7186   raise core._status_to_exception(e) from None

InvalidArgumentError: Incompatible shapes: [6,6] vs. [5,5] [Op:AddV2]
โ€‹

I found that in one batch, there will be confusion matrix shape as [5,5], but total_cm shape is [6,6], they did't match. Maybe in this batch, the patch image lack of one class.

tf.Tensor(
[[  3846      3 158742   5137      0]
 [     0      0      0      0      0]
 [    80      6   9529    176   1924]
 [ 12928      0 135083 121233 102881]
 [     0      0      0   1061  37195]], shape=(5, 5), dtype=int32)

Which configuration to produce the best model?

Hi,
when I use the code for training, my configuration is:

  • unet wit input 256x256x3
  • noScale
  • skip4 Residual
  • serialCenter
  • sgd optimizer
  • SparseCategoricalCrossentropy loss

I get the result for validation dataset:

Background	Building	Roads	Vegetation	Tree	         Car
0.199138	 0.770326    0.678813	0.652222	0.652314	0.66924

However, in your presentation, your best model give the result:

Background   Building        Roads	Vegetation	Tree	          Car
0.3132	      0.8701         0.8030	0.6928	        0.7061	       0.8243

Which configuration you use to produce the best model? And may you have the best model for downloading?
Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.