From 50b77e0498a6af1374185f8b43c53808b172b4eb Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Wed, 24 Feb 2021 22:06:15 -0500 Subject: [PATCH] Improve reliability of reading tflite models Signed-off-by: Tom Wildenhain --- tests/test_tflite_utils.py | 4 ++-- tf2onnx/tflite_handlers/tfl_math.py | 7 ++++--- tf2onnx/tflite_utils.py | 32 +++++++++++++++++++++++++---- tf2onnx/tfonnx.py | 7 +++++-- 4 files changed, 39 insertions(+), 11 deletions(-) diff --git a/tests/test_tflite_utils.py b/tests/test_tflite_utils.py index 6930afd50..d7bde093d 100644 --- a/tests/test_tflite_utils.py +++ b/tests/test_tflite_utils.py @@ -55,10 +55,10 @@ def func(a, b, c): with open(tflite_path, 'wb') as f: f.write(tflite_model) - tflite_graphs, opcodes_map, model = read_tflite_model(tflite_path) + tflite_graphs, opcodes_map, model, tensor_shapes = read_tflite_model(tflite_path) self.assertEqual(1, len(tflite_graphs)) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \ - parse_tflite_graph(tflite_graphs[0], opcodes_map, model) + parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes) self.assertEqual(2, op_cnt['MUL']) self.assertEqual(1, op_cnt['ADD']) self.assertEqual(1, op_cnt['FULLY_CONNECTED']) diff --git a/tf2onnx/tflite_handlers/tfl_math.py b/tf2onnx/tflite_handlers/tfl_math.py index f463976a9..eba80baf2 100644 --- a/tf2onnx/tflite_handlers/tfl_math.py +++ b/tf2onnx/tflite_handlers/tfl_math.py @@ -222,6 +222,7 @@ class TFlSoftmaxOp: @classmethod def to_tf(cls, ctx, node, **kwargs): beta = node.get_attr_value("beta") - beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32)) - mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name)) - ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]]) + if beta != 1: + beta_node = ctx.make_const(utils.make_name("beta"), np.array(beta, dtype=np.float32)) + mul_node = ctx.insert_new_node_on_output("Mul", node.output[0], name=utils.make_name(node.name)) + ctx.replace_inputs(mul_node, [node.output[0], beta_node.output[0]]) diff --git a/tf2onnx/tflite_utils.py b/tf2onnx/tflite_utils.py index bfcdde68c..56559d1e9 100644 --- a/tf2onnx/tflite_utils.py +++ b/tf2onnx/tflite_utils.py @@ -7,14 +7,17 @@ import collections import importlib +import logging from onnx import helper, onnx_pb, numpy_helper from tensorflow.core.framework import types_pb2, tensor_pb2 from tensorflow.python.framework import tensor_util +import tensorflow as tf from tf2onnx.tflite.TensorType import TensorType as TFLiteTensorType from tf2onnx.tflite.Model import Model from tf2onnx.flexbuffers import read_flexbuffer +logger = logging.getLogger(__name__) TFLITE_TO_ONNX_DTYPE = { TFLiteTensorType.FLOAT32: onnx_pb.TensorProto.FLOAT, @@ -138,7 +141,22 @@ def read_tflite_model(tflite_path): code = op_code.CustomCode().decode() opcodes_map[i] = code tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())] - return tflite_graphs, opcodes_map, model + # Shapes stored in tflite models are not always reliable so we get them from the interpreter if possible. + tensor_shapes = {} + try: + interpreter = tf.lite.Interpreter(tflite_path) + interpreter.allocate_tensors() + tensor_cnt = model.Subgraphs(0).TensorsLength() + for i in range(tensor_cnt): + name = model.Subgraphs(0).Tensors(i).Name().decode() + details = interpreter._get_tensor_details(i) # pylint: disable=protected-access + if "shape_signature" in details: + tensor_shapes[name] = details["shape_signature"].tolist() + elif "shape" in details: + tensor_shapes[name] = details["shape"].tolist() + except Exception as e: # pylint: disable=broad-except + logger.warning("Error loading model into tflite interpreter: %s", e) + return tflite_graphs, opcodes_map, model, tensor_shapes def get_quantization_attr(quant_params): @@ -153,7 +171,7 @@ def get_quantization_attr(quant_params): return attr -def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): +def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix='', tensor_shapes_override=None): """ Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_". Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs. @@ -165,6 +183,8 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): output_shapes = {} dtypes = {} tensor_names = {} + if tensor_shapes_override is None: + tensor_shapes_override = {} # Map tensor name to tflite Tensor object so we can fetch quantization info as needed name_to_tensor = {} # If a node takes a quantized tensor as input, we must add a dequantize op after it. @@ -183,7 +203,9 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): tensor_names[i] = name name_to_tensor[name] = tensor - if tensor.ShapeIsNone(): + if name in tensor_shapes_override: + output_shapes[name] = tensor_shapes_override[name] + elif tensor.ShapeIsNone(): output_shapes[name] = None elif tensor.ShapeSignatureIsNone(): # The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead. @@ -265,7 +287,9 @@ def get_prequant(tensor_name): if not op.CustomOptionsIsNone(): custom_ops_format = lookup_enum(op.CustomOptionsFormat(), 'CustomOptionsFormat') if custom_ops_format == 'FLEXBUFFERS': - attr.update(read_flexbuffer(op.CustomOptionsAsNumpy().tobytes())) + data = read_flexbuffer(op.CustomOptionsAsNumpy().tobytes()) + if isinstance(data, dict): + attr.update(read_flexbuffer(op.CustomOptionsAsNumpy().tobytes())) if option_class is not None: options = option_class() options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos) diff --git a/tf2onnx/tfonnx.py b/tf2onnx/tfonnx.py index a2ef25729..05d471e73 100644 --- a/tf2onnx/tfonnx.py +++ b/tf2onnx/tfonnx.py @@ -461,14 +461,17 @@ def rename_tensors_in_nodes(onnx_nodes): n.output[:] = rename_tensors_in_list(n.output) if tflite_path is not None: - tflite_graphs, opcodes, model = read_tflite_model(tflite_path) + tflite_graphs, opcodes, model, tensor_shapes = read_tflite_model(tflite_path) main_g = None inputs_as_nchw = rename_tensors_in_list(inputs_as_nchw) for i in reversed(range(len(tflite_graphs))): tfl_graph = tflite_graphs[i] prefix = '' if i == 0 else tfl_graph.Name().decode() + '_' + tensor_shapes_from_interpreter = None + if i == 0: + tensor_shapes_from_interpreter = tensor_shapes onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, f_inputs, f_outputs, graph_name = \ - parse_tflite_graph(tfl_graph, opcodes, model, prefix) + parse_tflite_graph(tfl_graph, opcodes, model, prefix, tensor_shapes_from_interpreter) g_inputs = f_inputs g_outputs = f_outputs if i == 0: