GithubHelp home page GithubHelp logo

arm-software / scalpel Goto Github PK

View Code? Open in Web Editor NEW
35.0 8.0 22.0 25 KB

This is a PyTorch implementation of the Scalpel. Node pruning for five benchmark networks and SIMD-aware weight pruning for LeNet-300-100 and LeNet-5 is included.

License: BSD 3-Clause "New" or "Revised" License

Python 89.68% Shell 10.32%

scalpel's Introduction

Scalpel

This is a PyTorch implementation of the Scalpel. Node pruning for five benchmark networks and SIMD-aware weight pruning for LeNet-300-100 and LeNet-5 is included.

Node Pruning

Here is the results for node pruning:

Network Dataset Structure After Pruning Accuracy
Before After
LeNet-300-100 MNIST 784->161/300*(ip)->64/100(ip)->10(ip) 98.48% 98.54%
LeNet-5 MNIST 1->9/20(conv)->18/50(conv)->65/500(ip)->10(ip) 99.34% 99.34%
ConvNet CIFAR-10 3->28/32(conv)->22/32(conv)->40/64(conv)->10(ip) 81.38% 81.52%
NIN CIFAR-10 3->117/192(conv)->81/160(conv)->62/96(conv)
->183/192(conv)->123/192(conv)->148/192(conv)
->91/192(conv)->54/192(conv)->10(conv)
89.67% 89.68%
AlexNet ImageNet 3->83/96(conv)->225/256(conv)->233/384(conv)
->238/384(conv)->253/256(conv)->3001/4096(ip)
->3029/4096(ip)->1000(ip)
80.3% 80.5%

* <# of nodes after pruning>/<# of original nodes>

MNIST

I use the dataset reader provided by torchvision.

LeNet-300-100

To run the pruning, please run:

$ cd <ROOT>/MNIST/
$ bash prune.node.lenet_300_100.sh

This script includes following commands:

# original training -- 98.48%
python main.py

# stage 0 -- 60 13
python main.py --prune node --stage 0 \
	--pretrained saved_models/LeNet_300_100.best_origin.pth.tar \
	--lr 0.001 --penalty 0.0002 --lr-epochs 30

# stage 1 -- 120 26
python main.py --prune node --stage 1 \
	--pretrained saved_models/LeNet_300_100.prune.node.0.pth.tar \
	--lr 0.001 --penalty 0.0003 --lr-epochs 30

# stage 2 -- 139 36
python main.py --prune node --stage 2 \
	--pretrained saved_models/LeNet_300_100.prune.node.1.pth.tar \
	--lr 0.001 --penalty 0.0010 --lr-epochs 30

# stage 3 retrain -- 98.54%
python main.py --prune node --stage 3 --retrain \
	--pretrained saved_models/LeNet_300_100.prune.node.2.pth.tar \
	--lr 0.1 --lr-epochs 20

It first trains the original model and, then, applies node pruning (stage 0-2). After node pruning, the model will be retrained to retain the original accuracy (stage 3).

LeNet-5

To run the pruning:

$ cd <ROOT>/MNIST/
$ bash prune.node.lenet_5.sh

It first trains the original model and then apply node pruning. The pre-pruned model can be download here. Download it and put it in the directory of <ROOT>/MNIST/saved_models/. To evaluate the pruned model:

$ python main.py --prune node --arch LeNet_5 --pretrained saved_models/LeNet_5.prune.node.5.pth.tar --evaluate

CIFAR-10

The training dataset can be downloaded here. Download and uncompress it to <ROOT>/CIFAR_10/data/.

ConvNet

Tor run the pruning:

$ cd <ROOT>/CIFAR_10/
$ bash prune.node.convnet.sh

Pre-pruned model can be downloaded here. Download it and put it in the directory of <ROOT>/CIFAR_10/saved_models/. To evaluate the pruned model:

$ python main.py --prune node --pretrained saved_models/ConvNet.prune.node.4.pth.tar --evaluate

Network-in-Network (NIN)

Tor run the pruning:

$ cd <ROOT>/CIFAR_10/
$ bash prune.node.nin.sh

Pre-pruned model can be downloaded here. Download it and put it in the directory of <ROOT>/CIFAR_10/saved_models/. To evaluate the pruned model:

