GithubHelp home page GithubHelp logo

fau-masters-collected-works-cgarbin / dropout-vs-batch-normalization Goto Github PK

View Code? Open in Web Editor NEW
8.0 3.0 1.0 2.3 MB

Dropout vs. batch normalization: effect on accuracy, training and inference times - code for the paper

License: MIT License

Python 46.83% Shell 1.11% TeX 52.07%
python keras tensorflow deep-learning neural-network mnist cnn mlp rmsprop sgd

dropout-vs-batch-normalization's Introduction

Batch Normalization vs. Dropout

Code to compare Dropout and Batch Normalization, published in the paper Dropout vs. batch normalization: an empirical study of their impact to deep learning. A free version of the paper is available here.

The experiments compare dropout and batch normalization effect on:

  • Training time
  • Test (inference) time
  • Accuracy
  • Memory usage

Originally, this work was a term project for for Florida Atlantic University's CAP-6619 Deep Learning class term project, Fall 2018. The report that originated the paper is available here.

Since its first release, the project has been updated to use TensorFlow 2.x. There are no changes to the logic, just enough changes to adapt to TensorFlow 2.x. Migration notes are available in this file. The TensorFlow 1.x is available through the v1.2 tag in the code.

What the experiments cover

The report compares the performance and accuracy of Dropout and Batch Normalization:

  • How long it takes to train a network (to run a specific number of epochs, to be more precise).
  • How long it takes to make a prediction with a trained network.
  • How much memory the network uses, measured indirectly with Keras' param_count().

To gather those numbers the code runs a series of tests with different network configurations and different hyperparameters.

The network configurations tested:

  • MLP - multilayer perceptron network (only densely connected layers) was tested with MNIST.
  • CNN - convolutional neural network (convolutional and max-pooling layers) was tested with CIFAR-10.

Several hyperparameters were tested: number of layers, number of units in each layer, learning rate, weight decay, dropout rates for input layer and hidden layer. The MLP network was also tested with a non-adaptive optimizer (SGD) and an adaptive optimizer (RMSProp). See the report for more details.

The raw results are available in this folder. The report has some analysis of those results. More analysis could be done on those results.

Results from experiments

Raw data generated from the experiments executed for the report are available in this directory.

Analysis of the results is available in the report.

More analysis could be done with the data collected. That's what I could do based on the time I had and my limited knowledge at that point.

How to install the environment and run the experiments

How to install the environment and dependencies

Install Python 3

The project uses Python 3.

Verify that you have Python 3.x installed: python --version should print Python 3.x.y. If it prints Python 2.x.y, try python3 --version. If that still doesn't work, please install Python 3.x before proceeding. The official Python download site is here.

From this point on, the instructions assume that Python 3 is installed as python3.

Clone the repository

git clone https://github.com/fau-masters-collected-works-cgarbin/cap6619-deep-learning-term-project.git

The repository is now in the directory cap6619-deep-learning-term-project.

Create a Python virtual environment

IMPORTANT: The project is configured for TensorFlow without GPU. If you are using it on a GPU-enabled system, open requirements.txt and follow th instructions there to use the GPU version of TensorFlow.

The project depends on specific versions TensorFlow and other packages. The safest way to install the correct versions, without affecting other projects you have on your computer, is to create a Python virtual environment specifically for this project.

The official guide to Python virtual environment is here.

Execute these commands to create and activate a virtual environment for the project:

  1. cd cap6619-deep-learning-term-project (if you are not yet in the project directory)
  2. python3 -m venv env
  3. source env/bin/activate (or in Windows: env\Scripts\activate.bat)

Install the dependencies

The project dependencies are listed in the requirements.txt file. To install them, execute this command:

pip install -r requirements.txt This may take several minutes to complete. Once it is done, you are ready to run the experiments.

How to run the experiments

The code is split into these directories:

  • mlp: the MLP tests
  • cnn: the CNN tests

Within each directory, the files are named with the network configuration they test.

MLP experiments

