GithubHelp home page GithubHelp logo

xiwuchen / torch-pruning Goto Github PK

View Code? Open in Web Editor NEW

This project forked from vainf/torch-pruning

1.0 0.0 0.0 9.12 MB

[CVPR-2023] Towards Any Structural Pruning; LLaMA / YOLOv8 / CNNs / Transformers

Home Page: https://arxiv.org/abs/2301.12900

License: MIT License

Shell 8.76% Python 74.90% Jupyter Notebook 16.34%

torch-pruning's Introduction


Towards Any Structural Pruning

Test Status Tested PyTorch Versions License Downloads Latest Version Open In Colab arXiv

[中文README | README in Chinese]

Torch-Pruning (TP) is a versatile library for Structural Network Pruning with the following features:

  • General-purpose Pruning Toolkit: TP enables structural pruning for a wide range of neural networks, including LLaMA, Vision Transformers, Yolov7, yolov8, FasterRCNN, SSD, KeypointRCNN, MaskRCNN, ResNe(X)t, ConvNext, DenseNet, ConvNext, RegNet, FCN, DeepLab, etc. Different from torch.nn.utils.prune that zeroizes parameters through masking, Torch-Pruning employs a (non-deep) graph algorithm called DepGraph to remove parameters and channels physically.
  • Reproducible Performance Benchmark and Prunability Benchmark: Currently, TP is able to prune approximately 81/85=95.3% of the models from Torchvision 0.13.1. Try this Colab Demo for quick start.

For more technical details, please refer to our CVPR'23 paper:

DepGraph: Towards Any Structural Pruning
Gongfan Fang, Xinyin Ma, Mingli Song, Michael Bi Mi, Xinchao Wang

Update:

Please do not hesitate to open a discussion or issue if you encounter any problems with the library or the paper.

Features:

TODO List:

  • A strong baseline pruner with bags of tricks from existing methods.
  • A benchmark for Torchvision compatibility (81/85=95.3%, ✔️) and timm compatibility.
  • Pruning from Scratch / at Initialization.
  • Language (e.g., LLaMA✔️), Speech and Generative Models.
  • More high-level pruners like FisherPruner, GrowingReg, etc.
  • More Transformers like Vision Transformers (:heavy_check_mark:), Swin Transformers, PoolFormers.
  • Block/Layer/Depth Pruning
  • Pruning benchmarks for CIFAR, ImageNet and COCO.

Installation

Torch-Pruning is compatible with PyTorch 1.x and 2.x. PyTorch 1.12.1 is recommended!

pip install torch-pruning # v1.1.6

or

git clone https://github.com/VainF/Torch-Pruning.git

Quickstart

Here we provide a quick start for Torch-Pruning. More explained details can be found in tutorals

0. How It Works

In structural pruning, Group is the minimal removable unit within deep networks. Each group contains several interdependent layers that must be pruned simultaneously to maintain the integrity of the resulting structures. However, deep networks often present complex dependencies among layers, making structural pruning a challenging endeavor. This work addresses this challenge by offering an automated mechanism, DepGraph, for parameter grouping, which facilitates effortless pruning for a wide range of deep networks.

1. A Minimal Example

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True).eval()

# 1. build dependency graph for resnet18
DG = tp.DependencyGraph().build_dependency(model, example_inputs=torch.randn(1,3,224,224))

# 2. Specify the to-be-pruned channels. Here we prune those channels indexed by [2, 6, 9].
group = DG.get_pruning_group( model.conv1, tp.prune_conv_out_channels, idxs=[2, 6, 9] )

# 3. prune all grouped layers that are coupled with model.conv1 (included).
if DG.check_pruning_group(group): # avoid full pruning, i.e., channels=0.
    group.prune()

The above example demonstrates the fundamental pruning pipeline using DepGraph. The target layer resnet.conv1 is coupled with several layers, which requires simultaneous removal in structural pruning. Let's print the group and observe how a pruning operation "triggers" other ones. In the following outputs, A => B means the pruning operation A triggers the pruning operation B. group[0] refers to the pruning root in DG.get_pruning_group.

--------------------------------
          Pruning Group
