GithubHelp home page GithubHelp logo

keras-semantic-segmentation's Introduction

Keras-Sematic-Segmentation

使用Keras实现深度学习中的一些语义分割模型。

配置

  • tensorflow 1.8.0/1.13.0
  • keras 2.2.4
  • GTX 2080Ti/CPU
  • Cuda 10.0 + Cudnn7
  • opencv

目录结构

  • data 存储输入图像和语义分割标签的文件夹
- data
	- dataset_name
		- train_image
		- train_label
		- test_image
		- test_label
  • Models 存储使用keras实现的一些经典分割模型
  • utils 存储工具代码,如数据预处理
  • data.py 加载1个batch的原始图片和分割标签图片
  • train.py 模型训练
  • test.py 模型测试

已支持的分割模型

Epoch model_name Base Model Segmentation Model Available
50 enet ENet Enet True
50 fcn8 Vanilla CNN FCN8 True
50 unet Vanilla CNN UNet True
50 segnet Vanilla CNN SegNet True
50 icnet Vanilla CNN ICNet True
50 pspnet Vanilla CNN PSPNet True
50 mobilenet_unet MobileNet MobileNetUnet True
50 mobilenet_fcn8 MobileNet MobileNetFCN8 True
50 seunet SENet SEUNet True
50 scseunet SCSENet scSEUNet True
50 vggunet VGGNet VGGUnet True
50 unet_xception_resnetblock XceptionNet Unet_Xception_ResNetBlock True
50 pspnet_resnet50 ResNet50 PSPNet_ResNet50 True
50 deeplab_v2 DeepLab DeepLabV2 True
50 hrnet HRNet HRNet True

训练

使用下面的命令训练和保存模型,模型保存路径,训练超参数需要灵活设置。

python train.py 

可用参数如下:

  • --dataset_name 字符串,代表选择对应的数据集的名称,默认streetscape
  • --n_classes 整型,代表分割图像中有几种类别的像素,默认为2
  • --input_height整型,代表要分割的图像需要resize的长,默认为224
  • --input_width 整型,代表要分割的图像需要resize的宽,默认为224
  • --resize_op 整型,代表resize的方式,如果为1则为默认resize,如果为2,则为letterbox_resize
  • --validate布尔型,代表训练过程中是否需要验证集,默认为True,即使用验证集。
  • --epochs整型,代表要训练多少个epoch,默认为50
  • --train_batch_size整型,代表训练时批量大小,默认为4
  • --model_name 字符串类型,代表训练时使用哪个模型,支持enet,unet,segnet,fcn8等多种模型,默认为unet
  • --train_save_path字符串类型,代表训练时保存模型的路径,默认为weights/unet,即会将模型保存在weights文件夹下,并且每个模型名字前缀以unet开头,后面接迭代次数和准确率构成完整的保存模型的路径。
  • --resume字符串类型,代表继续训练的时候加载的模型路径,默认值为``,即从头训练。
  • --optimizer_name字符串类型,代表训练模型时候的优化方法,支持sgd,adam,adadelta等多种优化方式,默认为adadelta
  • --image_init字符串类型,代表输入图片初始化方式,支持sub_meansub_and_dividedivide,默认为sub_mean
  • --multi_gpus 布尔类型,代表使用是否多卡进行训练,默认为Fasle,如果为True,需要手动调整train.py中的显卡标号,这里默认的是第0,1两块卡。

训练示例

  • 训练本工程提供的二分类数据集:python train.py --model_name unet --image_init divide --n_classes 2
  • 训练12个类别的城市街景分割数据集:python train.py --model_name unet --input_height 320 --input_width 640 --image_init sub_mean --n_classes 12

测试

使用下面的命令测试模型,加载模型的路径,图像输入分辨率等参数需要灵活设置。

python test.py

