Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,14 @@ def version_7(cls, ctx, node, **kwargs):
class TensorListStack:
@classmethod
def version_7(cls, ctx, node, **kwargs):
if node.inputs[0].is_while():
ctx.remove_node(node.name)
ctx.replace_all_inputs(node.output[0], node.input[0]) # ops=ctx.get_nodes()
inp_node = node.inputs[0]
inp = node.input[0]
while inp_node.type == "Identity":
inp = inp_node.input[0]
inp_node = inp_node.inputs[0]
utils.make_sure(inp_node.is_while(), "Can only convert TensorListStack that is part of a While loop")
ctx.remove_node(node.name)
ctx.replace_all_inputs(node.output[0], inp)


@tf_op(["While", "StatelessWhile"])
Expand Down Expand Up @@ -463,7 +468,7 @@ def version_7(cls, ctx, node, **kwargs):
for k, v in output_map.items():
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()

wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
wire_while_body(ctx, body, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)

# if there was a tensorflow variant type, bind in a real type here
Expand All @@ -473,7 +478,7 @@ def version_7(cls, ctx, node, **kwargs):
body.set_dtype(n.output[0], ctx.get_dtype(loop_node.input[i]))


def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
"""Wire subgraph graph into main."""
remove_parents = []
Expand All @@ -496,7 +501,7 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64)
g.inputs = [g.get_node_by_output(inp) for inp in func_inputs]

for p, c in zip(loop_node_inputs, func_inputs):
for p, c in zip(loop_node.inputs, func_inputs):
shape = p.output_shapes[0]
g.set_shape(c, shape)

Expand Down Expand Up @@ -534,6 +539,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond

# Reorder scan outputs
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]
for i in range(-len(scan_output_names), 0):
# Use shapes from subgraph if loop node shapes for scan outputs are missing
if loop_node.output_shapes[i] is None:
shape = g.get_shape(scan_outputs[i])
if shape is not None:
parent_g.set_shape(loop_node.output[i], [-1] + shape)

# remove all nodes feeding to TensorListSetItem's reserved tensor
while remove_parents:
Expand Down
8 changes: 4 additions & 4 deletions tf2onnx/onnx_opset/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,14 +856,14 @@ def version_1(cls, ctx, node, **kwargs):
# insert_new_node_on_output(self, op_type, output_name=None, name=None, inputs=None, domain=None, **kwargs)
# ctx.insert_new_node_on_output("Squeeze", node.output[0], name)
name = utils.make_name(node.name)
shape = ctx.get_shape(node.output[0])
dtype = ctx.get_dtype(node.output[0])
squeeze_node = GraphBuilder(ctx).make_squeeze(
{"axes": needs_squeeze, 'data': node.output[0]}, name=name, return_node=True)
{"axes": needs_squeeze, 'data': node.output[0]}, name=name,
dtypes=[dtype], shapes=[shape], return_node=True)
ctx.insert_node_on_output(squeeze_node)

nodes.append(squeeze_node)
input_dtype = ctx.get_dtype(node.output[0])
ctx.set_dtype(squeeze_node.output[0], input_dtype)
ctx.copy_shape(node.output[0], squeeze_node.output[0])
ctx.update_node_shape_dtype(node, override=True)

# onnx slice as of opset 7 does only take float tensors ... cast if needed
Expand Down