GithubHelp home page GithubHelp logo

stefan252423 / onnx2tflite Goto Github PK

View Code? Open in Web Editor NEW

This project forked from mpolaris/onnx2tflite

0.0 0.0 0.0 71 KB

Tool for onnx->keras or onnx->tflite. If tool is useful for you, please star it.

License: Apache License 2.0

Python 100.00%

onnx2tflite's Introduction

ONNX->Keras and ONNX->TFLite tools

How to use

# base
python converter.py --weights "./your_model.onnx"

# give save path
python converter.py --weights "./your_model.onnx" --outpath "./save_path"

# save tflite model
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite"

# save keras and tflite model
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" "keras"

# cutoff model, redefine inputs and outputs, support middle layers
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_name" --output-node-names "layer_name1" "layer_name2"

# quantitative model weight, only weight
python converter.py --weights "./your_model.onnx" --formats "tflite" --weigthquant

# quantitative model weight, include input and output
## recommend
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 1 1 1
## generate random data, instead of read from image file
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8

Features

  • High Consistency. Compare to ONNX outputs, average error less than 1e-5 per elements.
  • More Faster. Output tensorflow-lite model 30% faster than onnx_tf.
  • Auto Channel Align. Auto convert pytorch format(NCWH) to tensorflow format(NWHC).
  • Deployment Support. Support output quantitative model, include fp16 quantization and uint8 quantization.
  • Code Friendly. I've been trying to keep the code structure simple and clear.

Cautions

  • Friendly to 2D vision CNN, and not support 3D CNN, bad support for math operation(such as channel change).
  • Please use comfirm_acc.py comfirm output is correct after convertion, because some of methods rely on practice.
  • comfirm_acc.py only support tflite, and tflite should not be any quantification.

Pytorch -> ONNX -> Tensorflow-Keras -> Tensorflow-Lite

  • From torchvision to tensorflow-lite

import torch
import torchvision
_input = torch.randn(1, 3, 224, 224)
model = torchvision.models.mobilenet_v2(True)
# use default settings is ok
torch.onnx.export(model, _input, './mobilenetV2.onnx', opset_version=11)# or opset_version=13

from converter import onnx_converter
onnx_converter(
    onnx_model_path = "./mobilenetV2.onnx",
    need_simplify = True,
    output_path = "./",
    target_formats = ['tflite'], # or ['keras'], ['keras', 'tflite']
    weight_quant = False,
    int8_model = False,
    int8_mean = None,
    int8_std = None,
    image_root = None
)
  • From custom pytorch model to tensorflow-lite-int8

import torch
import torch.nn as nn
import torch.nn.functional as F

class MyModel(nn.Module):
    def __init__(self):
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
    
    def forward(self, x):
        return self.conv(x)

model = MyModel()
model.load_state_dict(torch.load("model_checkpoint.pth", map_location="cpu"))

_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, _input, './mymodel.onnx', opset_version=11)# or opset_version=13

from converter import onnx_converter
onnx_converter(
    onnx_model_path = "./mymodel.onnx",
    need_simplify = True,
    output_path = "./",
    target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite']
    weight_quant = False,
    int8_model = True, # do quantification
    int8_mean = [0.485, 0.456, 0.406], # give mean of image preprocessing 
    int8_std = [0.229, 0.224, 0.225], # give std of image preprocessing 
    image_root = "./dataset/train" # give image folder of train
)

Validated models


Add operator by yourself

When you counter unspport operator, you can choose add it by yourself or make a issuse.
It's very simple to implement a new operator parser by following these steps below.
Step 0: Select a corresponding layer code file in layers folder, such as activations_layers.py for 'HardSigmoid'.
Step 1: Open it, and edit it:

# all operators regist through OPERATOR register.
# regist operator's name is onnx operator name. 
@OPERATOR.register_operator("HardSigmoid")
class TFHardSigmoid():
    def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
        '''
        :param tensor_grap: dict, key is node name, value is tensorflow-keras node output tensor.
        :param node_weights: dict, key is node name, value is static data, such as weight/bias/constant, weight should be transfom by TorchWeights2TF at most time.
        :param node_inputs: List[str], stored node input names, indicates which nodes the input comes from, tensor_grap and node_weights are possible.
        :param node_attribute: dict, key is attribute name, such as 'axis' or 'perm'. value type is indeterminate, such as List[int] or int or float. notice that type of 'axis' value should be adjusted form NCHW to NHWC by Torch2TFAxis or TorchShape2TF.
        '''
        super().__init__()
        self.alpha = node_attribute.get("alpha", 0.2)
        self.beta = node_attribute.get("beta", 0.5)

    def __call__(self, inputs):
        return tf.clip_by_value(self.alpha*inputs+self.beta, 0, 1)

Step 2: Make it work without error.
Step 3: Convert model to tflite without any quantification.
Step 4: Run comfirm_acc.py, ensure outputs consistency.

TODO

  • support Transofomer, VIT\Swin Trasnformer etc...
  • support cutoff onnx model and specify output layer
  • optimize comfirm_acc.py

License

This software is covered by Apache-2.0 license.

onnx2tflite's People

Contributors

mpolaris avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.