可用参数如下:

  • --test_images字符串类型,代表测试图所在的文件夹路径,默认为data/test/
  • --output_path字符串类型,代表从测试图预测出的mask图输出路径,默认为data/output/
  • --model_name 字符串类型,代表测试时使用哪个模型,支持enet,unet,segnet,fcn8等多种模型,默认为unet
  • --weights_path字符串类型,代表预测时加载的模型权重,默认为weights/unet.18-0.856895.hdf5,即对应默认模型unet训练出来的模型权重。
  • --input_height整型,代表测试集输入到网络中需要被resize的长,默认为224
  • --input_width整型,代表测试集输入到网络中需要被resize的宽,默认为224
  • --resize_op 整型,代表resize的方式,如果为1则为默认resize,如果为2,则为letterbox_resize
  • --classes整型,代表图片中的像素类别数,默认为2
  • --mIOU布尔型,代表是否启用评测mIOU,默认为False,一旦启用需要提供带有mask图的测试数据集。
  • --val_images字符串类型,代表启用mIOU后测试集原图的路径,默认为data/val_image/
  • --val_annotations字符串类型,代表启用mIOU后测试集mask图的路径,默认为data/val_label/
  • --image_init字符串类型,代表输入图片初始化方式,支持sub_meansub_and_dividedivide,默认为sub_mean

测试示例

  • 测试二分类数据集:python test.py --model_name unet --weights_path weight/unet.xx.hdf5 --classes 2 --image_init divide
  • 测试城市街景分割数据集:python test.py --model_name unet --weights_path weights/unet.xx.hdf5 --classes 12 --image_init sub_mean --input_height 320 --input_width 640 --resize_op 2(2代表使用letterbox方式进行resize)
  • 测试人脸部位分割数据集:

数据增强

我们结合Augmentor这个库实现了一套完整的数据增强策略,即augmentation.py。你可以自由增加,减少各种Augmentor支持的操作。Augmentor这个数据增强库的安装方式为:pip install Augmentor。然后Augmentor是一个独立的脚本需要在你进行训练之前进行本地增强然后将增强出来的数据拷贝到你的原始数据集中去扩充数据。它需要下面4个参数。

  • --train_path 字符串类型,代表训练集的原始图片的路径,默认为./data/images_prepped_train
  • --mask_path字符串类型,代表训练集的分割标签图的路径,默认为./data/annotations_prepped_train
  • --augtrain_path字符串类型,代表增强后的图像的路径,默认为./data/new_img
  • --augtrain_mask 字符串类型,代表增强后的分割标签图的路径,默认为./data/new_mask

其中augtrain_pathaugtrain_mask这两个目录如果没有事先建立的话,程序会为你自动建立。执行数据增强的命令为:

python augmentation.py --train_path xxx --mask_path xxx --augtrain_path xxx --augtrain_mask xxx

然后,我们就会在你指定的增强路径下生成一定数量(数量也可以自己控制,程序中写死了是为每张图像生成5张增强后的图)的增强图了。

数据集

数据集制作使用Labelme即可,然后将得到的json文件使用json_to_dataset.py转换为本工程要用的mask标签图,具体操作步骤为:

  • 使用本工程中的json_to_dataset.py替换掉labelme/cli中的相应文件—json_to_dataset.py 。在cmd中输入python json_to_dateset.py /path/你的json文件夹的路径。注意是把每张图的json文件都放在一个目录下,labelme标注出来的默认是一张图片一个文件夹。
  • 运行后,在json文件夹中会出现mask_png、labelme_json文件夹,mask_png中存放的是所有8位掩码文件!也即是本工程中使用的标签图。
  • 具体来说,我们的标签图就是分别指示每张图片上每一个位置的像素属于几,0是背景,然后你要的类别从1开始往后递增即可。
  • 本工程测试的一个2类的简单分割数据集,下载地址为:https://pan.baidu.com/s/1sVjBfmgALVK7uEjeWgIMug
  • 本工程测试的城市街景分割数据集,下载地址为:https://pan.baidu.com/s/1zequLd0aYXNseGoXn-tdog
  • 本工程测试的人脸部位分割数据集,下载地址为:https://pan.baidu.com/s/1uXZX9c8VFZYVP-ru5MOXXA ,提取码为:09ry 。数据集来源:https://blog.csdn.net/yuanlulu/article/details/89789807

Benchmark(陆续公开)

个人制作2个类别小零件数据集分割结果

