GithubHelp home page GithubHelp logo

uyzhang / yolov5_prune Goto Github PK

View Code? Open in Web Editor NEW
87.0 87.0 9.0 877 KB

YOLOv5 pruning on COCO Dataset

License: Apache License 2.0

Shell 1.26% Python 86.34% Dockerfile 0.18% Jupyter Notebook 12.22%
coco prune yolov5

yolov5_prune's Introduction

yolov5_prune's People

Contributors

uyzhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

yolov5_prune's Issues

Can't get attribute 'DetectionModel' on <module 'models.yolo' from '/content/yolov5/yolov5_prune/models/yolo.py'>

what's the error?

Traceback (most recent call last):
File "prune.py", line 223, in
main()
File "prune.py", line 209, in main
prune(**params_prune)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "prune.py", line 85, in prune
model = DetectMultiBackend(weights, device=device, dnn=dnn, fuse=False)
File "/content/yolov5/yolov5_prune/models/common.py", line 308, in init
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device, fuse=fuse)
File "/content/yolov5/yolov5_prune/models/experimental.py", line 96, in attempt_load
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 712, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 1049, in _load
result = unpickler.load()
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 1042, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'DetectionModel' on <module 'models.yolo' from '/content/yolov5/yolov5_prune/models/yolo.py'>

what is the reason to limit percent threshold?

in get_prune_threshold function, percent threshold is calculated.

def get_prune_threshold(model_list, percent):
bn_weights = gather_bn_weights(model_list)
sorted_bn = torch.sort(bn_weights)[0]

# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
highest_thre = []
for bnlayer in model_list.values():
    highest_thre.append(bnlayer.weight.data.abs().max().item())

highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比
threshold_index = (sorted_bn == highest_thre).nonzero().squeeze()
if len(threshold_index.shape) > 0:
    threshold_index = threshold_index[0]
percent_threshold = threshold_index.item() / len(bn_weights)
print('Suggested Gamma threshold should be less than {}'.format(highest_thre))
print('The corresponding prune ratio is {}, but you can set higher'.format(percent_threshold))
thre_index = int(len(sorted_bn) * percent)
thre_prune = sorted_bn[thre_index]
print('Gamma value that less than {} are set to zero'.format(thre_prune))
print("=" * 94)
print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")
return thre_prune

we use --percent parameter to apply prune.py like
'python prune.py --percent 0.5 --weights runs/train/coco_sparsity2/weights/last.pt --data data/coco.yaml --cfg models/yolov5s.yaml --imgsz 640' and if --percent parameter is bigger than calculated percent_threshold, it happened error.

I tried to use this github code for custom dataset training and pruning.
Please let me know why percent threshold is limited, Thanks.

Facing issue to convert pruned model to tflite model

I pruned the model and try to export that model into the tflite model but I got belo error.

TensorFlow SavedModel: starting export with tensorflow 2.13.0...

             from  n    params  module                                  arguments                     

0 -1 1 3520 models.common.Conv [3, 32, 6, 2, 2]
1 -1 1 18560 models.common.Conv [32, 64, 3, 2]

TensorFlow SavedModel: export failure: init() missing 3 required positional arguments: 'cv2out', 'cv3out', and 'bottle_args'

TensorFlow Lite: starting export with tensorflow 2.13.0...

TensorFlow Lite: export failure: 'NoneType' object has no attribute 'call'

error

UnboundLocalError: local variable 'srtmp' referenced before assignment

question about sparsity training gpu memory

Hello, after basic training, I tried to run it in the same conditional environment as the sparsity training. But out of memory occurred. I think the only addition is the l1 regularization of bn param (scaling factor, bias), is this a lot of memory when subgradient calculation of that param is done? Or is there another reason? It takes about 2 hour per epoch (Of course, the image size was big, so it took about 40 minutes for basic training) thank you for contrib

稀疏训练epoch设置的更大,在同样的剪枝率下剪枝的结果反而更差!

Hello!您好,我在按照您的代码以及实验13相同的设置下进行的两次剪枝实验;
实验1:sparsity training epochs = 50;
pruning percent = 30%;
结果如下:

| P | origin:0.6032 | after prune:0.5359 | loss ratio:0.1114
| R | origin:0.4673 | after prune:0.3807 | loss ratio:0.1854
| [email protected] | origin:0.5084 | after prune:0.4142 | loss ratio:0.1852
| [email protected]:.95 | origin:0.3244 | after prune:0.2414 | loss ratio:0.2560

finetuning epoch = 100;
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.319
finetuning后精度损失基本和您实验中是差不多的。

实验2:sparsity training epochs = 100;
pruning percent = 30%;
结果如下:

| P | origin:0.6460 | after prune:0.4668 | loss ratio:0.2774
| R | origin:0.4777 | after prune:0.3093 | loss ratio:0.3525
| [email protected] | origin:0.5308 | after prune:0.3079 | loss ratio:0.4201
| [email protected]:.95 | origin:0.3413 | after prune:0.1745 | loss ratio:0.4888

因为直接剪枝的结果更差,所以此处暂未进行finetuning!

问题1:我看您实验13的设置中也是进行了100个epoch的训练,所以很好奇为啥稀疏训练epoch时间更长,反而效果更差了?
或者说这是不是说明稀疏训练的次数并非越多越好,类似于正常训练时可能会因为训练次数过多而导致的过拟合。

问题2:稀疏训练的过程就会导致精度损失,然后在剪枝后也会导致精度损失,即便进行了finetuning也还是存在精度损失,这也就意味着整个剪枝过程中会面临两次精度下降的问题,所以想请教一下您如何看待这个问题?以及是否可以通过稀疏训练更多的epoch来提高精度,或者是稀疏训练较少的次数,但是剪枝后finetuning更长时间来更大程度上恢复精度?

Looking forward to your reply! Thanks!

剪枝报错stride[1,1]

您好,请问您有遇到这种情况嘛
image

最后一行报错内容为
RuntimeError: expected stride to be a single integer value or a list of 1 values to match the convolution dimensions, but got stride=[1, 1]

yolo.py中parse_pruned_model 获取每层剩余通道数疑问

您好,有一个问题如图中所示:之前求得的mask_bn中key是类似这种model.0.conv.1 , 但是这里mask_bn使用的key是named_m_bn 带有.bn关键字,运行时会报错显示mask_bn没有这个关键字。请问是之前求得mask_bn时 修改了相应的key值嘛?
image

Prune YOLOv5 with head discarded.

Excellent work on the implementation. I've successfully pruned various YOLOv5 models by adjusting the width and height multipliers. However, I'm encountering an issue: I seem unable to prune any models when the head is discarded. Is there an available solution or method for this situation?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.