GithubHelp home page GithubHelp logo

fengxingxiang / modelcompression-2019 Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ahraza/modelcompression-2019

0.0 0.0 0.0 47.67 MB

Python 29.73% Jupyter Notebook 59.59% Shell 0.16% C++ 1.39% Cuda 4.88% C 4.13% MATLAB 0.13%

modelcompression-2019's Introduction

Thesis Test Bench

Prerequisites

  • Python 3.6+
  • Pip3
  • Virtualenv (optional)

Setup

To get started, set up a virtualenv (if so desired):

~ virtualenv venv
~ . venv/bin/activate

Next, install all dependencies:

~ pip install -r requirements.txt

Everything should now be ready.

Usage

Train a model from scratch

In order to re-train one of the available models from scratch (without pruning or quantization), simply run the equivalent *_classifier.py file. Example:

~ python cifar_classifier.py --save-model

By default, this will train a model and save it to models/cifar_classifier.pt.

Evaluate Pruning

Step 1

In order to prune and evaluate a pruned model, ensure that a trained model exists. If you are not sure how to do this, follow the instructions here.

Step 2

Once confirming a trained model exists, ensure that there exists an appropriate test time configuration for that model. This can be found in main.py, under the configurations variable. This will change in the future to a more robust solution.

Step 3 (WILL CHANGE)

Run the command:

python main.py

You will be presented with a list of available configurations:

Select a model type to prune. Models available:
    0:{'model': <class 'cifar_classifier.MaskedCifar'>, 'dataset': <class 'torchvision.datasets.cifar.CIFAR10'>, 'transforms': [ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))], 'loss_fn': <function cross_entropy at 0x000001C3C9EB1E18>}
    
    1:{'model': <class 'mnist_classifier.MaskedMNist'>, 'dataset': <class 'torchvision.datasets.mnist.MNIST'>, 'transforms': [ToTensor(), Normalize(mean=(0.1307,), std=(0.3081,))], 'loss_fn': <function nll_loss at 0x000001C3C9EB1AE8>}

Select the appropraite configuration by selecting the correct index:

Enter selected index and press enter: 0

Then, select the appropriate trained model parameters to load:

Select a model params file to load by index. Models available:
    0:cifar_classifier.pt
    1:mnist_classifier.pt
    2:mnist_cnn.pt

Enter selected index and press enter: 0

Finally, input the desired pruning percentage:

Select pruning percentage: (0-100)%: 50

The script will evaluate the pre-pruned model, prune the model, print relevant pruning statistics by layer, and finally evaluate the pruned model.

Loading file cifar_classifier.pt for model {'model': <class 'cifar_classifier.MaskedCifar'>, 'dataset': <class 'torchvision.datasets.cifar.CIFAR10'>, 'transforms': [ToTensor(), Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))], 'loss_fn': <function cross_entropy at 0x000001C3C9EB1E18>}

Testing pre-pruned model..

Test set: Average loss: 1.1923, Accuracy: 5865/10000 (59%)

Pruning model..
Layer 1 | Conv layer | 15.11% parameters pruned
Layer 2 | Conv layer | 26.54% parameters pruned
Layer 3 | Linear layer | 56.20% parameters pruned
Layer 4 | Linear layer | 30.25% parameters pruned
Layer 5 | Linear layer | 18.21% parameters pruned
Final pruning rate: 49.81%

Evaluating pruned model..

Test set: Average loss: 1.2210, Accuracy: 5714/10000 (57%)

modelcompression-2019's People

Contributors

ahraza avatar bobaraki avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.