Epoch model_name Base Model Segmentation Model Train Acc Train Loss Val Acc Val Loss Test mIOU
50 enet ENet Enet 0.99 0.02 0.98 0.02 0.91
50 fcn8 Vanilla CNN FCN8 0.99 0.02 0.98 0.04 0.93
50 unet Vanilla CNN UNet 0.99 0.02 0.99 0.03 0.94
50 segnet Vanilla CNN SegNet 0.99 0.02 0.99 0.02 0.94
50 icnet Vanilla CNN ICNet 0.99 0.02 0.99 0.02 0.94
50 pspnet Vanilla CNN PSPNet 0.99 0.02 0.99 0.02 0.94
50 mobilenet_unet MobileNet MobileNetUnet 0.99 0.02 0.99 0.02 0.94
50 mobilenet_fcn8 MobileNet MobileNetFCN8 0.99 0.02 0.99 0.02 0.94
50 seunet SENet SEUNet
50 scseunet SCSENet scSEUNet
50 vggunet VGGNet VGGUnet
50 unet_xception_resnetblock XceptionNet Unet_Xception_ResNetBlock
50 pspnet_resnet50 ResNet50 PSPNet_ResNet50
50 deeplab_v2 DeepLab DeepLabV2
50 hrnet HRNet HRNet

城市街景分割数据集分割结果

Epoch model_name Base Model Segmentation Model Train Acc Train Loss Val Acc Val Loss Test mIOU
50 enet ENet Enet 0.71 0.90 0.64 1.02 0.20
50 fcn8 Vanilla CNN FCN8 0.85 0.53 0.67 1.01 0.25
50 unet Vanilla CNN UNet 0.78 0.78 0.62 1.14 0.19
50 segnet Vanilla CNN SegNet 0.41 1.64 0.31 1.94 0.19
50 icnet Vanilla CNN ICNet 0.89 0.38 0.72 0.84 0.33
50 pspnet(576x384) Vanilla CNN PSPNet 0.88 0.41 0.73 0.88 0.33
50 mobilenet_unet MobileNet MobileNetUnet 0.90 0.36 0.73 0.87 0.34
50 mobilenet_fcn8 MobileNet MobileNetFCN8 0.70 0.92 0.58 1.25 0.17
50 seunet Vanilla CNN SEUnet 0.84 0.59 0.77 0.79 0.34
50 seunet SENet SEUNet
50 scseunet SCSENet scSEUNet
50 vggunet VGGNet VGGUnet
50 unet_xception_resnetblock XceptionNet Unet_Xception_ResNetBlock
50 pspnet_resnet50 ResNet50 PSPNet_ResNet50
50 deeplab_v2 DeepLab DeepLabV2
50 hrnet HRNet HRNet

人脸部位分割数据集

Epoch model_name Base Model Segmentation Model Train Acc Train Loss Val Acc Val Loss Test mIOU
50 enet ENet Enet 0.80 0.61 0.84 0.54 0.10
50 fcn8 Vanilla CNN FCN8 0.89 0.36 0.86 0.47 0.17
50 unet Vanilla CNN UNet 0.91 0.31 0.83 0.62 0.17
50 segnet Vanilla CNN SegNet 0.76 0.86 0.79 0.79 0.06
50 icnet Vanilla CNN ICNet 0.86 0.42 0.85 0.44 0.11
50 pspnet Vanilla CNN PSPNet 0.86 0.40 0.85 0.42 0.18
50 mobilenet_unet MobileNet MobileNetUnet 0.83 0.49 0.95 0.45 0.14
50 mobilenet_fcn8 MobileNet MobileNetFCN8 0.76 0.87 0.79 0.77 0.06
50 seunet SENet SEUNet
50 scseunet SCSENet scSEUNet
50 vggunet VGGNet VGGUnet
50 unet_xception_resnetblock XceptionNet Unet_Xception_ResNetBlock
50 pspnet_resnet50 ResNet50 PSPNet_ResNet50
50 deeplab_v2 DeepLab DeepLabV2
50 hrnet HRNet HRNet

个人制作2个类别小零件数据集分割可视化结果

Input Image Output Segmentation Image

城市街景分割数据集分割可视化结果

Input Image Output Segmentation Image

人脸部位分割数据集分割可视化结果

TODO

  • 支持DeepLab,UNet++等。
  • 支持OpenVINO和TensorRT部署。

参考

我的微信公众号

keras-semantic-segmentation's People

Contributors

bbuf avatar pprp avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.