GithubHelp home page GithubHelp logo

shape-bias-cnn's Introduction

shape-bias-CNN

一种跨域形状偏好 CNN 设计与实现。

  • 本项目设计并实现了一种基于形状识别的预训练方法,通过迁移学习实验来验证网络特性,结果显著图表明使用该方法预训练的网络具有显著的形状偏好,而且在某些跨域分类任务上能与在 ImageNet 上预训练的网络性能相比较;
  • 基于双任务的形状理解网络示意图:

基于基础形状数据集的分类任务

基础形状分类任务,目的是使网络认识形状,提取形状特征。具体实现都在 shape-classification 文件夹下。

基础形状数据集

这里是形状数据集的生成过程,主要包括以下内容。

  • 用于在 224*224 大小的偏黑或偏白(各占 50%)背景中随机生成不同颜色、不同大小、不同位置的形状;
  • 使用风格迁移对生成形状进行随机风格化;
  • 两者结合成为最终的形状数据集。(注:本次实验用数据集暂未开源。)

随机形状生成 ShapeGen.py

  • 目录下的 xxx-ShapeGen.py 文件用于生成随机颜色、随机大小、随机位置的形状,称之为 原始形状数据集 OSD (original shape dataset)
  • 生成图片大小为 224*224 ;
  • 生成图片背景为偏黑或偏白(各占 50% ,该值可根据需要调节,此处设置 50% 是为了使得样本均匀),目的是为了让生成的形状和背景相区分开,避免由于形状颜色和背景颜色相似或相同导致无法识别形状;
  • 一代形状数据集包含随机生成的四种形状:矩形、圆形、椭圆形、三角形
  • 二代形状数据集包含随机生成的十种形状:矩形、圆形、椭圆形、三角形、棱形、五角星形、五边形、六边形、八边形、梯形
  • 生成结果示例图:

风格化形状图片

  • 由原始形状数据集 OSD 生成 风格化形状数据集 SSD (stylized shape dataset)
  • 风格化代码来源于 https://github.com/rgeirhos/Stylized-ImageNet
  • 按照自己要求修改参数,输入命令进行风格化操作,示例命令如下:
python stylize.py --content-dir F:\pyprj_testfile\shape\rectangle --style-dir F:\graduation_prj\texture-vs-shape_ArticalRec\train --output-dir F:\pyprj_testfile\shape\style_rectangle --num-styles 10 --alpha 0.3 --content-size 0 --style-size 256
  • 生成结果示例图:

形状数据集

将原始形状数据集 OSD 和风格化形状数据集 SSD 相结合,即把对应相同形状类放在同一个文件夹下,得到 形状数据集 SD (shape dataset)

数据集大小

  • 一代形状数据集:4类,每类 550 张(原始形状 50 张,风格化形状 500 张),共 2200 张
  • 二代形状数据集:10类,共 12884 张,每类约 1300 张(对于部分类别生成质量不高的图片进行了清除和筛选)。

训练结果

在以相同方法生成的验证集上准确率达到 94.55% ;不但可以很好的分类简单形状,也可以 检测并提取跨域形状特征 。示例结果如下。

基于组合复杂形状数据集的分类任务

组合复杂形状分类任务,目的是通过学习复杂形状可以由简单形状组合得到,加深网络对形状的理解。具体实现都在 complicated-shape-classification 文件夹下。

组合复杂形状数据集

  • 通过 complicated-shape-generation.py 文件生成复杂形状;
  • 本项目中,复杂形状定义是由多种简单形状组合生成。具体来说,通过矩形、圆形、椭圆、三角形四种简单形状任意组合,共 15 种组合方式,分别生成不同颜色的任意形状;
  • 本项目中,第一版本复杂形状最终生成 15 个类,每类 460 张,共 6900 张复杂形状数据集的图像。第一个版本中同一张图片中,形状的线宽和颜色都相同,主要考虑不想让网络通过颜色和线宽的“捷径”来区分形状,而是通过形状的内在特征。示例如下(示例分别为矩形三角形组合,圆形椭圆形组合)。
  • 第二版本复杂形状增加了不同颜色、不同线宽的形状,每类增加 240 张,共 700 张图片。然后每张图像生成一张风格化图像,每类生成 700 张风格化图像。综上所述,第二个版本复杂形状最终生成 15 个类,每类 1400 张,共 21000 张图片。新增图像示例如下。

训练结果

可以很好的将复杂图像中含有的简单形状检测出来,示例结果如下。

基于双任务的形状理解网络

  • 本项目通过同时训练两个形状相关任务,希望网络加强对形状的理解,并在迁移学习中获得形状偏好;
  • 任务一:基于基础形状数据集的(10 类)分类任务,任务 1 详情可见文件夹 shape-classification
  • 任务二:基于组合复杂形状数据集的(15 类)分类任务,任务 2 详情可见文件夹 complicated-shape-classification
  • 联合损失函数:L_total = L1 + α * L2 ,其中 α 为可调超参数。

迁移学习形状偏好研究 (shape-bias)

本项目设计实验探究网络的形状偏好和跨域性能,部分实验结果如下。

salience maps

  • 显著图实验直观的表明,由形状数据集预训练的网络具有显著的形状偏好;
  • 结果分别为 | Original Image | ImageNet-Pretrained | ShapeDataset-Pretrained freezen-layer1 | ShapeDataset-Pretrained freezen-layer2 |

迁移学习跨域性能研究

method\domain art_painting cartoon photo sketch sketch
baseline(ResNet18) 58.6% 66.4% 34.0% 27.5% 46.6%
baseline + shapetask1 62.68% 59.47% 33.07% 30.44% 46.41%
MixStyle 61.9% 71.5% 41.2% 32.2% 51.7%
EFDM(SOTA) 63.2% 73.9% 42.5% 38.1% 54.4%
EFDM + shapetask1 68.54% 66.24% 53.67% 39.47% 56.98%

探索

  • 基于自然语言指导的偏好分类网络;
  • natural language is at once more expressive and easier to obtain than formal supervision. 自然语言往往比形式监督更具有表现力,而且更容易获得。

shape-bias-cnn's People

Contributors

dialogueeeeee avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.