GithubHelp home page GithubHelp logo

sithu31296 / pytorch-onnx-tflite Goto Github PK

View Code? Open in Web Editor NEW
318.0 4.0 46.0 16 KB

Conversion of PyTorch Models into TFLite

License: MIT License

Python 100.00%
onnx-tf tflite pytorch tflite-conversion

pytorch-onnx-tflite's Introduction

TFLite Conversion

PyTorch -> ONNX -> TF -> TFLite

Convert PyTorch Models to TFLite and run inference in TFLite Python API.

Tested Environment

  • pytorch==1.7.1
  • tensorflow==2.4.1
  • onnx==1.8.0
  • onnx-tf==1.7.0

PyTorch to ONNX

Load the PyTorch Model:

model = Model()
model.load_state_dict(torch.load(pt_model_path, map_location='cpu')).eval()

Prepare the Input:

sample_input = torch.rand((batch_size, channels, height, width))

Export to ONNX format:

torch.onnx.export(
    model,                  # PyTorch Model
    sample_input,                    # Input tensor
    onnx_model_path,        # Output file (eg. 'output_model.onnx')
    opset_version=12,       # Operator support version
    input_names=['input']   # Input tensor name (arbitary)
    output_names=['output'] # Output tensor name (arbitary)
)

opset-version: opset_version is very important. Some PyTorch operators are still not supported in ONNX even if opset_version=12. Default opset_version in PyTorch is 12. Please check official ONNX repo for supported PyTorch operators. If your model includes unsupported operators, convert to supported operators. For example, torch.repeat_interleave() is not supported, it can be converted into supported torch.repeat() + torch.view() to achieve the same function.

output-names: If your model returns more than 1 output, provide exact length of arbitary names. For example, if your model returns 3 outputs, then output_names should be ['output0', 'output1', 'output3']. If you don't provide exact length, although PT-ONNX conversion is successful, ONNX-TFLite conversion will not.

For more information about onnx model conversion, please check ONNX_DETAILS

Verification

You can verify the ONNX protobuf with onnx library.

Install onnx:

pip install onnx
import onnx

# Load the ONNX model
model = onnx.load("model.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a Human readable representation of the graph
onnx.helper.printable_graph(model.graph)

ONNX Model Inference

Install onnxruntime:

pip install onnxruntime

Then run the inference:

import onnxruntime as ort

ort_session = ort.InferenceSession('model.onnx')

outputs = ort_session.run(
    None,
    {'input': np.random.randn(batch_size, channels, height, width).astype(np.float32)}
)

ONNX to TF

You cannot convert ONNX model directly into TFLite model. You must first convert to TensorFlow model.

Use onnx-tensorflow to convert models from ONNX to Tensorflow.

Install as follows:

git clone https://github.com/onnx/onnx-tensorflow.git && cd onnx-tensorflow
pip install -e .

Load the ONNX model:

import onnx

onnx_model = onnx.load(onnx_model_path)

Convert with onnx-tf:

from onnx_tf.backend import prepare

tf_rep = prepare(onnx_model)

Export TF model:

tf_rep.export_graph(tf_model_path)

You will get a Tensorflow model in SavedModel format.

Note: tf_model_path should not contain an extension like .pb.

TF Model Inference

import tensorflow as tf

model = tf.saved_model.load(tf_model_path)
model.trainable = False

input_tensor = tf.random.uniform([batch_size, channels, height, width])
out = model(**{'input': input_tensor})

TF to TFLite

To convert TF SavedModel format into TFLite models, you can use official tf.lite.TFLiteConverter class.

# Convert the model
converter = tf.lite.TFLiteConverter.from_saved_model(tf_model_path)
tflite_model = converter.convert()

# Save the model
with open(tflite_model_path, 'wb') as f:
    f.write(tflite_model)

TFLite Model Inference

import numpy as np
import tensorflow as tf

# Load the TFLite model and allocate tensors
interpreter = tf.lite.Interpreter(model_path="converted_model.tflite")
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# get_tensor() returns a copy of the tensor data
# use tensor() in order to get a pointer to the tensor
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

Supported Ops and Limitations

TFlite supports a subset of TF operations with some limitations. For full list of operations and limitations see TF Lite Ops page.

Most TFLite ops target float32 and quantized uint8 or int8 inference, but many ops don't support other types like float16 and strings.

TFLite with TF ops

Since TFLite builtin ops only supports a limited number of TF operators, not every model is convertible.

To allow conversion, usage of certain TF ops can be enabled in TFLite model.

However, running TFLite models with TF Ops requires pulling in the core TF runtime, which increases TFLite interpreter binary size.

TF Ops that can be enabled in TFLite

To convert to TFLite model with additional TF ops:

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TFLite ops
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TF ops
]
tflite_model = converter.convert()

open("converted_model.tflite", "wb").write(tflite_model)

Run Inference

When using a TFLite model that has been converted with support for select TF ops, the client must also use a TFLite runtime that includes the necessary library of TF ops.

