GithubHelp home page GithubHelp logo

image-classification-pytorch's Introduction

简体中文 | English

图像分类模型

更新

  • 2022.12.22,增加了RandAugment数据增强方法。

1. 训练

1.1 预训练权重

模型 链接 论文
mobileone s0s1s2s3s4 An Improved One millisecond Mobile Backbone
ghostnetv2 宽度:1.0 GhostNetV2: Enhance Cheap Operation with Long-Range Attention

花朵图像分类数据集 链接:https://pan.baidu.com/s/1zs9U76OmGAIwbYr91KQxgg 提取码:bhjx

  1. 数据集文件结构

    - data
        - train            # 训练集
            - flower0
            - flower1
            - ...
        - test             # 验证集
            - flower0
            - flower1
            - ...

    运行python process_datasets_path.py命令,将会生成train_cls.txtvalid_cls.txt,这是训练时所需要的。

  2. 训练的参数配置在train.py中,注意:预训练权重下载到weights文件夹

    config = {
        'is_cuda'                  : True,         
        'fp16'                     : True,              # 混合精度训练  
        'classes_path'             : './classes.txt',   # 种类
        'input_shape'              : [224, 224],        
        'model_name'               : 'mobileone',
        'pretrained_weights'       : False,              # 是否需要预训练权重
        'model_path'               : '',                # 整个模型的权重
        'batch_size'               : 16,
        'Epochs'                   : 400,
        'learning_rate'            : 1e-2,
        'optimizer_type'           : 'SGD',
        'lr_decay_type'            : 'Cosine',
        'num_worker'               : 4,
        'save_dir'                 : './logs',          # 保存权重以及损失的文件夹
        'save_period'              : 10,                # 每隔10Epochs保存一次权重
        'loss_func_name'           : 'Poly_loss',        # 损失函数
        'data_aug'                 : 'original'
    }
    
    # ---------------------------------------------------- #
    # model_name                 可选:mobileone、ghostnetv2
    # optimizer_type             可选:SGD、Adam、Ranger
    # loss_func_name
    # 可选:Poly_loss、PolyFocal、CE、LabelSmoothSoftmaxCE
    # 若设置为是双损失函数,则'loss_func_name'设成列表形式
    # 如:'loss_func_name': [('Poly_loss', 'LabelSmoothSoftmaxCE'), (0.9, 0.1)]
    # 后面一个元组为对应损失函数的权重
    # data_aug                   可选:original、randaugment
    # lr_decay_type              可选:Cosine
    # ---------------------------------------------------- #

    mobileone网络结构的参数,运行:

    python summary.py --backbone mobileone

    单GPU训练,运行:

    python train.py

2. 评估

运行:

python eval.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth --output_dir eval_out

其中,model_name表示要评估的图像分类模型,model_path表示权重路径,output_dir表示保存评估结果的文件夹。

3. 推理

预测图片运行:

python inference.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth

4. 部署

本仓库暂时只支持onnxruntime部署。

  1. 导出onnx,运行:

    python export_onnx.py --model_name mobileone --model_path weights/mobileone-16e-s0-flower.pth --output_path weights/mobileone-16e-s0-flower.onnx

    其中,output_path表示onnx导出的路径。

  2. 使用onnxruntime推理图片,运行:

    python inference.py --model_name mobileone --model_onnx ./weights/mobileone-16e-s0-flower.onnx --infer_onnx 1

参考

  1. https://github.com/bubbliiiing/classification-pytorch
  2. https://github.com/apple/ml-mobileone
  3. https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/ghostnetv2_pytorch

image-classification-pytorch's People

Contributors

hao-ux avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.