GithubHelp home page GithubHelp logo

sayakpaul / sharpness-aware-minimization-tensorflow Goto Github PK

View Code? Open in Web Editor NEW
60.0 3.0 12.0 230 KB

Implements sharpness-aware minimization (https://arxiv.org/abs/2010.01412) in TensorFlow 2.

Python 5.93% Jupyter Notebook 94.07%
generalization computer-vision deep-neural-networks loss-landscape tpu-acceleration tensorflow

sharpness-aware-minimization-tensorflow's Introduction

Sharpness-Aware-Minimization-TensorFlow

This repository provides a minimal implementation of sharpness-aware minimization (SAM) (Sharpness-Aware Minimization for Efficiently Improving Generalization) in TensorFlow 2. SAM is motivated by the connections between the geometry of the loss landscape of deep neural networks and their generalization ability. SAM attempts to simultaneously minimize loss value as well as loss curvature thereby seeking parameters in neighborhoods having uniformly low loss value. This is indeed different from traditional SGD-based optimization that seeks parameters having low loss values on an individual basis. The figure below (taken from the original paper) demonstrates the effects of using SAM -

My goal with this repository is to be able to quickly train neural networks with and without SAM. All the experiments are shown in the SAM.ipynb notebook (Open In Colab). The notebook is end-to-end executable on Google Colab. Furthermore, the notebook utilizes the free TPUs (TPUv2-8) Google Colab provides allowing quick experimentation.

Notes

Before moving to the findings, please be aware of the following notable differences in my implementation:

  • ResNet20 (attributed to this repository) is used as opposed to PyramidNet and WideResNet.
  • ShakeDrop regularization has not been used.
  • Two simple augmentation transformations (random crop and random brightness) have been used as opposed to Cutout, AutoAugment.
  • Adam has been used as the optimizer with the default arguments as provided by TensorFlow with a ReduceLROnPlateau. Table 1 of the original paper suggests using SGD with different configurations.
  • Instead of training for full number of epochs I used early stopping with a patience of 10.

SAM has only one hyperparameter namely rho that controls the neighborhood of the parameter space. In my experiments, it's defaulted to 0.05. For other details related to training configuration (i.e. network depth, learning rate, batch size, etc.) please refer to the notebook.

Findings

Number of Parameters (million) Final Test Accuracy (%)
With SAM 0.575114 82.78
Without SAM 0.575114 79.51

Note that with the current experiment setup,

  • With-SAM the models runs for 100 epochs for 8 minutes.
  • Without-SAM the model tends to terminate early due to slow convergence. The shown run ended at 78 epochs for 6 minutes; another run ended at 72 epochs with accuracy of 79.5%.

Acknowledgements

sharpness-aware-minimization-tensorflow's People

Contributors

rainwoodman avatar sayakpaul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

sharpness-aware-minimization-tensorflow's Issues

Potential bug

The structure of the train_step in cell 8 of the notebook is very unconventional.

def train_step(self, data):
... first model evaluation
... first tape gradient
... second model evaluation
... update parameters
... second tape gradient      

Usually for the parameter update to affect the second tape gradient the update shall be before the second model evaluation.

def train_step(self, data):
... first model evaluation
... first tape gradient
... update parameters
... second model evaluation
... second tape gradient      

Verification

Have you conducted experiments to verify that your implementation could reproduce similar results to the original implementation?
Thanks

Reproducing WRN-28-10 (SAM) for SVHN dataset

I am trying to reproduce the results for WRN-28-10 (SAM) trained on 10-class classification SVHN dataset (Percentage Error 0.99) - https://paperswithcode.com/sota/image-classification-on-svhn

I'm able to train WRN-28-10 using https://github.com/hysts/pytorch_wrn (Modified the script to incorporate SAM into it)

I'm achieving a test accuracy of 93%. How can I replicate the SOTA Percentage Error 0.99 for WRN-28-10 (SAM). Which hyperparameters do I use?

Any help is appreciated!!

Adding a data generator

That's a great example - thanks. When I try replacing train_ds with a data generator though, I get "NotImplementedError: When subclassing the Model class, you should implement a call method." I also tried adding this call method to SAMModel, but that didn't work either:

def call(self, inputs):
    return self.resnet_model(inputs)

Any ideas? The attached main.py runs your code plus a simple data generator which just shuffles the training data.

main.zip

About the result

Hi, Sayak

Thanks for your contribution. I have a question about why the the model with SAM is not better than the model without SAM. Tnank you!

Yong

Memory requirements

Hi,

Great work. From the notebook it looks like the runtime is about a third longer with SAM. Is that right? That's not too bad.

But how much more memory does the training require? We are calculating gradients twice for each step right? Don't we need more memory for that?

PS If I set the learning rate to 1e-2 the model without SAM quite outperforms the model with SAM
With
Epoch 97/100 49/49 [==============================] - 4s 78ms/step - loss: 0.5391 - accuracy: 0.7984 - val_loss: 0.5719 - val_accuracy: 0.8091 - lr: 0.0025
Without
Epoch 53/200 49/49 [==============================] - 5s 101ms/step - loss: 0.2032 - accuracy: 0.9299 - val_loss: 0.4490 - val_accuracy: 0.8669 - lr: 7.8125e-05

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.