You don't need to do extra steps to use this select TF ops in Python. TFLite is automatically installed with that support.

Performance

The following table runs inference on MobileNet with Pixel 2.

Build Time (ms) APK Size
Builtin ops 260.7 561KB
Builtin ops + TF ops 264.5 8MB

Model Optimization

TFLite supports optimization via quantization, pruning and clustering.

Quantization

Quantization works by reducing the precision of the numbers used to represent a model's parameters (default, float32). This results in a smaller model size and faster computation.

Technique Data Requirements Size Reduction Accuracy
Post-training float16 quantization No data Up tp 50% Insignificant accuracy loss
Post-training dynamic range quantization No data Up to 75% Accuracy loss
Post-training integer quantization Unlabelled data Up to 75% Smaller accuracy loss
Quantization-aware training Labelled training data Up to 75% Smallest accuracy loss

Post-training float16 quantization

Use this when you are deploying to float16-enabled GPU.

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.float16]
tflite_quant_model = converter.convert()

Post-training dynamic range quantization

Don't use this. Use integer quantization.

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()

Post-training integer quantization

Integer with float fallback (using default float input/output)

The model is in integer but use float operators when they don't have an integer implementation.

A common use case for ARM CPU.

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def representative_dataset_gen():
    for _ in range(num_calibration_steps):
        # get sample input data as numpy array 
        yield [input]

converter.representative_dataset = representative_dataset_gen
tflite_quant_model = converter.convert()
Integer only

A common use case for 8-bit MCU and Coral Edge TPU.

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

def representative_dataset_gen():
    for _ in range(num_calibration_steps):
        # get sample input data as numpy array 
        yield [input]

converter.representative_dataset = representative_dataset_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_quant_model = converter.convert()

Pruning

Pruning works by removing parameters within a model that have only a minor impact on its predictions.

Pruned models are the same size on disk, and have the same runtime latency, but can be compressed more effectively. This makes pruning a useful technique for reducing model download size.

Clustering

Clustering works by grouping the weights of each layer in a model into a predefined number of clusters, then sharing the centroid values for the weights belonging to each individual cluster. This reduces the number of unique weight values in a model, thus reducing its complexity.

As a result, clustered models can be compressed more effecitvely, providing deployment benefits similar to pruning.

Note: For pruning and clustering, check out official TFLite Guide for more information.

References

pytorch-onnx-tflite's People

Contributors

sithu31296 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-onnx-tflite's Issues

Size not decreasing from onnx to tflite

Hello! I have followed your (very useful) guide on the conversion from pytorch to tflite, using onnx as the "middle step".

While all the conversions work fine, I have noticed that the model does not reduce in size going from Pytorch to TFLite, while the onnx model is 30-ish KB smaller than the others. Could this be a sign I'm doing something wrong?

tf_rep.export_graph(tf_model_path)

dear, Thanks for your great work.
when I run tf_rep.export_graph(tf_model_path) in "onnx_to_tf.py".
It throw error like following:
"
File "/home/ateam/code/PyTorch-ONNX-TFLite/conversion/onnx_to_tf.py", line 9, in
tf_rep.export_graph(tf_model_path)
File "/home/ateam/onnx-tensorflow/onnx_tf/backend_rep.py", line 107, in export_graph
file = open(path, "wb")
IsADirectoryError: [Errno 21] Is a directory: 'model_tf'
"

channel problem

Hello
How are you?
Thanks for contributing to this project.
I converted my PyTorch model to the TFLite model as your method.
But this TFLite model is different from the original TFLite model from the Keras model.

This figure is the structure of the base TFLite model converted from the Keras model.

image

The below figure is the structure of the TFLite model converted from the Pytorch model.

image

The below figure is the structure of the ONNX model converted from the Pytorch model.

image

The main problem is that the TFLite model converted from the Pytorch model runs more slowly than the base TFLite model converted from the Keras model.

As u know, the PyTorch model uses channel-first as tensor ordering and the Tensorflow uses channel-last.
I looked at the structure of the TFLite model converted from the PyTorch model.
Comparing with the base TFLite model converted from the Keras model, there are many Transpose layers in the TFLite model converted from the PyTorch model.
I think that this affects the model inference speed.
Can we get the TFLite model of using channel-last ordering from the PyTorch model?

When calling converter.convert() I encountered a wierd problem

HI, thank you for the detailed implementation
I was trying to convert the torch model to a tflite quantization for Coral Edge TPU.
However, I found when I call the converter.convert(), it will pop up a
Javascript Error: Too much recursion.
Did you encounter this problem before?

converting .pt file to .tflite file

Hello
first, I'd like to thank you all for your hard work.
second, I'm still learning about python and all kind of models
I've seen the instructions about converting PyTorch -> ONNX -> TF -> TFLite
here: https://github.com/sithu31296/PyTorch-ONNX-TFLite
and I'm still confused on how to do that, looked up online but there is no step-by-step tutorial
if anyone can explain it to me because I've been really struggling with it
thank you

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.