The experiments are driven by the combination of parameters defined in the test generator file. The parameters are specified in named tuples. This is the one used to generate MLP Batch Normalization tests with the SGD optimizer:

    batchnorm_sgd = Parameters(
        experiment_name='batchnorm_mnist_mlp_sgd',
        network=['batch_normalization'],
        optimizer=['sgd'],
        hidden_layers=['2', '3', '4'],
        units_per_layer=['1024', '2048'],
        epochs=['5', '20', '50'],
        batch_size=['128'],
        # Test with the Keras default 0.01 and a higer rate because the paper
        # recommends 'Increase learning rate.'
        learning_rate=['0.01', '0.1'],
        # Test with Keras default 0.0 (no decay) and a small decay
        decay=['0.0', '0.0001'],
        # Test with Keras default (no momentum) and some momentum
        sgd_momentum=['0.0', '0.95'],
    )
MLP with batch normalization
cd mlp/batch_normalization
python3 CAP6619_term_project_mnist_mlp_batchnorm_test_generator.py

# For a quick check of the environment
./batchnorm_mnist_mlp_quick_test.sh

# Using SGD
./batchnorm_mnist_mlp_sgd.sh
# Using RMSProp
./batchnorm_mnist_mlp_rmsprop.sh

# Merges all tests into one file and top 10 files
python3 CAP6619_term_project_mnist_mlp_batchnorm_analysis.py
MLP with dropout
cd mlp/dropout
python3 CAP6619_term_project_mnist_mlp_dropout_test_generator.py

# For a quick check of the environment
./dropout_mnist_mlp_quick_test.sh
# Regular MLP network (no dropout) with SGD to use as baseline
./dropout_mnist_mlp_standard_sgd.sh
# Regular MLP network (no dropout) with RMSprop to use as baseline
./dropout_mnist_mlp_standard_rmsprop.sh
# Dropout MLP network without adjustment and with SGD
./dropout_mnist_mlp_dropout_no_adjustment_sgd.sh
# Dropout MLP network with adjustment and with SGD
./dropout_mnist_mlp_dropout_sgd.sh
# Regular MLP network without adjustment with RMSprop
./dropout_mnist_mlp_dropout_no_adjustment_rmsprop.sh
# Regular MLP network without adjustment with RMSprop
./dropout_mnist_mlp_dropout_rmsprop.sh

# Merges all tests into one file and top 10 files
./CAP6619_term_project_mnist_mlp_dropout_analysis.py

CNN experiments

The experiments are driven by command line parameters. Shell scripts encapsulate the experiements.

cd cnn

# For a quick check of the environment
./cnn_test_quick.sh
# All CNN experiments
./cnn_test_all.sh

Where results are saved

Results are saved in these files:

  • A ..._results.txt file collects the data for each test, e.g. training time, model parameter count, etc. There is one line for each test. See an example here.
  • Several ..._history.json files, one for each test. It contains the training and validation loss/accuracy. It's a JSON file with the contents of the History object created by Keras during training. The name of the file encodes the values of the hyperparameters used for that text. See several examples in this directory.

What needs to be improved

Gathering the data was a great learning experience. Knowing what I know now, I'd have done a few things differently:

  • Force overfit: Dropout and Batch Normalization fight overfitting. Therefore, more interesting data would have been produced if first I had made sure the network was overfitting. This could have been done by reducing the number of samples used to train the network.
  • Extract more data from the results: the results collected quite a bit of data for each combination of parameters. Only some basic analysis was done to write the report. More analysis, e.g. what is the effect of learning rate changes, of momentum changes, etc., could be done.
  • Extract repeated code: there is a fair bit of copy and paste in the code, especially in the CNN tests. It should be refactored and removed.
  • Split the standard MLP tests from the Dropout MLP tests: they are embedded in one file. It would be easier to manage them if they were in separate files, like it was done for the CNN code.
  • See more "TODO" in the code: there are few "TODO" in the code, pointing to more specific improvements that could be done.

License and citation

Licensed as MIT. If you use parts of this project, please link back to this repository and cite the paper (how to cite).

And please let me know you are referring to it - personal curiosity.

Miscellanea

This was my first real-file (of sorts) experience with machine learning and the first time I wrote a significant amount of Python and Keras code. It's not a polished result by any means.

Suggestions for improvements are always appreciated.

dropout-vs-batch-normalization's People

Contributors

cgarbin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

geomars

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.