Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ci_build/azure_pipelines/keras2onnx_unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ jobs:
INSTALL_ORT: pip install onnxruntime==1.8.0

############ Pure Keras Unit Tests ############
# Keras-Py36-tf1.15.0: # Failing, will enable soon.
# python.version: '3.6'
# ONNX_PATH: onnx==1.5.0
# KERAS: keras==2.2.5
# TENSORFLOW_PATH: tensorflow==1.15.0
# INSTALL_ORT: pip install onnxruntime==1.8.0
Keras-Py36-tf1.15.0:
python.version: '3.6'
ONNX_PATH: onnx==1.5.0
KERAS: keras==2.2.5
TENSORFLOW_PATH: tensorflow==1.15.0
INSTALL_ORT: pip install onnxruntime==1.8.0

Keras-Py37-tf1.15.0:
python.version: '3.7'
Expand Down
10 changes: 8 additions & 2 deletions tf2onnx/rewriter/loop_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_context(self):

def run(self):
logger.debug("enter loop rewriter")
return self.run_internal()
return self.run_internal(allow_ta_read_last=True)

def need_rewrite(self, context):
return True
Expand Down Expand Up @@ -93,6 +93,10 @@ def rewrite(self, context):
logger.error("failed to create loop node during rewrite")
return REWRITER_RESULT.FAIL

for unneeded_scan_variable in loop_props.unneeded_scan_variables.values():
self.g.replace_all_inputs(unneeded_scan_variable.exit_output.id,
unneeded_scan_variable.equivalent_state_variable.exit_output.id)

logger.debug("rewrite successfully")
return REWRITER_RESULT.OK

Expand Down Expand Up @@ -152,7 +156,9 @@ def _create_loop_node(self, context, loop_props, init_cond_output, branches=None
n = self.g.get_node_by_output(tensor_value_info.id)
self.g.remove_node(n.name)
else:
loop_outputs.append(utils.make_name("unused_loop_output_"))
output_name = utils.make_name("unused_loop_output_")
tensor_value_info.id = output_name
loop_outputs.append(output_name)
loop_output_shapes.append([-1])
loop_output_dtypes.append(None)

Expand Down
59 changes: 45 additions & 14 deletions tf2onnx/rewriter/loop_rewriter_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tf2onnx import utils
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
from tf2onnx.utils import is_tf_loopcond_op, is_tf_tensor_array_op
from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op
from tf2onnx.utils import is_tf_tensor_array_gather_op, is_tf_tensor_array_write_op, is_tf_tensor_array_read_op
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
from tf2onnx.utils import TensorValueInfo

Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(self):
# used as initial input for more than one Enter nodes.
self.state_variables = OrderedDict()
self.scan_variables = OrderedDict()
self.unneeded_scan_variables = OrderedDict()

self.tensor_array_inputs = [] # list of type InputTensorArray

Expand All @@ -55,10 +56,14 @@ def add_variable(self, var):
"variable %s already exists as scan variable.", var.enter_name)
utils.make_sure(var.enter_name not in self.state_variables,
"variable %s already exists as state variable.", var.enter_name)
if not var.is_tensor_array:
self.state_variables[var.enter_name] = var
else:
if var.tensor_array_type == TensorArrayVariableType.READ_LAST:
# If the variable just returns the last value of the constructed tensor array, it doesn't need to be
# a scan output
self.unneeded_scan_variables[var.enter_name] = var
elif var.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
self.scan_variables[var.enter_name] = var
else:
self.state_variables[var.enter_name] = var

def get_variables(self, checker):
if not checker:
Expand All @@ -69,6 +74,7 @@ def get_variables(self, checker):
def all_variables(self):
items = self.state_variables.copy()
items.update(self.scan_variables)
items.update(self.unneeded_scan_variables)
return items

# state inputs and outputs are in pairs, even though some outputs are not depending on corresponding input,
Expand Down Expand Up @@ -111,6 +117,16 @@ def scan_inputs(self):
def scan_inputs_initial_values(self):
return [i.data_input_id for i in self.tensor_array_inputs]

def has_variable_with_ta_type(self, tensor_array_type):
for variable in self.all_variables.values():
if variable.tensor_array_type == tensor_array_type:
return True
return False

class TensorArrayVariableType:
GATHER_ALL = "GATHER_ALL"
READ_LAST = "READ_LAST"

class LoopVariable(object):
"""In TensorFlow loop, all loop variables are listed both in iteration body graph's inputs, and outputs.
Loop (state variable 1, state variable 2) {
Expand All @@ -131,7 +147,7 @@ class LoopVariable(object):
(e.g. switch_true_identity_output.id).
"""
def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
switch_true_identity_output_id, exit_output_id, is_tensor_array, ta_index_id, g):
switch_true_identity_output_id, exit_output_id, tensor_array_type, ta_index_id, g):
self.enter_name = enter_name
self.enter_input_id = enter_input_id

