From 7bf32d0f26f4a8ca14947083269329f705e4646b Mon Sep 17 00:00:00 2001 From: Tom Wildenhain Date: Thu, 8 Apr 2021 17:01:39 -0400 Subject: [PATCH] Fix TensorListStack and Loop shapes Signed-off-by: Tom Wildenhain --- tf2onnx/onnx_opset/controlflow.py | 23 +++++++++++++++++------ tf2onnx/onnx_opset/tensor.py | 8 ++++---- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index 81794e3b8..533d1c846 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -322,9 +322,14 @@ def version_7(cls, ctx, node, **kwargs): class TensorListStack: @classmethod def version_7(cls, ctx, node, **kwargs): - if node.inputs[0].is_while(): - ctx.remove_node(node.name) - ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes() + inp_node = node.inputs[0] + inp = node.input[0] + while inp_node.type == "Identity": + inp = inp_node.input[0] + inp_node = inp_node.inputs[0] + utils.make_sure(inp_node.is_while(), "Can only convert TensorListStack that is part of a While loop") + ctx.remove_node(node.name) + ctx.replace_all_inputs(node.output[0], inp) @tf_op(["While", "StatelessWhile"]) @@ -463,7 +468,7 @@ def version_7(cls, ctx, node, **kwargs): for k, v in output_map.items(): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() - wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes, + wire_while_body(ctx, body, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes, output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names) # if there was a tensorflow variant type, bind in a real type here @@ -473,7 +478,7 @@ def version_7(cls, ctx, node, **kwargs): body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i])) -def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes, +def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes, output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names): """Wire subgraph graph into main.""" remove_parents = [] @@ -496,7 +501,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64) g.inputs = [g.get_node_by_output(inp) for inp in func_inputs] - for p, c in zip(loop_node_inputs, func_inputs): + for p, c in zip(loop_node.inputs, func_inputs): shape = p.output_shapes[0] g.set_shape(c, shape) @@ -534,6 +539,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond # Reorder scan outputs scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names] + for i in range(-len(scan_output_names), 0): + # Use shapes from subgraph if loop node shapes for scan outputs are missing + if loop_node.output_shapes[i] is None: + shape = g.get_shape(scan_outputs[i]) + if shape is not None: + parent_g.set_shape(loop_node.output[i], [-1] + shape) # remove all nodes feeding to TensorListSetItem's reserved tensor while remove_parents: diff --git a/tf2onnx/onnx_opset/tensor.py b/tf2onnx/onnx_opset/tensor.py index ebb7e6eb1..d5e75bbe3 100644 --- a/tf2onnx/onnx_opset/tensor.py +++ b/tf2onnx/onnx_opset/tensor.py @@ -856,14 +856,14 @@ def version_1(cls, ctx, node, **kwargs): # insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs) # ctx.insert_new_node_on_output("Squeeze", node.output[0], name) name = utils.make_name(node.name) + shape = ctx.get_shape(node.output[0]) + dtype = ctx.get_dtype(node.output[0]) squeeze_node = GraphBuilder(ctx).make_squeeze( - {"axes": needs_squeeze, 'data': node.output[0]}, name=name, return_node=True) + {"axes": needs_squeeze, 'data': node.output[0]}, name=name, + dtypes=[dtype], shapes=[shape], return_node=True) ctx.insert_node_on_output(squeeze_node) nodes.append(squeeze_node) - input_dtype = ctx.get_dtype(node.output[0]) - ctx.set_dtype(squeeze_node.output[0], input_dtype) - ctx.copy_shape(node.output[0], squeeze_node.output[0]) ctx.update_node_shape_dtype(node, override=True) # onnx slice as of opset 7 does only take float tensors ... cast if needed