GithubHelp home page GithubHelp logo

sgd-opt's Introduction

A Comparison Among Different Variants of Gradient Descent Algorithm

###Introduction


This script implements and visualizes the performance the following algorithms, based on the MNIST hand-written digit recognition dataset:

  • Stochastic Gradient Descent(SGD)
  • Momentum
  • Nesterov Accelerated Gradient(NAG)
  • Adagrada
  • Adadelta
  • RMSprop
  • Adam

All the detail of the algorithms are described in the blog post An overview of gradient descent optimization algorithms by Sebastian Ruder

###Dataset


The MNIST dataset contains 60,000 samples for training and 10,000 for validating. It is naturally divided into 10 classes corresponding to digit 0 to 9. The amount of samples for each class is well-balanced. Each digit image is pre-processed and rescaled into a 28*28 gray scale array, ranging from 0 to 255

###Model


A traditional 3 layer neural network is adopted, with 28*28 input units, 25 hidden units and a softmax output layer

###Performance


Here are the training and validating accuracy of each algorithm, with 30 epochs and 100 mini-batch:

  • SGD: 95.36% vs 94.06%
  • Momentum: 97.52% vs 94.88%
  • NAG: 97.47% vs 94.33%
  • Adagrad: 96.17% vs 93.95%
  • Adadelta: 94.65% vs 93.84%
  • RMSprop: 96.35% vs 94.51%
  • Adam: 96.54% vs 94.12%

###Visualization


Here is the visualization of cost decreasing w.r.t each mini-batch within the first 10 epochs:

image

###Conclusion


The variants of gradient descent algorithm can be roughly divided into 2 types: Momentum-like SGD and Adaptive learning rate SGD.

As mentioned in the blog post Adaptive learning rate SGD is suitable for large-scale sparse optimization problem (e.g, predict CTR). While in this case, data is not sparse, Momentum-like SGD significantly outperforms the others.

sgd-opt's People

Contributors

mazefeng avatar

Watchers

James Cloos avatar Leo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.