- Building a Neural Network Classifier with MNIST Dataset
- For the MNIST, We implement a data loader and a classification model.
- We obtain the CNN model and the MLP model and compare the performance and loss of the model.
-
Python version is 3.8.
-
Used 'PyTorch' and device type as 'GPU'.
-
requirements.txt
file is required to set up the virtual environment for running the program. This file contains a list of all the libraries needed to run your program and their versions.$ conda create -n [your virtual environment name] python=3.9 $ conda activate [your virtual environment name] $ pip install -r requirements.txt
- Create your own virtual environment.
- Activate your Anaconda virtual environment where you want to install the package. If your virtual environment is named 'test', you can type conda activate test.
- Use the command pip install -r requirements.txt to install libraries.
-
Each of tar files contains 60,000 training images and 10,000 test images respectively.
-
Each image has its own filename like
{ID}_{Label}.png
. -
Run
dataset.py
to extract .tar compressed files.python dataset.py
|-- data | |-- train | | |- 00000_5.png | | |- 00001_0.png | | |- ... | | | |-- test | |- 00000_7.png | |- 00001_2.png | |- ...
- You need to run
main.py
.python main.py
- Model training configuration can be set in args.
- The default settings are as follows.
model_type = 'LeNet5'
epochs = 20
batch_size = 64
-
You can check the structure of the model by running
model.py
.python model.py
Layer (type) | Output Shape | Param # |
---|---|---|
Conv2d-1 | [-1, 6, 24, 24] | 156 |
MaxPool2d-2 | [-1, 6, 12, 12] | 0 |
Conv2d-3 | [-1, 16, 8, 8] | 2,416 |
MaxPool2d-4 | [-1, 16, 4, 4] | 0 |
Conv2d-5 | [-1, 120, 1, 1] | 30,840 |
Linear-6 | [-1, 84] | 10,164 |
Linear-7 | [-1, 10] | 850 |
Total | 44,426 |
Layer (type) | Output Shape | Param # |
---|---|---|
Linear-1 | [-1, 56] | 43,960 |
Linear-2 | [-1, 28] | 1,596 |
Linear-3 | [-1, 10] | 290 |
Total | 45,846 |
- LeNet-5
-
Conv2d-1(conv1) input channel : 1, output channel : 6, kernal_size : 5 * 5, bias : 6
total parms : (5 * 5 * 1 * 6) + 6 = 156 -
Conv2d-3(conv2) input channel : 6, output channel : 16, kernal_size : 5 * 5, bias : 16
total parms : (5 * 5 * 6 * 16) + 16 = 2,416 -
Conv2d-5(conv3) input channel : 16, output channel : 120, kernal_size : 4 * 4, bias : 120
total parms : (4 * 4 * 16 * 120) + 120 = 30,840 -
Linear-6(fc1) input channel : 120, output channel : 84, bias : 84
total parms : (120 * 84) + 84 = 10,164 -
Linear-7(fc2) input channel : 84, output channel : 10, bias : 10
total parms : (84 * 10) + 10 = 850Total parameters of LeNet-5 = 156 + 2,416 + 30,840 + 10,164 + 850 = 44,426
- CustomMLP
-
Linear-1(fc1) input channel : 784, output channel : 56, bias : 56
total parms : (784 * 56) + 56 = 43,960 -
Linear-2(fc2) input channel : 56, output channel : 28, bias : 28
total parms : (56 * 28) + 28 = 1,596 -
Linear-3(fc3) input channel : 28, output channel : 10, bias : 10
total parms : (28 * 10) + 10 = 290Total parameters of CustomMLP = 43,960 + 1,596 + 290 = 45,846
-
Accuracy for each model
-
Loss for each model
-
As a result of comparing the performance of the Custom MLP and LeNet-5 models through 20 epoch, the result were 97.66 for Custom MLP and 98.80 for LeNet-5. Although the similar parameters of the two models (Custom MLP: 45,846, LeNet-5: 44,426), the CNN-based model performs better than the MLP model.
-
From checking the learning curves of both models, the loss decreases exponentially as the learning progresses, and it seems to converge from 0.10-0.15 for Custom NLP and 0.04-0.06 for LeNet-5.
-
To improve the LeNet-5 model performance, two regularization techniques were applied: Batch normalization and Dropout. The performance was 98.80 for LeNet-5, and 99.13 for LeNet-5 with regularization, showing very slight performance improvement. Due to the low complexity of the data or model used in the experiment, it is assumed that the performance difference between the two models did not appear significantly.
-
For the known accuracy of the existing LeNet-5, the reference result was referred to as reference result. The reference result (about 97.5, 97.64)
LeNet5 model
https://deep-learning-study.tistory.com/503
https://deep-learning-study.tistory.com/368?category=963091
Train
https://velog.io/@skarb4788/%EB%94%A5-%EB%9F%AC%EB%8B%9D-MNIST-%EB%8D%B0%EC%9D%B4%ED%84%B0PyTorch