From 748f4916ac32e4611494b08ef88fd8260f36d83d Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Thu, 21 Jan 2021 23:59:28 -0500 Subject: [PATCH] Optimization for tflite loops Signed-off-by: Tom Wildenhain --- tf2onnx/graph.py | 3 + tf2onnx/tflite_handlers/tfl_controlflow.py | 46 +++++- tf2onnx/tflite_rewriters/__init__.py | 7 + .../tfl_scan_output_rewriter.py | 156 ++++++++++++++++++ tf2onnx/tflite_utils.py | 2 +- 5 files changed, 211 insertions(+), 3 deletions(-) create mode 100644 tf2onnx/tflite_rewriters/__init__.py create mode 100644 tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py diff --git a/tf2onnx/graph.py b/tf2onnx/graph.py index ef2a96f43..efff0a286 100644 --- a/tf2onnx/graph.py +++ b/tf2onnx/graph.py @@ -465,6 +465,9 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No self.graph_name = graph_name or "tf2onnx" self._is_subgraph = is_subgraph self.ta_reads = [] + # A list of index, output tuples of potential scan outputs in this graph + # Used by the tflite while loop handler + self.scan_outputs = [] self.func_inputs = [] self._target = set(target) diff --git a/tf2onnx/tflite_handlers/tfl_controlflow.py b/tf2onnx/tflite_handlers/tfl_controlflow.py index e83893eb6..2a9c1c39d 100644 --- a/tf2onnx/tflite_handlers/tfl_controlflow.py +++ b/tf2onnx/tflite_handlers/tfl_controlflow.py @@ -12,6 +12,7 @@ from tf2onnx.handler import tfl_op from tf2onnx import utils from tf2onnx.tf_loader import find_function +from tf2onnx.graph_builder import GraphBuilder from tf2onnx.onnx_opset.controlflow import parameter_binding, inline_subgraph @@ -40,6 +41,19 @@ def version_7(cls, ctx, node, **kwargs): cond_binding = parameter_binding(cond_graph, tfl_while_inputs) cond_outputs = inline_subgraph(ctx, cond_graph, cond_name, cond_binding) + # Potential scan output candidates are identified in the body subgraph using tfl_scan_output_rewriter. + # They can then be optimized in this tfl loop handler provided they are not used in the cond subgraph. + scan_outputs = sorted(body.scan_outputs, reverse=True) + def input_is_unused(g, index): + return len(g.find_output_consumers(g.func_inputs[index])) == 0 + scan_outputs = [(i, out) for i, out in scan_outputs if input_is_unused(cond_graph, i)] + + for idx, _ in scan_outputs: + del tfl_while_inputs[idx] + output_shapes.append(output_shapes.pop(idx)) + output_dtypes.append(output_dtypes.pop(idx)) + output_names.append(output_names.pop(idx)) + max_iterations = ctx.make_const(utils.make_name("max_iterations"), np.array(np.iinfo(np.int64).max)) loop_node = ctx.make_node("Loop", [max_iterations.output[0], cond_outputs[0]] + tfl_while_inputs, @@ -52,15 +66,21 @@ def version_7(cls, ctx, node, **kwargs): for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() - body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph) + body = wire_tfl_while_body(body, loop_node.inputs, output_shapes, output_dtypes, cond_graph, scan_outputs) + + for i in range(len(scan_outputs)): + squeeze_node = GraphBuilder(body).make_squeeze( + {'data': body.outputs[-1-i], "axes": [0]}, return_node=True) + body.outputs[-1-i] = squeeze_node.output[0] loop_node.set_body_graph_as_attr("body", body) def wire_tfl_while_body(g, loop_node_inputs, output_shapes, - output_dtypes, cond_graph): + output_dtypes, cond_graph, scan_outputs): """Wire subgraph graph into main.""" g = copy.deepcopy(g) + graph_inputs = g.func_inputs.copy() # onnx will pass in cond as argument iter_node = g.make_node("Placeholder", [], name=utils.make_name("iteration_num"), @@ -69,6 +89,28 @@ def wire_tfl_while_body(g, loop_node_inputs, output_shapes, output_count=1, dtypes=[TensorProto.BOOL], shapes=[[]]) cond_binding = parameter_binding(cond_graph, g.outputs) + to_remove = set() + for idx, scan_output in scan_outputs: + inp = g.get_node_by_output(graph_inputs[idx]) + + # Remove consumers of scan input + stack = [inp] + while stack: + node = stack.pop() + if node not in to_remove: + to_remove.add(node) + for out in node.output: + stack += g.find_output_consumers(out) + + # Remove scan input from cond graph + cond_binding = {k: "@@ALLOC" if v == g.outputs[idx] else v for k, v in cond_binding.items()} + del g.func_inputs[idx] + del g.outputs[idx] + g.outputs.append(scan_output) + + for node in to_remove: + g.remove_node(node.name) + # in onnx the body inputs are: index, cond, [loop_vars] g.func_inputs = [iter_node.output[0], cond_node.output[0]] + g.func_inputs # tell graph lib to keep inputs in order diff --git a/tf2onnx/tflite_rewriters/__init__.py b/tf2onnx/tflite_rewriters/__init__.py new file mode 100644 index 000000000..48608aa29 --- /dev/null +++ b/tf2onnx/tflite_rewriters/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""tf2onnx.tflite_rewriters module""" + +from . import ( + tfl_scan_output_rewriter +) diff --git a/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py b/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py new file mode 100644 index 000000000..5cb1223b0 --- /dev/null +++ b/tf2onnx/tflite_rewriters/tfl_scan_output_rewriter.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + + +""" +tf2onnx.tflite_rewriters.tfl_scan_output_rewriter - Identify a common slice/concat pattern in tflite subgraphs +Effectively replace A = A[:i] + [B] + A[i+1:] with A[i] = B +""" +import numpy as np + +from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher + + +# pylint: disable=missing-docstring + +def rewrite_slice_concat_to_scatter(g, ops): + pattern0 = \ + OpTypePattern('TFL_CONCATENATION', name='concat', inputs=[ + OpTypePattern('TFL_SLICE', name='begin_slice'), + OpTypePattern('*', name='middle'), + OpTypePattern('TFL_SLICE', name='end_slice') + ]) + + matcher = GraphMatcher(pattern0, allow_reorder=False) + match_results = list(matcher.match_ops(ops)) + if match_results: + for match in match_results: + concat = match.get_op("concat") + begin_slice = match.get_op("begin_slice") + middle = match.get_op("middle") + end_slice = match.get_op("end_slice") + middle_shape = g.get_shape(middle.output[0]) + + # Both slices must be slicing the same tensor + if begin_slice.input[0] != end_slice.input[0]: + continue + original_tensor = begin_slice.input[0] + if concat.get_attr_int("axis") != 0: + continue + # The inserted slice must have length 1 (to be a single index) + if middle_shape is None or len(middle_shape) == 0 or middle_shape[0] != 1: + continue + rank = len(middle_shape) + scan_output = middle.output[0] + if not begin_slice.inputs[1].is_const() or not end_slice.inputs[2].is_const(): + continue + # The first slice must start from the beginning (0) for all dims + if not all(v == 0 for v in begin_slice.inputs[1].get_tensor_value()): + continue + # The second slice must slice to the end (-1) for all dims + if not all(v == -1 for v in end_slice.inputs[2].get_tensor_value()): + continue + # The other slice dims are assembled by concatenation if rank > 1 + if rank > 1: + begin_concat = begin_slice.inputs[2] + end_concat = end_slice.inputs[1] + if not begin_concat.type == "TFL_CONCATENATION": + continue + if not end_concat.type == "TFL_CONCATENATION": + continue + # Except for dim 0, slice from beginning to end + if not all(get_uniform_const_val(inp) == -1 for inp in begin_concat.inputs[1:]): + continue + if not all(get_uniform_const_val(inp) == 0 for inp in end_concat.inputs[1:]): + continue + begin_idx = begin_concat.inputs[0] + end_idx = end_concat.inputs[0] + else: + begin_idx = begin_slice.inputs[2] + end_idx = end_slice.inputs[1] + # For dim 0, slice to i for first part and from i+1 for second + if not node_is_one_plus_node(begin_idx, end_idx): + continue + out1, _ = get_out_and_offset(begin_idx) + graph_inps = [n.output[0] for n in g.inputs] + # To be a scan output, i must be a graph input + if out1 not in graph_inps: + continue + # The array being sliced must be a graph input + if original_tensor not in graph_inps: + continue + # The input/output index of i + idx = graph_inps.index(out1) + # The input/output index of the array + scan_output_idx = graph_inps.index(original_tensor) + # For a scan output, i must be assigned to i+1 with each iteration + if not node_is_one_plus_node(g.get_node_by_output(out1), g.get_node_by_output(g.outputs[idx])): + continue + if len(g.find_output_consumers(concat.output[0])) > 1: + continue + + if g.opset < 10 and len(g.find_output_consumers(concat.output[0])) <= 1: + # If opset is < 10, conversion of the subgraph will fail unless we remove the slice nodes + # We add a tmp node to replace them. + shape = g.get_shape(concat.output[0]) + dtype = g.get_dtype(concat.output[0]) + tmp_node = g.make_node("TMP_SCAN_OUTPUT", [original_tensor, scan_output], + shapes=[shape], dtypes=[dtype]) + g.replace_all_inputs(concat.output[0], tmp_node.output[0]) + + to_remove = [] + out = g.outputs[scan_output_idx] + node = g.get_node_by_output(out) + to_remove.append(node) + + while len(node.input) > 0 and node != concat: + out = node.input[0] + node = g.get_node_by_output(out) + to_remove.append(node) + + to_remove += [begin_slice, end_slice, concat] + + out = original_tensor + node = g.get_node_by_output(out) + to_remove.append(node) + + while len(node.input) > 0: + out = node.input[0] + node = g.get_node_by_output(out) + to_remove.append(node) + + if not g.is_safe_to_remove_nodes(to_remove): + continue + + g.scan_outputs.append((scan_output_idx, scan_output)) + return ops + +def get_uniform_const_val(n): + if not n.is_const(): + return None + v = n.get_tensor_value(as_list=False).flatten() + if len(v) == 0: + return None + if np.all(v == v[0]): + return v[0] + return None + +def get_out_and_offset(n): + if n.type in ['TFL_RESHAPE', 'TFL_IDENTITY', 'Identity']: + return get_out_and_offset(n.inputs[0]) + if n.type == 'TFL_ADD': + v1 = get_uniform_const_val(n.inputs[0]) + v2 = get_uniform_const_val(n.inputs[1]) + if v1 is not None and v2 is not None: + return '', v1 + v2 + if v1 is not None: + inp2, o2 = get_out_and_offset(n.inputs[1]) + return inp2, v1 + o2 + if v2 is not None: + inp1, o1 = get_out_and_offset(n.inputs[0]) + return inp1, v2 + o1 + return n.output[0], 0 + +def node_is_one_plus_node(node, one_plus_node): + n1, o1 = get_out_and_offset(node) + n2, o2 = get_out_and_offset(one_plus_node) + return n1 == n2 and o1 + 1 == o2 diff --git a/tf2onnx/tflite_utils.py b/tf2onnx/tflite_utils.py index 5ee6224a6..82cd2f5d8 100644 --- a/tf2onnx/tflite_utils.py +++ b/tf2onnx/tflite_utils.py @@ -180,7 +180,7 @@ def parse_tflite_graph(tflite_g, opcodes_map, model, input_prefix=''): output_shapes[name] = tensor.ShapeSignatureAsNumpy().tolist() buf = model.Buffers(tensor.Buffer()) dtypes[name] = map_tflite_dtype_to_onnx(tensor.Type()) - if not buf.DataIsNone(): + if not buf.DataIsNone() and tensor.Buffer() > 0: # For const values we use TF to decode the binary data from the buffer t = tensor_pb2.TensorProto() t.tensor_content = buf.DataAsNumpy().tobytes()