简体中文 | English
- 2022.12.22,增加了RandAugment数据增强方法。
模型 | 链接 | 论文 |
---|---|---|
mobileone | s0、s1、s2、s3、s4 | An Improved One millisecond Mobile Backbone |
ghostnetv2 | 宽度:1.0 | GhostNetV2: Enhance Cheap Operation with Long-Range Attention |
花朵图像分类数据集 链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg 提取码:bhjx
-
数据集文件结构
- data - train # 训练集 - flower0 - flower1 - ... - test # 验证集 - flower0 - flower1 - ...
运行
python process_datasets_path.py
命令,将会生成train_cls.txt
和valid_cls.txt
,这是训练时所需要的。 -
训练的参数配置在
train.py
中,注意:预训练权重下载到weights
文件夹。config = { 'is_cuda' : True, 'fp16' : True, # 混合精度训练 'classes_path' : './classes.txt', # 种类 'input_shape' : [224, 224], 'model_name' : 'mobileone', 'pretrained_weights' : False, # 是否需要预训练权重 'model_path' : '', # 整个模型的权重 'batch_size' : 16, 'Epochs' : 400, 'learning_rate' : 1e-2, 'optimizer_type' : 'SGD', 'lr_decay_type' : 'Cosine', 'num_worker' : 4, 'save_dir' : './logs', # 保存权重以及损失的文件夹 'save_period' : 10, # 每隔10Epochs保存一次权重 'loss_func_name' : 'Poly_loss', # 损失函数 'data_aug' : 'original' } # ---------------------------------------------------- # # model_name 可选:mobileone、ghostnetv2 # optimizer_type 可选:SGD、Adam、Ranger # loss_func_name # 可选:Poly_loss、PolyFocal、CE、LabelSmoothSoftmaxCE # 若设置为是双损失函数,则'loss_func_name'设成列表形式 # 如:'loss_func_name': [('Poly_loss', 'LabelSmoothSoftmaxCE'), (0.9, 0.1)] # 后面一个元组为对应损失函数的权重 # data_aug 可选:original、randaugment # lr_decay_type 可选:Cosine # ---------------------------------------------------- #
mobileone网络结构的参数,运行:
python summary.py --backbone mobileone
单GPU训练,运行:
python train.py
运行:
python eval.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth --output_dir eval_out
其中,model_name
表示要评估的图像分类模型,model_path
表示权重路径,output_dir
表示保存评估结果的文件夹。
预测图片运行:
python inference.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth
本仓库暂时只支持onnxruntime部署。
-
导出onnx,运行:
python export_onnx.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth --output_path weights/mobileone-16e-s0-flower.onnx
其中,output_path表示onnx导出的路径。
-
使用onnxruntime推理图片,运行:
python inference.py --model_name mobileone --model_onnx ./weights/mobileone-16e-s0-flower.onnx --infer_onnx 1