Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/test_tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def func(a, b, c):
with open(tflite_path, 'wb') as f:
f.write(tflite_model)

tflite_graphs, opcodes_map, model = read_tflite_model(tflite_path)
tflite_graphs, opcodes_map, model, tensor_shapes = read_tflite_model(tflite_path)
self.assertEqual(1, len(tflite_graphs))
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \
parse_tflite_graph(tflite_graphs[0], opcodes_map, model)
parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes)
self.assertEqual(2, op_cnt['MUL'])
self.assertEqual(1, op_cnt['ADD'])
self.assertEqual(1, op_cnt['FULLY_CONNECTED'])
Expand Down
7 changes: 4 additions & 3 deletions tf2onnx/tflite_handlers/tfl_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ class TFlSoftmaxOp:
@classmethod
def to_tf(cls, ctx, node, **kwargs):
beta = node.get_attr_value("beta")
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])
if beta != 1:
beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32))
mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name))
ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]])
32 changes: 28 additions & 4 deletions tf2onnx/tflite_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

import collections
import importlib
import logging

from onnx import helper, onnx_pb, numpy_helper
from tensorflow.core.framework import types_pb2, tensor_pb2
from tensorflow.python.framework import tensor_util
import tensorflow as tf
from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType
from tf2onnx.tflite.Model import Model
from tf2onnx.flexbuffers import read_flexbuffer

logger = logging.getLogger(__name__)

TFLITE_TO_ONNX_DTYPE = {
TFLiteTensorType.FLOAT32: onnx_pb.TensorProto.FLOAT,
Expand Down Expand Up @@ -138,7 +141,22 @@ def read_tflite_model(tflite_path):
code = op_code.CustomCode().decode()
opcodes_map[i] = code
tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())]
return tflite_graphs, opcodes_map, model
# Shapes stored in tflite models are not always reliable so we get them from the interpreter if possible.
tensor_shapes = {}
try:
interpreter = tf.lite.Interpreter(tflite_path)
interpreter.allocate_tensors()
tensor_cnt = model.Subgraphs(0).TensorsLength()
for i in range(tensor_cnt):
name = model.Subgraphs(0).Tensors(i).Name().decode()
details = interpreter._get_tensor_details(i) # pylint: disable=protected-access
if "shape_signature" in details:
tensor_shapes[name] = details["shape_signature"].tolist()
elif "shape" in details:
tensor_shapes[name] = details["shape"].tolist()
except Exception as e: # pylint: disable=broad-except
logger.warning("Error loading model into tflite interpreter: %s", e)
return tflite_graphs, opcodes_map, model, tensor_shapes


def get_quantization_attr(quant_params):
Expand All @@ -153,7 +171,7 @@ def get_quantization_attr(quant_params):
return attr


def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix='', tensor_shapes_override=None):
"""
Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_".
Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs.
Expand All @@ -165,6 +183,8 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
output_shapes = {}
dtypes = {}
tensor_names = {}
if tensor_shapes_override is None:
tensor_shapes_override = {}
# Map tensor name to tflite Tensor object so we can fetch quantization info as needed
name_to_tensor = {}
# If a node takes a quantized tensor as input, we must add a dequantize op after it.
Expand All @@ -183,7 +203,9 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''):
tensor_names[i] = name
name_to_tensor[name] = tensor

if tensor.ShapeIsNone():
if name in tensor_shapes_override:
output_shapes[name] = tensor_shapes_override[name]
elif tensor.ShapeIsNone():
output_shapes[name] = None
elif tensor.ShapeSignatureIsNone():
# The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead.
Expand Down Expand Up @@ -265,7 +287,9 @@ def get_prequant(tensor_name):
if not op.CustomOptionsIsNone():
custom_ops_format = lookup_enum(op.CustomOptionsFormat(), 'CustomOptionsFormat')
if custom_ops_format == 'FLEXBUFFERS':
attr.update(read_flexbuffer(op.CustomOptionsAsNumpy().tobytes()))
data = read_flexbuffer(op.CustomOptionsAsNumpy().tobytes())
if isinstance(data, dict):
attr.update(read_flexbuffer(op.CustomOptionsAsNumpy().tobytes()))
if option_class is not None:
options = option_class()
options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos)
Expand Down
7 changes: 5 additions & 2 deletions tf2onnx/tfonnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,14 +461,17 @@ def rename_tensors_in_nodes(onnx_nodes):
n.output[:] = rename_tensors_in_list(n.output)

if tflite_path is not None:
tflite_graphs, opcodes, model = read_tflite_model(tflite_path)
tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path)
main_g = None
inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw)
for i in reversed(range(len(tflite_graphs))):
tfl_graph = tflite_graphs[i]
prefix = '' if i == 0 else tfl_graph.Name().decode() + '_'
tensor_shapes_from_interpreter = None
if i == 0:
tensor_shapes_from_interpreter = tensor_shapes
onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \
parse_tflite_graph(tfl_graph, opcodes, model, prefix)
parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter)
g_inputs = f_inputs
g_outputs = f_outputs
if i == 0:
Expand Down