diff --git a/tests/test_tflite_utils.py b/tests/test_tflite_utils.py new file mode 100644 index 000000000..f6a05edb1 --- /dev/null +++ b/tests/test_tflite_utils.py @@ -0,0 +1,74 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +"""Unit Tests for TFLite utils.""" + +import os +import tensorflow as tf + +from common import * # pylint: disable=wildcard-import,unused-wildcard-import +from backend_test_base import Tf2OnnxBackendTestBase +from tf2onnx.tf_loader import from_function, tf_session +from tf2onnx.tflite_utils import read_tflite_model, parse_tflite_graph + +# pylint: disable=missing-docstring + + +class TFListUtilsTests(Tf2OnnxBackendTestBase): + + @check_tf_min_version("2.0") + def test_parse_tflite_graph(self): + + def func(a, b, c): + alpha = tf.constant(1.1, dtype=tf.float32) + beta = tf.constant(2.3, dtype=tf.float32) + mul1 = tf.multiply(alpha, tf.matmul(a, b)) + mul2 = tf.multiply(beta, c) + x_ = mul1 + mul2 + return tf.identity(x_, name="output") + + inp_shapes = [[2, 3], [3, 1], [2, 1]] + inp_dtypes = [tf.float32, tf.float32, tf.float32] + names = ['a', 'b', 'c'] + names_with_port = ['a:0', 'b:0', 'c:0'] + output_names = ['output'] + output_names_with_port = ['output:0'] + + input_tensors = [tf.TensorSpec(shape=s, dtype=d, name=n) for s, d, n in zip(inp_shapes, inp_dtypes, names)] + + concrete_func = tf.function(func, input_signature=tuple(input_tensors)) + concrete_func = concrete_func.get_concrete_function() + graph_def = from_function(concrete_func, + input_names=names_with_port, + output_names=output_names_with_port) + with tf_session() as sess: + tf.import_graph_def(graph_def, name='') + sess_inputs = [sess.graph.get_tensor_by_name(k) for k in names_with_port] + sess_outputs = [sess.graph.get_tensor_by_name(n) for n in output_names_with_port] + converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs) + + tflite_model = converter.convert() + tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite") + dir_name = os.path.dirname(tflite_path) + tflite_model = converter.convert() + os.makedirs(dir_name, exist_ok=True) + with open(tflite_path, 'wb') as f: + f.write(tflite_model) + + tflite_graphs, opcodes_map, model = read_tflite_model(tflite_path) + self.assertEqual(1, len(tflite_graphs)) + onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \ + parse_tflite_graph(tflite_graphs[0], opcodes_map, model) + self.assertEqual(2, op_cnt['MUL']) + self.assertEqual(1, op_cnt['ADD']) + self.assertEqual(1, op_cnt['FULLY_CONNECTED']) + + self.assertEqual(1, attr_cnt['WeightsFormat']) + self.assertEqual(names, inputs) + self.assertEqual(output_names, outputs) + + for name, shape, dtype in zip(names, inp_shapes, inp_dtypes): + self.assertEqual(shape, output_shapes[name]) + self.assertEqual(dtype, dtypes[name]) + + self.assertTrue(len(onnx_nodes) >= 4) diff --git a/tf2onnx/tflite_utils.py b/tf2onnx/tflite_utils.py new file mode 100644 index 000000000..509258398 --- /dev/null +++ b/tf2onnx/tflite_utils.py @@ -0,0 +1,304 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +""" +tf2onnx.tflite_utils - utilities for parsing tflite files into onnx graph +""" + +import collections +import importlib + +from onnx import helper, onnx_pb, numpy_helper +from tensorflow.core.framework import types_pb2, tensor_pb2 +from tensorflow.python.framework import tensor_util +from tflite.TensorType import TensorType as TFLiteTensorType +from tflite.Model import Model + + +TFLITE_TO_ONNX_DTYPE = { + TFLiteTensorType.FLOAT32: onnx_pb.TensorProto.FLOAT, + TFLiteTensorType.FLOAT16: onnx_pb.TensorProto.FLOAT16, + TFLiteTensorType.INT32: onnx_pb.TensorProto.INT32, + TFLiteTensorType.UINT8: onnx_pb.TensorProto.UINT8, + TFLiteTensorType.INT64: onnx_pb.TensorProto.INT64, + TFLiteTensorType.STRING: onnx_pb.TensorProto.STRING, + TFLiteTensorType.BOOL: onnx_pb.TensorProto.BOOL, + TFLiteTensorType.INT16: onnx_pb.TensorProto.INT16, + TFLiteTensorType.COMPLEX64: onnx_pb.TensorProto.COMPLEX64, + TFLiteTensorType.INT8: onnx_pb.TensorProto.INT8, + TFLiteTensorType.FLOAT64: onnx_pb.TensorProto.DOUBLE, + TFLiteTensorType.COMPLEX128: onnx_pb.TensorProto.COMPLEX128, + TFLiteTensorType.UINT64: onnx_pb.TensorProto.UINT64, +} + + +TFLITE_TO_TF_DTYPE = { + TFLiteTensorType.FLOAT32: types_pb2.DT_FLOAT, + TFLiteTensorType.FLOAT16: types_pb2.DT_HALF, + TFLiteTensorType.INT32: types_pb2.DT_INT32, + TFLiteTensorType.UINT8: types_pb2.DT_UINT8, + TFLiteTensorType.INT64: types_pb2.DT_INT64, + TFLiteTensorType.STRING: types_pb2.DT_STRING, + TFLiteTensorType.BOOL: types_pb2.DT_BOOL, + TFLiteTensorType.INT16: types_pb2.DT_INT16, + TFLiteTensorType.COMPLEX64: types_pb2.DT_COMPLEX64, + TFLiteTensorType.INT8: types_pb2.DT_INT8, + TFLiteTensorType.FLOAT64: types_pb2.DT_DOUBLE, + TFLiteTensorType.COMPLEX128: types_pb2.DT_COMPLEX128, + TFLiteTensorType.UINT64: types_pb2.DT_UINT64, +} + + +def map_tflite_dtype_to_onnx(dtype): + return TFLITE_TO_ONNX_DTYPE[dtype] + + +def map_tflite_dtype_to_tf(dtype): + return TFLITE_TO_TF_DTYPE[dtype] + + +# The tflite schema uses snake case, but the python bindings use proper case +def snake_to_proper_case(name): + return ''.join(n.capitalize() for n in name.split('_')) + + +def proper_to_snake_case(name): + res = '' + for c in name: + if c.isupper() and res: + res += '_' + res += c.lower() + return res + +# Pulled from the tflite schema.fbs file. Needed to decode enum numbers into strings. +NODE_ATTR_NAME_TO_ENUM_TYPE = { + 'fused_activation_function': 'ActivationFunctionType', + 'padding': 'Padding', + 'type': 'LSHProjectionType', + 'weights_format': 'FullyConnectedOptionsWeightsFormat', + 'kernel_type': 'LSTMKernelType', + 'combiner': 'CombinerType', + 'in_data_type': 'TensorType', + 'out_data_type': 'TensorType', + 'output_type': 'TensorType', + 'out_type': 'TensorType', + 'mode': 'MirrorPadMode', + 'idx_out_type': 'TensorType', +} +NODE_ATTR_NAME_TO_ENUM_TYPE = {snake_to_proper_case(key): value for key, value in NODE_ATTR_NAME_TO_ENUM_TYPE.items()} + +# Pulled from the tflite schema.fbs file. +FUNCTION_ATTRS = ['then_subgraph_index', 'else_subgraph_index', 'cond_subgraph_index', + 'body_subgraph_index', 'subgraph'] +FUNCTION_ATTRS = [snake_to_proper_case(attr) for attr in FUNCTION_ATTRS] + + +enum_cache = {} +def lookup_enum(idx, enum_name): + """Given the name of a tflite enum class and an index, return a string with the name of the enum value""" + if enum_name == 'TensorType': + return map_tflite_dtype_to_onnx(idx) + if enum_name in enum_cache: + return enum_cache[enum_name][idx] + module = importlib.import_module('tflite.' + enum_name) + enum_class = getattr(module, enum_name) + idx_to_name = {value: key for key, value in enum_class.__dict__.items() if not key.startswith('_')} + enum_cache[enum_name] = idx_to_name + return idx_to_name[idx] + + +def get_options_class(name): + """Each tflite optype has a flatbuffer Options class (ex: AddOptions). Returns the options class given its name.""" + if name == "NONE": + return None + module = importlib.import_module('tflite.' + name) + return getattr(module, name) + + +def read_tflite_model(tflite_path): + """ + Given the path to a tflite model, returns tuple (tflite_graphs, opcodes_map, model) + Pass these to parse_tflite_graph + """ + with open(tflite_path, 'rb') as f: + buf = f.read() + buf = bytearray(buf) + model = Model.GetRootAsModel(buf, 0) + # To save space, each op in the model indicates its opcode as an index into the model's opcode map. + opcodes_map = {} + for i in range(model.OperatorCodesLength()): + op_code = model.OperatorCodes(i) + # TFlite ran out of opcodes since they only used a byte. Old models store opcodes in DeprecatedBuiltinCode. + # New models put PLACEHOLDER_FOR_GREATER_OP_CODES in this field to signify that BuiltinCode should be used. + code = lookup_enum(op_code.DeprecatedBuiltinCode(), 'BuiltinOperator') + if code == 'PLACEHOLDER_FOR_GREATER_OP_CODES': + code = lookup_enum(op_code.BuiltinCode(), 'BuiltinOperator') + opcodes_map[i] = code + tflite_graphs = [model.Subgraphs(i) for i in range(model.SubgraphsLength())] + return tflite_graphs, opcodes_map, model + + +def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): + """ + Returns a Graph object along with some op count stats. All tflite op types are prefixed with "TFL_". + Names of graph inputs are optionally prefixed with a string to prevent name conflicts in subgraphs. + Quantizatized tensors are surrounded with quantize/dequantize ops + """ + op_cnt = collections.Counter() + attr_cnt = collections.Counter() + onnx_nodes = [] + output_shapes = {} + dtypes = {} + tensor_names = {} + # Map tensor name to tflite Tensor object so we can fetch quantization info as needed + name_to_tensor = {} + # If a node takes a quantized tensor as input, we must add a dequantize op after it. + # Store a mapping so we only need to make at most one dequantize op per tensor. + tensor_name_to_dequant_output = {} + + # tflite uses generic names (arg0, arg1, etc.) for inputs but full names for other tensors, so + # prefixing just the inputs should be fine. Other tensors are prefixed when we do inlining. + input_indices = {tflite_g.Inputs(i) for i in range(tflite_g.InputsLength())} + + for i in range(tflite_g.TensorsLength()): + tensor = tflite_g.Tensors(i) + name = tensor.Name().decode() + if i in input_indices: + name = input_prefix + name + tensor_names[i] = name + name_to_tensor[name] = tensor + + if tensor.ShapeIsNone(): + output_shapes[name] = None + elif tensor.ShapeSignatureIsNone(): + # The shape signature uses -1 to signify unknown dims. Old models don't have this and use Shape instead. + output_shapes[name] = tensor.ShapeAsNumpy().tolist() + else: + output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist() + buf = model.Buffers(tensor.Buffer()) + dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type()) + if not buf.DataIsNone(): + # For const values we use TF to decode the binary data from the buffer + t = tensor_pb2.TensorProto() + t.tensor_content = buf.DataAsNumpy().tobytes() + if output_shapes[name] is None: + output_shapes[name] = [] + for d in output_shapes[name]: + t.tensor_shape.dim.add().size = d + t.dtype = map_tflite_dtype_to_tf(tensor.Type()) + np_data = tensor_util.MakeNdarray(t) + onnx_tensor = numpy_helper.from_array(np_data, name=name) + onnx_node = helper.make_node("Const", [], outputs=[name], name=name, value=onnx_tensor) + onnx_nodes.append(onnx_node) + op_cnt["Const"] += 1 + + def get_dequant(tensor_name): + """Creates a dequantize op for the provided tensor if needed and returns the output of the op, or + the original tensor name if no dequantization is needed""" + quant = name_to_tensor[tensor_name].Quantization() + if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone(): + return tensor_name + if tensor_name in tensor_name_to_dequant_output: + return tensor_name_to_dequant_output[tensor_name] + dequant_name = tensor_name + "_dequant" + attr = {} + attr['scale'] = quant.ScaleAsNumpy().tolist() + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() + attr['quantized_dimension'] = quant.QuantizedDimension() + onnx_node = helper.make_node("TFL_DEQUANTIZE", [tensor_name], [dequant_name], name=dequant_name, **attr) + onnx_nodes.append(onnx_node) + tensor_name_to_dequant_output[tensor_name] = dequant_name + output_shapes[dequant_name] = output_shapes[tensor_name].copy() + dtypes[dequant_name] = onnx_pb.TensorProto.FLOAT + return dequant_name + + def get_prequant(tensor_name): + """Called by nodes with the name of the tensor they must output. + If the output is supposed to be quantized, creates a Quantize op outputting the tensor. + Returns the name that should be used for the "prequantized" tensor, or the original tensor if no quantization + is needed""" + quant = name_to_tensor[tensor_name].Quantization() + if quant is None or quant.ScaleIsNone() or quant.ZeroPointIsNone(): + return tensor_name + prequant_name = tensor_name + "_prequant" + quantize_name = tensor_name + "_quantize" + attr = {} + attr['scale'] = quant.ScaleAsNumpy().tolist() + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() + attr['quantized_dimension'] = quant.QuantizedDimension() + onnx_node = helper.make_node("TFL_QUANTIZE", [prequant_name], [tensor_name], name=quantize_name, **attr) + onnx_nodes.append(onnx_node) + output_shapes[prequant_name] = output_shapes[tensor_name].copy() + dtypes[prequant_name] = onnx_pb.TensorProto.FLOAT + return prequant_name + + for i in range(tflite_g.OperatorsLength()): + op = tflite_g.Operators(i) + optype = opcodes_map[op.OpcodeIndex()] + op_cnt[optype] += 1 + attr = {} + options_type_name = lookup_enum(op.BuiltinOptionsType(), 'BuiltinOptions') + option_class = get_options_class(options_type_name) + wants_dequantized_input = True + has_prequantized_output = True + if optype == 'QUANTIZE': + out_tensor = tflite_g.Tensors(op.Outputs(0)) + quant = out_tensor.Quantization() + has_prequantized_output = False + if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone(): + attr['scale'] = quant.ScaleAsNumpy().tolist() + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() + attr['quantized_dimension'] = quant.QuantizedDimension() + elif optype == 'DEQUANTIZE': + in_tensor = tflite_g.Tensors(op.Inputs(0)) + quant = in_tensor.Quantization() + wants_dequantized_input = False + if quant is not None and not quant.ScaleIsNone() and not quant.ZeroPointIsNone(): + attr['scale'] = quant.ScaleAsNumpy().tolist() + attr['zero_point'] = quant.ZeroPointAsNumpy().tolist() + attr['quantized_dimension'] = quant.QuantizedDimension() + if option_class is not None: + options = option_class() + options.Init(op.BuiltinOptions().Bytes, op.BuiltinOptions().Pos) + # All flatbuffer objects have these properties. + block_list = [options_type_name + 'BufferHasIdentifier', 'Init', 'GetRootAs' + options_type_name] + # The rest of the properties of the options class provide its attribute names + attr_names = {opt for opt in dir(options) if not opt.startswith('_') and opt not in block_list} + for a in list(attr_names): + # Flatbufffer list properties have 3 functions: *Length, *IsNone, and *AsNumpy + if a + 'Length' in attr_names: + attr_names.remove(a + 'Length') + attr_names.remove(a + 'IsNone') + attr_names.remove(a) + for a in attr_names: + if a.endswith('AsNumpy'): + value = getattr(options, a)().tolist() + a = a[:-len('AsNumpy')] + else: + # For enums we use a string with the value name, not enum index + value = getattr(options, a)() + if a in NODE_ATTR_NAME_TO_ENUM_TYPE: + value = lookup_enum(value, NODE_ATTR_NAME_TO_ENUM_TYPE[a]) + elif a in FUNCTION_ATTRS: + value = model.Subgraphs(value).Name().decode() + attr_cnt[a] += 1 + attr[proper_to_snake_case(a)] = value + input_names = [tensor_names[op.Inputs(i)] for i in range(op.InputsLength()) if op.Inputs(i) != -1] + if wants_dequantized_input: + input_names = [get_dequant(inp) for inp in input_names] + output_names = [tensor_names[op.Outputs(i)] for i in range(op.OutputsLength()) if op.Outputs(i) != -1] + if has_prequantized_output: + output_names = [get_prequant(out) for out in output_names] + onnx_node = helper.make_node("TFL_" + optype, input_names, output_names, name=output_names[0], **attr) + onnx_nodes.append(onnx_node) + + inputs = [tensor_names[tflite_g.Inputs(i)] for i in range(tflite_g.InputsLength())] + outputs = [tensor_names[tflite_g.Outputs(i)] for i in range(tflite_g.OutputsLength())] + # TODO: Allow input/outputs to be overridden + + for inp in inputs: + onnx_node = helper.make_node("Placeholder", [], outputs=[inp], name=inp) + onnx_nodes.append(onnx_node) + + graph_name = (tflite_g.Name() or b'tflite graph').decode() + return onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, graph_name