--------------------------------
[0] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)), idxs=[2, 6, 9] (Pruning Root)
[1] prune_out_channels on conv1 (Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)) => prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[2] prune_out_channels on bn1 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on _ElementWiseOp_20(ReluBackward0), idxs=[2, 6, 9]
[3] prune_out_channels on _ElementWiseOp_20(ReluBackward0) => prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0), idxs=[2, 6, 9]
[4] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_out_channels on _ElementWiseOp_18(AddBackward0), idxs=[2, 6, 9]
[5] prune_out_channels on _ElementWiseOp_19(MaxPool2DWithIndicesBackward0) => prune_in_channels on layer1.0.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[6] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[7] prune_out_channels on _ElementWiseOp_18(AddBackward0) => prune_out_channels on _ElementWiseOp_17(ReluBackward0), idxs=[2, 6, 9]
[8] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_out_channels on _ElementWiseOp_16(AddBackward0), idxs=[2, 6, 9]
[9] prune_out_channels on _ElementWiseOp_17(ReluBackward0) => prune_in_channels on layer1.1.conv1 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[10] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)), idxs=[2, 6, 9]
[11] prune_out_channels on _ElementWiseOp_16(AddBackward0) => prune_out_channels on _ElementWiseOp_15(ReluBackward0), idxs=[2, 6, 9]
[12] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.downsample.0 (Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)), idxs=[2, 6, 9]
[13] prune_out_channels on _ElementWiseOp_15(ReluBackward0) => prune_in_channels on layer2.0.conv1 (Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[14] prune_out_channels on layer1.1.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.1.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
[15] prune_out_channels on layer1.0.bn2 (BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)) => prune_out_channels on layer1.0.conv2 (Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)), idxs=[2, 6, 9]
--------------------------------

For more details about grouping, please refer to tutorials/2 - Exploring Dependency Groups

How to scan all groups (Advanced):

We can use DG.get_all_groups(ignored_layers, root_module_types) to scan all groups sequentially. Each group will begin with a layer that matches a type in the "root_module_types" parameter. Note that DG.get_all_groups is only responsible for grouping and does not have any knowledge or understanding of which parameters should be pruned. Therefore, it is necessary to specify the pruning idxs using group.prune(idxs=idxs).

for group in DG.get_all_groups(ignored_layers=[model.conv1], root_module_types=[nn.Conv2d, nn.Linear]):
    # handle groups in sequential order
    idxs = [2,4,6] # your pruning indices
    group.prune(idxs=idxs)
    print(group)

2. High-level Pruners

Leveraging the DependencyGraph, we developed several high-level pruners in this repository to facilitate effortless pruning. By specifying the desired channel sparsity, you can prune the entire model and fine-tune it using your own training code. For detailed information on this process, please refer to this tutorial, which shows how to implement a slimming pruner from scratch. Additionally, you can find more practical examples in benchmarks/main.py.

import torch
from torchvision.models import resnet18
import torch_pruning as tp

model = resnet18(pretrained=True)

# Importance criteria
example_inputs = torch.randn(1, 3, 224, 224)
imp = tp.importance.MagnitudeImportance(p=2)

ignored_layers = []
for m in model.modules():
    if isinstance(m, torch.nn.Linear) and m.out_features == 1000:
        ignored_layers.append(m) # DO NOT prune the final classifier!

iterative_steps = 5 # progressive pruning
pruner = tp.pruner.MagnitudePruner(
    model,
    example_inputs,
    importance=imp,
    iterative_steps=iterative_steps,
    ch_sparsity=0.5, # remove 50% channels, ResNet18 = {64, 128, 256, 512} => ResNet18_Half = {32, 64, 128, 256}
    ignored_layers=ignored_layers,
)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, example_inputs)
for i in range(iterative_steps):
    pruner.step()
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

Sparse Training

Some pruners like BNScalePruner and GroupNormPruner require sparse training before pruning. This can be easily achieved by inserting just one line of code pruner.regularize(model) in your training script. The pruner will update the gradient of trainable parameters.

for epoch in range(epochs):
    model.train()
    for i, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out, target)
        loss.backward()
        pruner.regularize(model) # <== for sparse learning
        optimizer.step()

Interactive Pruning (Advanced)

All high-level pruners support interactive pruning. You can use pruner.step(interactive=True) to get all groups and interactively prune them by calling group.prune(). This feature is useful if you want to control/monitor the pruning process.

