uyzhang / yolov5_prune Goto Github PK
View Code? Open in Web Editor NEWYOLOv5 pruning on COCO Dataset
License: Apache License 2.0
YOLOv5 pruning on COCO Dataset
License: Apache License 2.0
what's the error?
Traceback (most recent call last):
File "prune.py", line 223, in
main()
File "prune.py", line 209, in main
prune(**params_prune)
File "/usr/local/lib/python3.7/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "prune.py", line 85, in prune
model = DetectMultiBackend(weights, device=device, dnn=dnn, fuse=False)
File "/content/yolov5/yolov5_prune/models/common.py", line 308, in init
model = attempt_load(weights if isinstance(weights, list) else w, map_location=device, fuse=fuse)
File "/content/yolov5/yolov5_prune/models/experimental.py", line 96, in attempt_load
ckpt = torch.load(attempt_download(w), map_location=map_location) # load
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 712, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 1049, in _load
result = unpickler.load()
File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 1042, in find_class
return super().find_class(mod_name, name)
AttributeError: Can't get attribute 'DetectionModel' on <module 'models.yolo' from '/content/yolov5/yolov5_prune/models/yolo.py'>
in get_prune_threshold function, percent threshold is calculated.
def get_prune_threshold(model_list, percent):
bn_weights = gather_bn_weights(model_list)
sorted_bn = torch.sort(bn_weights)[0]
# 避免剪掉所有channel的最高阈值(每个BN层的gamma的最大值的最小值即为阈值上限)
highest_thre = []
for bnlayer in model_list.values():
highest_thre.append(bnlayer.weight.data.abs().max().item())
highest_thre = min(highest_thre)
# 找到highest_thre对应的下标对应的百分比
threshold_index = (sorted_bn == highest_thre).nonzero().squeeze()
if len(threshold_index.shape) > 0:
threshold_index = threshold_index[0]
percent_threshold = threshold_index.item() / len(bn_weights)
print('Suggested Gamma threshold should be less than {}'.format(highest_thre))
print('The corresponding prune ratio is {}, but you can set higher'.format(percent_threshold))
thre_index = int(len(sorted_bn) * percent)
thre_prune = sorted_bn[thre_index]
print('Gamma value that less than {} are set to zero'.format(thre_prune))
print("=" * 94)
print(f"|\t{'layer name':<25}{'|':<10}{'origin channels':<20}{'|':<10}{'remaining channels':<20}|")
return thre_prune
we use --percent parameter to apply prune.py like
'python prune.py --percent 0.5 --weights runs/train/coco_sparsity2/weights/last.pt --data data/coco.yaml --cfg models/yolov5s.yaml --imgsz 640' and if --percent parameter is bigger than calculated percent_threshold, it happened error.
I tried to use this github code for custom dataset training and pruning.
Please let me know why percent threshold is limited, Thanks.
Hello,
I pruned model yolov5s with 0.1 percent and after that fine tunned with 30 epochs still it is not able to detet object in bus.jpg.
Could you please guide me
I pruned the model and try to export that model into the tflite model but I got belo error.
TensorFlow SavedModel: starting export with tensorflow 2.13.0...
from n params module arguments
0 -1 1 3520 models.common.Conv [3, 32, 6, 2, 2]
1 -1 1 18560 models.common.Conv [32, 64, 3, 2]
TensorFlow SavedModel: export failure: init() missing 3 required positional arguments: 'cv2out', 'cv3out', and 'bottle_args'
TensorFlow Lite: starting export with tensorflow 2.13.0...
TensorFlow Lite: export failure: 'NoneType' object has no attribute 'call'
UnboundLocalError: local variable 'srtmp' referenced before assignment
https://github.com/uyzhang/yolov5_prune#steps
In basic training number 2, coco-hand is based on pre-trained weight, but it is not coco, which one is correct?
As far as I know, it is related to the botleneck ignore
problem for sparsity training, should we proceed with the existing pretrained weight before sparsity training?
Hello, after basic training, I tried to run it in the same conditional environment as the sparsity training. But out of memory occurred. I think the only addition is the l1 regularization of bn param (scaling factor, bias), is this a lot of memory when subgradient calculation of that param is done? Or is there another reason? It takes about 2 hour per epoch (Of course, the image size was big, so it took about 40 minutes for basic training) thank you for contrib
finetuning epoch = 100;
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.319
finetuning后精度损失基本和您实验中是差不多的。
因为直接剪枝的结果更差,所以此处暂未进行finetuning!
问题1:我看您实验13的设置中也是进行了100个epoch的训练,所以很好奇为啥稀疏训练epoch时间更长,反而效果更差了?
或者说这是不是说明稀疏训练的次数并非越多越好,类似于正常训练时可能会因为训练次数过多而导致的过拟合。
问题2:稀疏训练的过程就会导致精度损失,然后在剪枝后也会导致精度损失,即便进行了finetuning也还是存在精度损失,这也就意味着整个剪枝过程中会面临两次精度下降的问题,所以想请教一下您如何看待这个问题?以及是否可以通过稀疏训练更多的epoch来提高精度,或者是稀疏训练较少的次数,但是剪枝后finetuning更长时间来更大程度上恢复精度?
Looking forward to your reply! Thanks!
Excellent work on the implementation. I've successfully pruned various YOLOv5 models by adjusting the width and height multipliers. However, I'm encountering an issue: I seem unable to prune any models when the head is discarded. Is there an available solution or method for this situation?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.