Expand All @@ -150,7 +166,7 @@ def __init__(self, enter_name, enter_input_id, next_iteration_input_id,
self.exit_output = TensorValueInfo(exit_output_id, g)

# only applicable for tensor array variable
self.is_tensor_array = is_tensor_array
self.tensor_array_type = tensor_array_type
# todo: need check ta's index variable is a scalar starting from 1, and increase by 1 each iteration.
# then we can be sure this is equivalent to scan output behavior.
self.ta_index_id = ta_index_id
Expand Down Expand Up @@ -189,7 +205,7 @@ def need_rewrite(self, context):
def rewrite(self, context):
return REWRITER_RESULT.FAIL

def run_internal(self):
def run_internal(self, allow_ta_read_last=False):
loopcond_ops = []
for op in self.g.get_nodes():
if is_tf_loopcond_op(op):
Expand All @@ -201,7 +217,11 @@ def run_internal(self):
context = self.create_context()
context.loop_cond = op

self._check_in_read_only_mode(context)
self._check_in_read_only_mode(context) # parses loop variables

loop_properties = context.loop_properties
if not allow_ta_read_last and loop_properties.has_variable_with_ta_type(TensorArrayVariableType.READ_LAST):
continue

if self.need_rewrite(context):
# cut off connection between cell/cond graphs and useless nodes like Merge, NextIteration.
Expand Down Expand Up @@ -241,6 +261,12 @@ def _parse_loop_variables(self, context):
loop_var = self._get_loop_var_from_switch(s)
context.loop_properties.add_variable(loop_var)

for unneeded_scan_variable in context.loop_properties.unneeded_scan_variables.values():
for state_variable in context.loop_properties.state_variables.values():
if unneeded_scan_variable.next_iteration_input.id == state_variable.next_iteration_input.id:
unneeded_scan_variable.equivalent_state_variable = state_variable
break

def _parse_input_ta(self, context):
graph_inputs = [v.switch_true_identity_output.id for v in context.loop_properties.all_variables.values()
if v.switch_true_identity_output.id]
Expand Down Expand Up @@ -313,7 +339,7 @@ def _cut_off_connection_for_cell(self, context):
n = self.g.get_node_by_output(val.switch_true_identity_output.id)
self.g.remove_node(n.name)

if val.is_tensor_array:
if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
# connect NextIteration to an invalid node, to cut off an ending node of the cell.
ta_write_nodes = [n for n in self.g.get_nodes() if is_tf_tensor_array_write_op(n)]
self.g.replace_all_inputs(val.next_iteration_input.id, INVALID_INPUT_ID, ops=ta_write_nodes)
Expand Down Expand Up @@ -382,10 +408,9 @@ def _get_loop_var_from_switch(self, switch_node):
else:
raise ValueError("unexpected number of switch false consumers")

is_ta = False
ta_type = None
ta_index_id = None
if is_tf_tensor_array_op(self.g.get_node_by_output(target_node_input_id)):
is_ta = True

ta_write_node = self.g.get_node_by_output(last_iteration_output_id)
utils.make_sure(is_tf_tensor_array_write_op(ta_write_node), "ta nextiteration is not following ta write op")
Expand All @@ -396,13 +421,19 @@ def _get_loop_var_from_switch(self, switch_node):
# ta.write(), then ta.stack(), because this is the most frequent usage pattern.
if exit_output_id:
exit_consumers = self.g.find_output_consumers(exit_output_id)
ta_gather_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n)][0]
ta_access_node = [n for n in exit_consumers if is_tf_tensor_array_gather_op(n) or \
is_tf_tensor_array_read_op(n)][0]

if is_tf_tensor_array_read_op(ta_access_node):
ta_type = TensorArrayVariableType.READ_LAST
else:
ta_type = TensorArrayVariableType.GATHER_ALL

# update exit output id, treat the gather output as ta's output
exit_output_id = ta_gather_node.output[0]
exit_output_id = ta_access_node.output[0]

loop_var = LoopVariable(enter_node.name, target_node_input_id, last_iteration_output_id,
switch_true_identity_output, exit_output_id, is_ta, ta_index_id, self.g)
switch_true_identity_output, exit_output_id, ta_type, ta_index_id, self.g)

return loop_var

Expand Down
3 changes: 2 additions & 1 deletion tf2onnx/rewriter/rnn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope):
1. iteration counter does not exist in tf1.4 or earlier versions
2. if dynamic_rnn's first input is not consumed, output ta does not exist.
"""
from tf2onnx.rewriter.loop_rewriter_base import TensorArrayVariableType # pylint: disable=import-outside-toplevel
time_name = rnn_scope + "time"
ta_array_name_prefix = rnn_scope + "dynamic_rnn/output_"
iteration_counter_name = while_context_scope + "iteration_counter"
Expand All @@ -319,7 +320,7 @@ def parse_rnn_loop(graph, loop_properties, rnn_scope, while_context_scope):
iteration_var = None
for val in loop_properties.all_variables.values():
enter_input_node = graph.get_node_by_output(val.enter_input_id)
if val.is_tensor_array:
if val.tensor_array_type == TensorArrayVariableType.GATHER_ALL:
ta_name = enter_input_node.get_attr("tensor_array_name").s.decode("utf-8")
if not ta_name.startswith(ta_array_name_prefix):
is_rnn_out_ta = False
Expand Down
3 changes: 3 additions & 0 deletions tf2onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,9 @@ def is_tf_tensor_array_gather_op(op):
def is_tf_tensor_array_write_op(op):
return op.type in ("TensorArrayWriteV2", "TensorArrayWriteV3")

def is_tf_tensor_array_read_op(op):
return op.type in ("TensorArrayReadV2", "TensorArrayReadV3")


def is_tf_tensor_array_op(op):
return op.type in ("TensorArrayV2", "TensorArrayV3")
Expand Down