for i in range(iterative_steps):
    for group in pruner.step(interactive=True): # Warning: groups must be handled sequentially. Do not keep them as a list.
        print(group) 
        # do whatever you like with the group 
        # ...
        group.prune() # remeber to call the group.prune()
        # group.prune(idxs=[0, 2, 6]) # It is even possible to change the pruning behaviour with the idxs parameter
    macs, nparams = tp.utils.count_ops_and_params(model, example_inputs)
    # finetune your model here
    # finetune(model)
    # ...

Group-level Pruning

With DepGraph, it is easy to design some "group-level" criteria to estimate the importance of a whole group rather than a single layer. In Torch-pruning, all pruners work in the group level.

3. Save & Load

A Naive Way

The following script saves the whole model object (structure+weights) as a 'model.pth'. This .pth file can be very large due to the additional information in the object.

torch.save(model, 'model.pth') # without .state_dict
model = torch.load('model.pth') # load the model object

Pruning History

We introduce pruning_history to save and load your pruned model in a similar way to the state_dict in pytorch. This feature is currently not available in pypi package. An example can be found in tests/test_load.py

...
# Save
state_dict = {
    'model': model.state_dict(), # model weights
    'pruning': pruner.pruning_history(), # DG also supports DG.pruning_history & DG.load_pruning_history.
}
torch.save(state_dict, 'pruned_model.pth')

# Load
model = resnet18() # Create a new unpruned model
DG = tp.DependencyGraph().build_dependency(model, example_inputs) # Create a new DepGraph or Pruner (both OK!)
state_dict = torch.load('pruned_model.pth') # load the .pth file
DG.load_pruning_history(state_dict['pruning']) # load the pruning history to replay the pruning prcoess 
model.load_state_dict(state_dict['model']) # then, we can load the pruned weights into this model.
print(model)

4. Low-level Pruning Functions

While it is possible to manually prune your model using low-level functions, this approach can be quite laborious, as it requires careful management of the associated dependencies. As a result, we recommend utilizing the aforementioned high-level pruners to streamline the pruning process.

tp.prune_conv_out_channels( model.conv1, idxs=[2,6,9] )

# fix the broken dependencies manually
tp.prune_batchnorm_out_channels( model.bn1, idxs=[2,6,9] )
tp.prune_conv_in_channels( model.layer2[0].conv1, idxs=[2,6,9] )
...

The following pruning functions are available:

'prune_conv_out_channels',
'prune_conv_in_channels',
'prune_depthwise_conv_out_channels',
'prune_depthwise_conv_in_channels',
'prune_batchnorm_out_channels',
'prune_batchnorm_in_channels',
'prune_linear_out_channels',
'prune_linear_in_channels',
'prune_prelu_out_channels',
'prune_prelu_in_channels',
'prune_layernorm_out_channels',
'prune_layernorm_in_channels',
'prune_embedding_out_channels',
'prune_embedding_in_channels',
'prune_parameter_out_channels',
'prune_parameter_in_channels',
'prune_multihead_attention_out_channels',
'prune_multihead_attention_in_channels',
'prune_groupnorm_out_channels',
'prune_groupnorm_in_channels',
'prune_instancenorm_out_channels',
'prune_instancenorm_in_channels',

5. Customized Layers

Please refer to tests/test_customized_layer.py.

6. Benchmarks

Our results on {ResNet-56 / CIFAR-10 / 2.00x}

Method Base (%) Pruned (%) $\Delta$ Acc (%) Speed Up
NIPS [1] - - -0.03 1.76x
Geometric [2] 93.59 93.26 -0.33 1.70x
Polar [3] 93.80 93.83 +0.03 1.88x
CP [4] 92.80 91.80 -1.00 2.00x
AMC [5] 92.80 91.90 -0.90 2.00x
HRank [6] 93.26 92.17 -0.09 2.00x
SFP [7] 93.59 93.36 +0.23 2.11x
ResRep [8] 93.71 93.71 +0.00 2.12x
Ours-L1 93.53 92.93 -0.60 2.12x
Ours-BN 93.53 93.29 -0.24 2.12x
Ours-Group 93.53 93.77 +0.38 2.13x

Please refer to benchmarks for more details.

Citation

@article{fang2023depgraph,
  title={DepGraph: Towards Any Structural Pruning},
  author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
  journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2023}
}

torch-pruning's People

Contributors

eltociear avatar horseee avatar hovavalon avatar jonnykong avatar nus-lv-admin avatar serjio42 avatar trouble404 avatar vainf avatar

Stargazers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.