From b7347a449a47a62db07c74e5b036b2a58e7d4ab6 Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Mon, 21 Jun 2021 17:57:44 -0400 Subject: [PATCH] Fix issue with tf1 loops with tensor array read last pattern Signed-off-by: Tom Wildenhain --- .../azure_pipelines/keras2onnx_unit_test.yml | 12 ++-- tf2onnx/rewriter/loop_rewriter.py | 10 +++- tf2onnx/rewriter/loop_rewriter_base.py | 59 ++++++++++++++----- tf2onnx/rewriter/rnn_utils.py | 3 +- tf2onnx/utils.py | 3 + 5 files changed, 64 insertions(+), 23 deletions(-) diff --git a/ci_build/azure_pipelines/keras2onnx_unit_test.yml b/ci_build/azure_pipelines/keras2onnx_unit_test.yml index 3175d366d..b1206a553 100644 --- a/ci_build/azure_pipelines/keras2onnx_unit_test.yml +++ b/ci_build/azure_pipelines/keras2onnx_unit_test.yml @@ -40,12 +40,12 @@ jobs: INSTALL_ORT: pip install onnxruntime==1.8.0 ############ Pure Keras Unit Tests ############ - # Keras-Py36-tf1.15.0: # Failing, will enable soon. - # python.version: '3.6' - # ONNX_PATH: onnx==1.5.0 - # KERAS: keras==2.2.5 - # TENSORFLOW_PATH: tensorflow==1.15.0 - # INSTALL_ORT: pip install onnxruntime==1.8.0 + Keras-Py36-tf1.15.0: + python.version: '3.6' + ONNX_PATH: onnx==1.5.0 + KERAS: keras==2.2.5 + TENSORFLOW_PATH: tensorflow==1.15.0 + INSTALL_ORT: pip install onnxruntime==1.8.0 Keras-Py37-tf1.15.0: python.version: '3.7' diff --git a/tf2onnx/rewriter/loop_rewriter.py b/tf2onnx/rewriter/loop_rewriter.py index ca84945b6..a95fdb63d 100644 --- a/tf2onnx/rewriter/loop_rewriter.py +++ b/tf2onnx/rewriter/loop_rewriter.py @@ -30,7 +30,7 @@ def create_context(self): def run(self): logger.debug("enter loop rewriter") - return self.run_internal() + return self.run_internal(allow_ta_read_last=True) def need_rewrite(self, context): return True @@ -93,6 +93,10 @@ def rewrite(self, context): logger.error("failed to create loop node during rewrite") return REWRITER_RESULT.FAIL + for unneeded_scan_variable in loop_props.unneeded_scan_variables.values(): + self.g.replace_all_inputs(unneeded_scan_variable.exit_output.id, + unneeded_scan_variable.equivalent_state_variable.exit_output.id) + logger.debug("rewrite successfully") return REWRITER_RESULT.OK @@ -152,7 +156,9 @@ def _create_loop_node(self, context, loop_props, init_cond_output, branches=None n = self.g.get_node_by_output(tensor_value_info.id) self.g.remove_node(n.name) else: - loop_outputs.append(utils.make_name("unused_loop_output_")) + output_name = utils.make_name("unused_loop_output_") + tensor_value_info.id = output_name + loop_outputs.append(output_name) loop_output_shapes.append([-1]) loop_output_dtypes.append(None) diff --git a/tf2onnx/rewriter/loop_rewriter_base.py b/tf2onnx/rewriter/loop_rewriter_base.py index d49320f3a..99a74036f 100644 --- a/tf2onnx/rewriter/loop_rewriter_base.py +++ b/tf2onnx/rewriter/loop_rewriter_base.py @@ -10,7 +10,7 @@ from tf2onnx import utils from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher from tf2onnx.utils import is_tf_loopcond_op, is_tf_tensor_array_op -from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op +from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op, is_tf_tensor_array_read_op from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT from tf2onnx.utils import TensorValueInfo @@ -47,6 +47,7 @@ def __init__(self): # used as initial input for more than one Enter nodes. self.state_variables = OrderedDict() self.scan_variables = OrderedDict() + self.unneeded_scan_variables = OrderedDict() self.tensor_array_inputs = [] # list of type InputTensorArray @@ -55,10 +56,14 @@ def add_variable(self, var): "variable %s already exists as scan variable.", var.enter_name) utils.make_sure(var.enter_name not in self.state_variables, "variable %s already exists as state variable.", var.enter_name) - if not var.is_tensor_array: - self.state_variables[var.enter_name] = var - else: + if var.tensor_array_type == TensorArrayVariableType.READ_LAST: + # If the variable just returns the last value of the constructed tensor array, it doesn't need to be + # a scan output + self.unneeded_scan_variables[var.enter_name] = var + elif var.tensor_array_type == TensorArrayVariableType.GATHER_ALL: self.scan_variables[var.enter_name] = var + else: + self.state_variables[var.enter_name] = var def get_variables(self, checker): if not checker: @@ -69,6 +74,7 @@ def get_variables(self, checker): def all_variables(self): items = self.state_variables.copy() items.update(self.scan_variables) + items.update(self.unneeded_scan_variables) return items # state inputs and outputs are in pairs, even though some outputs are not depending on corresponding input, @@ -111,6 +117,16 @@ def scan_inputs(self): def scan_inputs_initial_values(self): return [i.data_input_id for i in self.tensor_array_inputs] + def has_variable_with_ta_type(self, tensor_array_type): + for variable in self.all_variables.values(): + if variable.tensor_array_type == tensor_array_type: + return True + return False + +class TensorArrayVariableType: + GATHER_ALL = "GATHER_ALL" + READ_LAST = "READ_LAST" + class LoopVariable(object): """In TensorFlow loop, all loop variables are listed both in iteration body graph's inputs, and outputs. Loop (state variable 1, state variable 2) { @@ -131,7 +147,7 @@ class LoopVariable(object): (e.g. switch_true_identity_output.id). """ def __init__(self, enter_name, enter_input_id, next_iteration_input_id, - switch_true_identity_output_id, exit_output_id, is_tensor_array, ta_index_id, g): + switch_true_identity_output_id, exit_output_id, tensor_array_type, ta_index_id, g): self.enter_name = enter_name self.enter_input_id = enter_input_id @@ -150,7 +166,7 @@ def __init__(self, enter_name, enter_input_id, next_iteration_input_id, self.exit_output = TensorValueInfo(exit_output_id, g) # only applicable for tensor array variable - self.is_tensor_array = is_tensor_array + self.tensor_array_type = tensor_array_type # todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration. # then we can be sure this is equivalent to scan output behavior. self.ta_index_id = ta_index_id @@ -189,7 +205,7 @@ def need_rewrite(self, context): def rewrite(self, context): return REWRITER_RESULT.FAIL - def run_internal(self): + def run_internal(self, allow_ta_read_last=False): loopcond_ops = [] for op in self.g.get_nodes(): if is_tf_loopcond_op(op): @@ -201,7 +217,11 @@ def run_internal(self): context = self.create_context() context.loop_cond = op - self._check_in_read_only_mode(context) + self._check_in_read_only_mode(context) # parses loop variables + + loop_properties = context.loop_properties + if not allow_ta_read_last and loop_properties.has_variable_with_ta_type(TensorArrayVariableType.READ_LAST): + continue if self.need_rewrite(context): # cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration. @@ -241,6 +261,12 @@ def _parse_loop_variables(self, context): loop_var = self._get_loop_var_from_switch(s) context.loop_properties.add_variable(loop_var) + for unneeded_scan_variable in context.loop_properties.unneeded_scan_variables.values(): + for state_variable in context.loop_properties.state_variables.values(): + if unneeded_scan_variable.next_iteration_input.id == state_variable.next_iteration_input.id: + unneeded_scan_variable.equivalent_state_variable = state_variable + break + def _parse_input_ta(self, context): graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values() if v.switch_true_identity_output.id] @@ -313,7 +339,7 @@ def _cut_off_connection_for_cell(self, context): n = self.g.get_node_by_output(val.switch_true_identity_output.id) self.g.remove_node(n.name) - if val.is_tensor_array: + if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL: # connect NextIteration to an invalid node, to cut off an ending node of the cell. ta_write_nodes = [n for n in self.g.get_nodes() if is_tf_tensor_array_write_op(n)] self.g.replace_all_inputs(val.next_iteration_input.id, INVALID_INPUT_ID, ops=ta_write_nodes) @@ -382,10 +408,9 @@ def _get_loop_var_from_switch(self, switch_node): else: raise ValueError("unexpected number of switch false consumers") - is_ta = False + ta_type = None ta_index_id = None if is_tf_tensor_array_op(self.g.get_node_by_output(target_node_input_id)): - is_ta = True ta_write_node = self.g.get_node_by_output(last_iteration_output_id) utils.make_sure(is_tf_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op") @@ -396,13 +421,19 @@ def _get_loop_var_from_switch(self, switch_node): # ta.write(), then ta.stack(), because this is the most frequent usage pattern. if exit_output_id: exit_consumers = self.g.find_output_consumers(exit_output_id) - ta_gather_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n)][0] + ta_access_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n) or \ + is_tf_tensor_array_read_op(n)][0] + + if is_tf_tensor_array_read_op(ta_access_node): + ta_type = TensorArrayVariableType.READ_LAST + else: + ta_type = TensorArrayVariableType.GATHER_ALL # update exit output id, treat the gather output as ta's output - exit_output_id = ta_gather_node.output[0] + exit_output_id = ta_access_node.output[0] loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id, - switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g) + switch_true_identity_output, exit_output_id, ta_type, ta_index_id, self.g) return loop_var diff --git a/tf2onnx/rewriter/rnn_utils.py b/tf2onnx/rewriter/rnn_utils.py index 8abd6a9c7..93bdc1c79 100644 --- a/tf2onnx/rewriter/rnn_utils.py +++ b/tf2onnx/rewriter/rnn_utils.py @@ -309,6 +309,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope): 1. iteration counter does not exist in tf1.4 or earlier versions 2. if dynamic_rnn's first input is not consumed, output ta does not exist. """ + from tf2onnx.rewriter.loop_rewriter_base import TensorArrayVariableType # pylint: disable=import-outside-toplevel time_name = rnn_scope + "time" ta_array_name_prefix = rnn_scope + "dynamic_rnn/output_" iteration_counter_name = while_context_scope + "iteration_counter" @@ -319,7 +320,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope): iteration_var = None for val in loop_properties.all_variables.values(): enter_input_node = graph.get_node_by_output(val.enter_input_id) - if val.is_tensor_array: + if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL: ta_name = enter_input_node.get_attr("tensor_array_name").s.decode("utf-8") if not ta_name.startswith(ta_array_name_prefix): is_rnn_out_ta = False diff --git a/tf2onnx/utils.py b/tf2onnx/utils.py index e609d7cf7..095ac6e5e 100644 --- a/tf2onnx/utils.py +++ b/tf2onnx/utils.py @@ -501,6 +501,9 @@ def is_tf_tensor_array_gather_op(op): def is_tf_tensor_array_write_op(op): return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3") +def is_tf_tensor_array_read_op(op): + return op.type in ("TensorArrayReadV2", "TensorArrayReadV3") + def is_tf_tensor_array_op(op): return op.type in ("TensorArrayV2", "TensorArrayV3")