GithubHelp home page GithubHelp logo

oscillations-qat's Introduction

Overcoming Oscillations in Quantization-Aware Training

This repository containes the implementation and experiments for the paper presented in

Markus Nagel*1, Marios Fournarakis*1, Yelysei Bondarenko1, Tijmen Blankevoort1 "Overcoming Oscillations in Quantization-Aware Training", ICML 2022. [ArXiv]

*Equal contribution 1 Qualcomm AI Research (Qualcomm AI Research is an initiative of Qualcomm Technologies, Inc.)

You can use this code to recreate the results in the paper.

Reference

If you find our work useful, please cite

@InProceedings{pmlr-v162-nagel22a,
  title = 	 {Overcoming Oscillations in Quantization-Aware Training},
  author =       {Nagel, Markus and Fournarakis, Marios and Bondarenko, Yelysei and Blankevoort, Tijmen},
  booktitle = 	 {Proceedings of the 39th International Conference on Machine Learning},
  pages = 	 {16318--16330},
  year = 	 {2022},
  editor = 	 {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
  volume = 	 {162},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {17--23 Jul},
  publisher =    {PMLR},
  pdf = 	 {https://proceedings.mlr.press/v162/nagel22a/nagel22a.pdf},
  url = 	 {https://proceedings.mlr.press/v162/nagel22a.html}
  }

Method and Results

When training neural networks with simulated quantization, we observe that quantized weights can, rather unexpectedly, oscillate between two grid-points. This is an inherent issue problem caused by the straight-through-estimator (STE). In our paper, we delve deeper in this little understood phenomenon and show that oscillations harm accuracy by corrupting the EMA statistics of the batch-normalization layers and by preventing convergence to local mimima.

We propose two novel methods to tackle oscillations at their source: oscillations dampening and iterative state freezing We demonstrate that our algorithms achieve state-of-the-art accuracy for low-bit (3 & 4 bits) weight and activation quantization of efficient architectures, such as MobileNetV2, MobileNetV3, and EfficentNet-lite on ImageNet.

How to install

Make sure to have Python โ‰ฅ3.6 (tested with Python 3.6.8) and ensure the latest version of pip (tested with 21.3.1):

source env/bin/activate
pip install --upgrade --no-deps pip

Next, install PyTorch 1.9.1 with the appropriate CUDA version (tested with CUDA 10.0, CuDNN 7.6.3):

pip install torch==1.9.1 torchvision==0.10.1

Finally, install the remaining dependencies using pip:

pip install -r requirements.txt

Running experiments

The main run file to reproduce all experiments is main.py. It contains commands for quantization-aware training (QAT) and validating quantized models. You can see the full list of options for each command using python main.py [COMMAND] --help.

Usage: main.py [OPTIONS] COMMAND [ARGS]...

Options:
  --help  Show this message and exit.

Commands:
  train-quantized

Quantization-Aware Training (QAT)

All models are fine-tuned starting from pre-trained FP32 weights. Pretrained weights may be found here

MobileNetV2

To train with oscillations dampening run:

python main.py train-quantized  --arhcitecture mobilenet_v2_quantized
--images-dir path/to/raw_imagenet --act-quant-method MSE  --weight-quant-method MSE 
--optimizer SGD --weight-decay 2.5e-05 --sep-quant-optimizer 
--quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0 
--model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0
# Dampening loss configurations 
--oscillations-dampen-weight 0 --oscillations-dampen-weight-final 0.1 
# 4-bit best learning rate
--n-bits 4 --learning-rate 0.0033 
# 3-bits best learning rate
--n-bits 3 --learning-rate 0.01

To train with iterative weight freezing run:

python main.py train-quantized  --arhcitecture mobilenet_v2_quantized
--images-dir path/to/raw_imagenet --act-quant-method MSE  --weight-quant-method MSE 
--optimizer SGD  --sep-quant-optimizer 
--quant-optimizer Adam --quant-learning-rate 1e-5 --quant-weight-decay 0.0 
--model-dir /path/to/mobilenet_v2.pth.tar --learning-rate-schedule cosine:0
# Iterative weight freezing configuration
--oscillations-freeze-threshold 0.1
# 4-bit best configuration
--n-bits 4 --learning-rate 0.0033 --weight-decay 5e-05 --oscillations-freeze-threshold-final 0.01 
# 3-bit best configuration
--n-bits 3 --learning-rate 0.01 --weight-decay 2.5e-05 --oscillations-freeze-threshold-final 0.011

For end user's convenience, bash scripts are provided under /bash/ for reproducing our experiments.

./bash/train_mobilenetv2.sh --IMAGES_DIR path_to_raw_imagenet --MODEL_DIR path_to_pretrained_weights # QAT training of MobileNetV2 with defaults (method 'freeze' and 3 bits)
./bash/train_efficientnet.sh --IMAGES_DIR path_to_raw_imagenet --METHOD damp --N_BITS 4

oscillations-qat's People

Contributors

mhofmann-qc avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

oscillations-qat's Issues

Code of LSQ-BR

Hi! Thanks for your work.
I wonder if you will share the code of LSQ-BR which I want to do the same comparison experiment of paper.
Thanks!

weight oscilation code

Thanks for your great work. I have a problem in code "oscillation_tracking_utils.py", where self.ema_x_int and self.frozen_x_int are never used in any other place. the __call__ function of class TrackOscillation also still returns x_int, instead of self.ema_x_int, which seems the weight oscillation is not dampened.

Questions about binary QAT training

Great work!
It seems the idea of your dampening oscillation during QAT is consistent with the idea about C2I & FF ratio from How Do Adam and Training Strategies Help BNNs Optimization? which argues that some filp-flops do not contribute to the final weights, but just harm the training stability. And you've written in the paper the lower the bit-width b, ......, they cause a proportionally larger shift in the output distribution. So it seems pretty straightforward for me that perhaps this dampening loss will also work in binary QAT training, so have you done some relative experiments and what's the results?
Looking forward to your reply!
Best regards.

Reproduction of baseline

Hi! Thanks for your excellent work.

I noticed that you stated in Table 8 of your paper that the results of EfficientNet's 4-bit quantization LSQ (baseline) were reproduced by yourself, but when I reproduced it with this code (with osc_freeze.threshold=None and osc_damp.weight_final=None), I found that the accuracy of the baseline was 73.2, which does not match the results in your paper (72.3).

I can reproduce the results of LSQ+Freeze and LSQ+Dampen in the paper, but the results of baseline are much higher than those listed in the paper, what code does your baseline use? Did I get it wrong?
Thanks! Looking forward to your reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.