GithubHelp home page GithubHelp logo

orthoreg's Introduction

OrthoReg: Robust Network Pruning Using Orthogonality Regularization

Code release for OrthoReg. The main paper can be found at https://arxiv.org/abs/2009.05014.

Brief Summary

OrthoReg is a regularization strategy aimed at making importance estimates in prior state-of-the-art pruning methods based on loss preservation (e.g., see Fisher pruning) more reliable and robust for large pruning ratios. We provide code for extracting Early-Bird Tickets from and for Iteratively Pruning VGG-13, MobileNet-V1, and ResNet-34.

The code requires:

  • Python 3.6 or higher

  • Pytorch 1.4 or higher

Requirements

To install requirements (uses pip):

./requirements.sh

Organization

The provided modules serve the following purpose:

  • main.py: Provides functions for training pruned networks in general, including Early-Bird Tickets.

  • eval.py: Calculate train accuracy, test accuracy, degree of orthogonality (layerwise), FLOPs, and compression ratio.

  • imp_estimator.py: Importance estimators for different methods (Fisher, BN, TFO, GraSP, RDT, L1, and SFP).

  • pruner.py: Pruning engine (includes pruned networks' classes).

  • models.py: Model classes for VGG-13, MobileNet-V1, ResNet-34.

  • config.py: Hyperparameters and progress bar for training models.

Pretrained models are to be stored in the directory pretrained and pruned models will be saved in pruned_nets. We provide both minimally and fully trained models for Early-Bird Tickets and iterative pruning experiments, respectively (see below).

Extracting Early-Bird Tickets

To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command

python main.py --ebt=True --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=25

Iterative Pruning

To extract Early-Bird Tickets from a model (e.g., VGG-13) on CIFAR-100, run the following command

python main.py --model=vgg --pretrained=False --pruning_type=orthoreg --prune_percent=50 --n_rounds=2

Summary of available options

--ebt=<extract_early_bird_tickets>

  • Options: True/False; default: False.
  • If False, n_rounds (see below) needs to be specified to prune iteratively.

--model=<model_name>

  • Options: vgg/mobilenet/resnet.

--pretrained=<use_pretrained_model>

  • Options: True/False.
  • For True, pretrained models should be stored in the directory pretrained.

--data_path=<path_to_data>

  • Options: string; Default: "CIFAR100".
  • If CIFAR-100 is not already downloaded, it will be downloaded.
  • If you are using your own dataset, indicate the path to the main dataset directory which should contain two subdirectories: train and test.

--pruning_type=<how_to_estimate_importance>

--pruning_percent=<how_much_percent_filters_to_prune>

  • Options: integer in range [0, 95].

--n_rounds=<number_of_pruning_rounds>

  • Options: integer.

--thresholds=<manual_thresholds_for_pruning>

  • Options: 'array'.
  • If you do not want to use the default method for deciding pruning ratios, use this option to define manual thresholds.
  • E.g., for pruning the network by 10%, 50%, 80% of original filters in 3 rounds, respectively, use --thresholds='[10, 50, 80]'

--seed=<change_random_seed>

  • Options: integer; Default: 0.

--only_train=<only_train_do_not_prune>

  • Options: base/ortho/False; Default: False.

Training Settings: To change number of epochs or the learning rate schedule for training the base models or the pruned models, change the hyperparameters in config.py. By default, models are trained using SGD with momentum (0.9).

Evaluation

To evaluate a model (e.g., a pruned VGG-13 model), use:

python eval.py --model vgg --pruned True --model_path <path_to_model_file> --test_acc True

Summary of available options for evaluating models:

--model=<model_name>

  • Options: vgg/mobilenet/resnet.

--pruned=<evaluating_a_pruned_model>

  • Options: True/False; Default: False.
  • Set to True for evaluating a pruned model.

--model_path=<path_to_model>

  • Options: string.
  • Location where model to be analyzed is stored.

--data_path=<path_to_dataset>

  • Options: string; Default: "CIFAR100".
  • If CIFAR-100 is not already downloaded, it will be downloaded.
  • If you are using your own dataset, indicate the path to the main dataset directory which should contain two subdirectories: train and test.

--train_acc=<evaluate_train_accuracy>

  • Options: True/False; Default: False.

--test_acc=<evaluate_test_accuracy>

  • Options: True/False; Default: False.

--flops=<evaluate_flops_in_model>

  • Options: True/False; Default: False.

--compression=<evaluate_compression_ratio>

  • Options: True/False; Default: False.

--eval_ortho=<evaluate_degree_of_orthogonality>

  • Options: True/False; Default: False.
  • For layer l with weight matrix W(l) , the degree of Orthogonality option outputs the ratio ||diag(W(l)^T W(l))|| / ||W(l)^T W(l)||. For a diagonal matrix, this ratio should be 1.

Results

We provide sample results for our code. Following are the performances of the Early-Bird Tickets drawn using different pruning methods (OrthoReg, Fisher pruning, and BN-scale based pruning) on VGG-13, MobileNet-V1, and ResNet-34 models for the CIFAR-100 dataset:

To replicate these, use seed 0.

  • Original models:

    Model name Accuracy
    ResNet-34 73.4%
    VGG-13 65.5%
    MobileNet-V1 67.0%
  • OrthoReg:

    Model name % Pruned Accuracy Model name % Pruned Accuracy Model name % Pruned Accuracy
    ResNet-34 25% 77.4% VGG-13 25% 71.4% MobileNet-V1 25% 67.8%
    50% 76.7% 50% 71.2% 50% 67.4%
    75% 74.8% 75% 67.5% 75% 65.8%
  • Fisher pruning:

    Model name % Pruned Accuracy Model name % Pruned Accuracy Model name % Pruned Accuracy
    ResNet-34 25% 72.7% VGG-13 25% 65.4% MobileNet-V1 25% 67.9%
    50% 72.3% 50% 67.1% 50% 67.6%
    75% 71.3% 75% 65.3% 75% 65.9%
  • BN-scale based:

    Model name % Pruned Accuracy Model name % Pruned Accuracy Model name % Pruned Accuracy
    ResNet-34 25% 72.9% VGG-13 25% 65.5% MobileNet-V1 25% 67.7%
    50% 72.1% 50% 65.6% 50% 68.3%
    75% 70.1% 75% 64.8% 75% 65.5%

Pre-trained Models

We provide fully trained and minimally trained models that can be pruned using our code.

For Early-Bird Tickets experiments, minimally trained models can be found here:

For iterative pruning, fully trained models can be found here:

To use these models:

  • Store the fully trained models in the directory pretrained/iterative.

  • Store the minimally trained models in the directory pretrained/ebt.

Note

  • The optimizer is set to SGD by default. For iterative pruning experiments described in the paper, Adam was used.

orthoreg's People

Contributors

ekdeepslubana avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.