GithubHelp home page GithubHelp logo

roll920 / curl Goto Github PK

View Code? Open in Web Editor NEW
22.0 1.0 2.0 24.51 MB

PyTorch Implementation of CURL-Neural Network Pruning with Residual-Connections and Limited-Data

License: MIT License

Python 99.18% Shell 0.82%
pytorch-cnn pruning deep-learning

curl's Introduction

PyTorch Implementation of CURL

  • Neural Network Pruning with Residual-Connections and Limited-Data, CVPR 2020, Oral.
  • [CVF open acess]

Requirements

PyTorch environment:

Prune on ImageNet

  1. clone this repository.
  2. download the ImageNet dataset and organize this dataset with train and val folders.
  3. select subfolder:
    cd ImageNet/CURL
    
  4. start pruning and fine-tuning:
    ./run_this.sh
    

Note: The training log files on ImageNet are missing. We provide the pruned model: ImageNet/released_model/ResNet50-CURL-1G.pth. You can run ImageNet/released_model/run_this.sh to test its accuracy.

Prune on CUB200

  1. clone this repository.
  2. download the CUB200 dataset and organize this dataset with train and val folders.
  3. expand the small dataset using CUB200/expand_dataset.py
  4. select subfolder:
    cd CUB200/mobilenetv2/
    
  5. edit the configuration file config.yaml.
  6. calculate the importance score for each filter:
    ./evaluate_importance.sh
    
  7. fine-tune the pruned model:
    ./run_this.sh
    

Note: The training log files are provided in corresponding folders.

Results

We prune the ResNet50 on ImageNet dataset:

Architecture Top-1 Acc. Top-5 Acc. #MACs #Param.
ResNet-50 76.15% 92.87% 4.09G 25.56M
CURL 73.39% 91.46% 1.11G 6.67M

The results of MobileNetV2 on CUB200:

Architecture Top-1 Acc. #MACs
MobileNetV2-1.0 78.77% 299.77M
MobileNetV2-0.5 73.96% 96.12M
CURL 78.72% 96.07M

The results of ResNet50 on CUB200:

Architecture Top-1 Acc. #MACs
ResNet50 84.76% 4.09G
ResNet50-CURL 81.33% 1.11G
CURL 83.64% 1.10G

Citation

If you find this work is useful for your research, please cite:

@InProceedings{Luo_2020_CVPR,
author = {Luo, Jian-Hao and Wu, Jianxin},
title = {Neural Network Pruning With Residual-Connections and Limited-Data},
booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2020},
pages = {1458-1467}
}

Contact

Feel free to contact me if you have any question (Jian-Hao Luo [email protected] or [email protected]).

curl's People

Contributors

roll920 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

Forkers

wzb1005 yvonnedl

curl's Issues

kl_loss

Dear @Roll920 ,

Despite the computational cost, KL-divergence is a very interesting criterion for assigning filter importance. However, I'm a bit confused about the KL loss implemented in your code (e.g., lines 128-129 in generate_mask.py). Based on Equation 1 in the paper, kl_loss = torch.mean(torch.sum(softmax(output) * (logsoftmax(output) - logsoftmax(logits)), dim=1)) should be kl_loss = torch.mean(torch.sum(softmax(logits) * (logsoftmax(logits) - logsoftmax(output)), dim=1))?

Thanks in advance,

Question about how to use it

Dear Luo, I have read this paper about how select filter to prune, I have a question that should I calculate the scores of every layer's every channel ? If so, I think it's a little inconvenient for me to use. The second question is that every layer has its top k filters to remove, how should I set this k? Should change it every dataset. Maybe I can set a k for whole network, but I don't I can get a well performance. The third question is that how can I apply KL-divergence on object detection or instance segmentation? I think the application in real world is not only the classification but more is object detection, how to apply filter prune on this model?

CUB200 dataset organize

First, thanks very much of your code, It is realy great work, however could you release the code of organize the CUB200 dataset of train val
thanks very much

Seems like there needs to be some lines to be changed

File "/*/CURL/ImageNet/released_model/pruned_model.py", line 76, in __init__         
    index_list = list(np.load('index.npy'))                                                 
  File "/opt/conda/lib/python3.8/site-packages/numpy/lib/npyio.py", line 439, in load             
    return format.read_array(fid, allow_pickle=allow_pickle,                                      
  File "/opt/conda/lib/python3.8/site-packages/numpy/lib/format.py", line 727, in read_array     
    raise ValueError("Object arrays cannot be loaded when "                               
ValueError: Object arrays cannot be loaded when allow_pickle=False 

so setting line 76 of pruned_model.py to

index_list = list(np.load('index.npy', allow_pickle=True))
did the work.
Not sure if it is safe to do this though

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.