$ python main.py --prune node --arch NIN --pretrained saved_models/NIN.prune.node.7.pth.tar --evaluate

ImageNet

Tor run the pruning:

$ cd <ROOT>/ImageNet/
$ bash prune.node.alexnet.sh

Pre-pruned model can be downloaded here. Download it and put it in the directory of <ROOT>/ImageNet/saved_models/. To evaluate the pruned model:

$ python main.py --prune node --pretrained saved_models/AlexNet.prune.node.8.pth.tar --evaluate

SIMD-Aware Weight Pruning

SIMD-aware weight pruning is provided in ./SIMD_Aware_MNIST. LeNet-300-100 and LeNet-5 on MNIST is tested. The example of LeNet-300-100 can be executed by

$ cd ./SIMD_Aware_MNIST/
$ bash prune.simd.lenet_300_100.sh

It will first train the network and then perform the SIMD-aware weight pruning with group width set to 8. It can remove 92.0% of the weights. The script of prune.simd.lenet_300_100.sh contains following instructions:

# original training -- 98.48%
python main.py

# 60.6% pruned
python main.py --prune simd --stage 0 --width 8\
	--pretrained saved_models/LeNet_300_100.best_origin.pth.tar \
	--lr 0.01 --lr-epochs 20 --threshold 0.04

# 72.6% pruned
python main.py --prune simd --stage 1 --width 8\
	--pretrained saved_models/LeNet_300_100.prune.simd.0.pth.tar \
	--lr 0.01 --lr-epochs 20 --threshold 0.05

# 82.4% pruned
python main.py --prune simd --stage 2 --width 8\
	--pretrained saved_models/LeNet_300_100.prune.simd.1.pth.tar \
	--lr 0.01 --lr-epochs 20 --threshold 0.06

# 88.7% pruned
python main.py --prune simd --stage 3 --width 8\
	--pretrained saved_models/LeNet_300_100.prune.simd.2.pth.tar \
	--lr 0.01 --lr-epochs 20 --threshold 0.07

# 92.0% pruned
python main.py --prune simd --stage 4 --width 8\
	--pretrained saved_models/LeNet_300_100.prune.simd.3.pth.tar \
	--lr 0.01 --lr-epochs 20 --threshold 0.08

For LeNet-5, the experiment can be performed by run

$ bash prune.simd.lenet_5.sh

It will remove 96.8% of the weights in LeNet-5.

SIMD-aware weight pruning for other benchmark networks are under construction.

Citation

Please cite Scalpel in your publications if it helps your research:

@inproceedings{yu2017scalpel,
  title={Scalpel: Customizing DNN Pruning to the Underlying Hardware Parallelism},
  author={Yu, Jiecao and Lukefahr, Andrew and Palframan, David and Dasika, Ganesh and Das, Reetuparna and Mahlke, Scott},
  booktitle={Proceedings of the 44th Annual International Symposium on Computer Architecture},
  pages={548--560},
  year={2017},
  organization={ACM}
}

scalpel's People

Contributors

ganeshdasikaarm avatar jiecaoyu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

scalpel's Issues

Performance Issue: sparse matrix-matrix multiplication

Hi,
I have a question about sparse matrix-matrix multiplication.
The environment of experiment

  • Odroid-xu4 mali-T628
  • OpenCL

I implemented sparse matrix-matrix multiplication OpenCL kernel code using below this algorithm.

image

But, i have a lower performance than Arm Compute Library (sgemm - dense matrix multiplication) on odroid device.

I saw a paper "Scalpel", there is a words "Libraries for sparse matrix-vector/matrix multiplication
are written in-house."

I wonder how to improve performance of SpMM on Mali-T628 GPU with OpenCL.

Thanks.

TypeError: can't assign a bool to a torch.cuda.ByteTensor

I run the SIMD-Aware Weight Pruning, problem happened like this

File "main.py", line 231, in <module>
     simd_prune_op = SIMD_Prune_Op(model, args.threshold, args.width)
File "scalpel/SIMD_Aware_MNIST/../util/util.py", line 97, in __init__
    tmp_pruned[:, -1] = False
TypeError: can't assign a bool to a torch.cuda.ByteTensor

How to solve the problem? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.