Skip to content

Commit 57ef758

Browse files
Fix TensorListStack and Loop shapes (#1448)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent eadc613 commit 57ef758

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

tf2onnx/onnx_opset/controlflow.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,14 @@ def version_7(cls, ctx, node, **kwargs):
322322
class TensorListStack:
323323
@classmethod
324324
def version_7(cls, ctx, node, **kwargs):
325-
if node.inputs[0].is_while():
326-
ctx.remove_node(node.name)
327-
ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes()
325+
inp_node = node.inputs[0]
326+
inp = node.input[0]
327+
while inp_node.type == "Identity":
328+
inp = inp_node.input[0]
329+
inp_node = inp_node.inputs[0]
330+
utils.make_sure(inp_node.is_while(), "Can only convert TensorListStack that is part of a While loop")
331+
ctx.remove_node(node.name)
332+
ctx.replace_all_inputs(node.output[0], inp)
328333

329334

330335
@tf_op(["While", "StatelessWhile"])
@@ -463,7 +468,7 @@ def version_7(cls, ctx, node, **kwargs):
463468
for k, v in output_map.items():
464469
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
465470

466-
wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
471+
wire_while_body(ctx, body, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
467472
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)
468473

469474
# if there was a tensorflow variant type, bind in a real type here
@@ -473,7 +478,7 @@ def version_7(cls, ctx, node, **kwargs):
473478
body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))
474479

475480

476-
def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
481+
def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
477482
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
478483
"""Wire subgraph graph into main."""
479484
remove_parents = []
@@ -496,7 +501,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
496501
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
497502
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]
498503

499-
for p, c in zip(loop_node_inputs, func_inputs):
504+
for p, c in zip(loop_node.inputs, func_inputs):
500505
shape = p.output_shapes[0]
501506
g.set_shape(c, shape)
502507

@@ -534,6 +539,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
534539

535540
# Reorder scan outputs
536541
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]
542+
for i in range(-len(scan_output_names), 0):
543+
# Use shapes from subgraph if loop node shapes for scan outputs are missing
544+
if loop_node.output_shapes[i] is None:
545+
shape = g.get_shape(scan_outputs[i])
546+
if shape is not None:
547+
parent_g.set_shape(loop_node.output[i], [-1] + shape)
537548

538549
# remove all nodes feeding to TensorListSetItem's reserved tensor
539550
while remove_parents:

tf2onnx/onnx_opset/tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -856,14 +856,14 @@ def version_1(cls, ctx, node, **kwargs):
856856
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)
857857
# ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
858858
name = utils.make_name(node.name)
859+
shape = ctx.get_shape(node.output[0])
860+
dtype = ctx.get_dtype(node.output[0])
859861
squeeze_node = GraphBuilder(ctx).make_squeeze(
860-
{"axes": needs_squeeze, 'data': node.output[0]}, name=name, return_node=True)
862+
{"axes": needs_squeeze, 'data': node.output[0]}, name=name,
863+
dtypes=[dtype], shapes=[shape], return_node=True)
861864
ctx.insert_node_on_output(squeeze_node)
862865

863866
nodes.append(squeeze_node)
864-
input_dtype = ctx.get_dtype(node.output[0])
865-
ctx.set_dtype(squeeze_node.output[0], input_dtype)
866-
ctx.copy_shape(node.output[0], squeeze_node.output[0])
867867
ctx.update_node_shape_dtype(node, override=True)
868868

869869
# onnx slice as of opset 7 does only take float tensors ... cast if needed

0 commit comments

Comments
 (0)