GithubHelp home page GithubHelp logo

zerobn's Introduction

ZeroBN

This repository contains a pytorch implementation for the latency-critical neural network adjustment method. avatar

Dependencies

An environment.yml has been uploaded for creating a conda environment.

Dateset

This Cifar 10 dataset will be downloaded automatically by the scripts. And the Imagenet can be downloaded from IMAGENET_2012, and the classes we used are listed in the imagenet100.txt.

Cifar 10

cd Cifar10
python train_cifar10.py -sr --s 0.0001 --dataset cifar10 --arch vgg --depth 19 --our 1 --prune_ratio 0.5 --epochs 160 --save ./logs-vgg
python train_cifar10.py -sr --s 0.00001 --dataset cifar10 --arch resnet --depth 164 --our 1 --prune_ratio 0.5 --epochs 160 --save ./logs-resnet
python train_cifar10.py -sr --s 0.00001 --dataset cifar10 --arch densenet --depth 40 --our 1 --prune_ratio 0.5 --epochs 160 --save ./logs-densenet

ImageNet

cd ImageNet
python train_imagenet.py -a vgg19_bn --save ./logs-vgg19bn/ --lr 0.01 -sr --ssr 0.001 ../imagenet_100/ --our 1 --prune_ratio 0.5 --epochs 90 --batch-size 128 --auto
python train_imagenet.py -a googlenet --save ./logs-googlenet/ --lr 0.1 -sr --ssr 0.0001 ../imagenet_100/ --our 1 --prune_ratio 0.5 --epochs 90 --batch-size 128 --auto
python train_imagenet_resnetnew.py -a resnet50_new --save ./logs-resnet50new/ --lr 0.1 -sr --ssr 0.0001 ../imagenet_100/ --our 1 --prune_ratio 0.5 --epochs 90 --batch-size 128 --auto

Saved Models

All models that are trained by our method in the paper are publicly accessed at OneDrive. And you only need to add two arguments to the above commands to evaluate these models:

--resume PATH_TO_CHECKPOINT --evaluate

Latency Predictor

cd LatPredictor

We train a Backpropagation (BP) neural network via Matlab using the script. It will save a .mat file and you can read this .mat file by Python (an example is provided in here).

Export the results from the script into a latency predictor written by Pytorch, which can be used in the training process.

Tensorflow Support

As Tensorflow is widely used in edge GPU, we also support covert our model from PyTorch into Tensorflow frozen .pb file.

cd Torch2pb
python trans_vgg19.py --torch_model /path/to/torch_checkpoint/ --save /path/to/savevgg19.pb
python trans_googlenet.py --torch_model /path/to/torch_checkpoint/ --save /path/to/savegooglenet.pb
python trans_resnet.py --torch_model /path/to/torch_checkpoint/ --save /path/to/saveresnet.pb

If you want to use quantization, you can refer to the scripts with _half to generate a quantized model.

The checkpoints saved in Pytorch format still include zeros in the weights, so the models' sizes are not changed. But our Tensorflow convertor will remove all zeros and the models' sizes become smaller.

Results

Please refer to the paper.

Project Information

Copyright (c) HP-NTU Digital Manufacturing Corporate Lab, Nanyang Technological University, Singapore.

If you use the tool or adapt the tool in your works or publications, you are required to cite the following reference:

@inproceedings{huai2021zerobn,
  title={ZeroBN: Learning Compact Neural Networks For Latency-Critical Edge Systems},
  author={Huai, Shuo and Zhang, Lei and Liu, Di and Liu, Weichen and Subramaniam, Ravi},
  booktitle={2021 58th ACM/IEEE Design Automation Conference (DAC)},
  pages={151--156},
  year={2021},
  organization={IEEE}
}

Contributors: Shuo Huai, Lei Zhang, Di Liu, Weichen Liu, Ravi Subramaniam (HP).

If you have any comments, questions, or suggestions please create an issue on github or contact us via email.

Shou Huai <shuo [DOT] huai [AT] ntu [DOT] edu [DOT] sg>

It is a contribution made from the HP-NTU Corp Lab. 2 public mirror repositories for it: HP Inc., ntuliuteam.

zerobn's People

Contributors

shuo-huai avatar classicvalues avatar mend-bolt-for-github[bot] avatar wrighto avatar

Stargazers

 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.