GithubHelp home page GithubHelp logo

digit-recognizer's Introduction

深度学习数字识别

项目概况

本项目利用Keras框架,达到了识别手写数字的功能。

train.csv是训练集,这些数据用来建立一个模型。数字以图片的形式存储在电脑中,我们将这个数字图片分成784个像素(pixel),存储在一个向量里面。训练集的一行就表示一个向量,也就是一个数字。784列对应的就是784个像素。第一行label表明这一行是数字几。我们的数据有42000行,即有42000个手写的数字图片。

test.csv是测试集,这些数据用来检测我们的模型。我们的模型最后就是要输出测试集中的这28000张图片都是数字几,然后输出一个文件保存结果。

项目分析

模型的结构图,以及流程分析

模型结构图如下:

Layer (type) Output Shape Param #
conv2d_1 (Conv2D) (None, 24, 24, 32) 832
batch_normalization_1 (BatchNormalization) (None, 24, 24, 32) 96
max_pooling2d_1 (MaxPooling2D) None, 12, 12, 32) 0
zero_padding2d_1 (ZeroPadding) (None, 14, 14, 32) 0
conv2d_2 (Conv2D) (None, 12, 12, 48) 13872
batch_normalization_2 (BatchNormalization) (None, 12, 12, 48) 48
max_pooling2d_2 (MaxPooling2D) (None, 6, 6, 48) 0
dropout_1 (Dropout) (None, 6, 6, 48) 0
flatten_1 (Flatten) (None, 1728) 0
dense_1 (Dense) (None, 1024) 1770496
dense_2 (Dense) (None, 10) 10250

在进行卷积操作的时候用到的是keras.layers.Conv2Dkeras.layers.MaxPool2D来分别进行卷积和最大化池化操作。

在卷积层和池化层之间加入规范层keras.layers.BatchNormalization来在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1。

ZeroPadding层用于保持边界信息,如果没有加padding的话,输入图片最边缘的像素点信息只会被卷积核操作一次,但是图像中间的像素点会被扫描到很多遍,那么就会在一定程度上降低边界信息的参考程度,但是在加入padding之后,在实际处理过程中就会从新的边界进行操作,就从一定程度上解决了这个问题。其中28×28的图片经过第一次卷积后是24×24,第一次池化后变为12×12,Padding后变为14×14,第二次卷积后为12×12,第二次池化后变为6×6,最后得到了48张6×6的平面。

卷积神经网络后接了舍弃机制和两个全连接的神经网络。使用的是keras.layers.Dropoutkeras.layers.Dense的方法。

竞赛最佳精度(Kaggle平台给出)以及提交总次数

总共提交1次,在Kaggle平台上得到准确率:

准确率 = 99.057%

结果截图如下图所示:

image

试比较实验中使用的不同参数效果,并分析原因。

  1. batch_size批量大小决定我们一次训练的样本数目,将影响到模型的优化程度和速度。batch_size的正确选择是为了在内存效率和内存容量之间寻找最佳平衡。
单次epoch=(全部训练样本/batchsize)/iteration=1

如果batch_size过小,训练数据就会非常难收敛,从而导致underfitting。适当地增加batch_size可以通过并行化提高内存利用率、减少单次epoch的迭代次数、提高运行速度。适当的增加batch_size可以使得梯度下降方向准确度增加,训练震动的幅度减小。 2. 采样窗口的大小3×3的效果好于5×5,一般小的采样窗口更加精细。 3. 分类任务的loss function一般适合用cross_entropy,而回归问题的用MSE好一点。

问题思考

Q1:训练什么时候停止是最合适的?简要陈述你的实现方式,并分析固定迭代次数与通过验证集调整等方法的优缺点。

  1. epochs到30左右会有明显的过拟合现象,在测试集上的表现反而不如epochs为10的时候。
  2. 通过validation_split参数将train.csv划分为两部分,一部分用来构建模型,另一部分来验证模型的准确率。这里取的验证集为整个train.csv的20%的大小,是后期调参调出来的,开始一般取10%-25%,慢慢降低比例。

Q2:参数的初始化是怎么做的?不同的方法适合哪些地方?

默认使用的初始化方法为glorot也就是Xavier初始化。另外还有零均值初始化,高斯分布初始化,正交初始化等,keras提供了uniform、lecun_uniform、normal、orthogonal、zero、glorot_normal、he_normal这几种。

  • 使用 RELU(without BN) 激活函数时,最好选用 He 初始化方法,将参数初始化为服从高斯分布或者均匀分布的较小随机数
  • 使用 BN 时,减少了网络对参数初始值尺度的依赖,此时使用较小的标准差(eg:0.01)进行初始化即可
  • 借助预训练模型中参数作为新任务参数初始化的方式也是一种简便易行且十分有效的模型参数初始化方法

Q3:过拟合是深度学习常见的问题,有什么方法可以方式训练过程陷入过拟合?

  1. 使用规范化技术(尤其是弃权和卷积层)来减少过度拟合
  2. 使用充分大的数据集来避免过度拟合

Q4:CNN(卷积神经网络)相对于全连接神经网络的优点?

CNN相对于全联接神经网络的主要不同点分为四个方面:

  1. 局部链接
  2. 权值共享
  3. 池化
  4. 从局部到全局

使用卷积层(卷积核)可以大大减少层中的参数的数目,使学习过程更容易。

心得体会

数据处理

  1. kaggle上的digit recognizer数据集比较干净,都是已经存储到784列里的(有784个像素),所以不用进行图片的放大缩小等处理。
  2. 由于数据集较大,需要减少图像的存储。最大的pixel为255,要用27空间来存储。所以给train和test矩阵都乘以1.0/255.0来进行处理,把像素值控制在0-1之间。像素值越接近于1,就越黑,越接近于0越白。
  3. 对train第一列的label进行one-hot编码,使类别独立出来,能够更好地处理离散型数据。模型中要建立一个softmax层来给类别分配概率,这些概率是独立同分布的。所以要求输入的标签也应该是独立的,因此我们要对标签(0-9)进行one-hot编码,进而计算交叉熵。

模型搭建

  1. BatchNormalization可以:1.加速收敛;2.控制过拟合,可以少用或不用Dropout和正则;3.降低网络对初始化权重不敏感;4.允许使用较大的学习率。
  2. DropoutDense的选取真的是玄学,不知道怎么解释效果的提升和降低。

digit-recognizer's People

Watchers

ChanceXuan avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.