Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,40 @@ def b(i, out_ta):
output_names_with_port = ["i:0", "output_ta:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

def test_while_loop_with_multi_scan_outputs(self):
def func(i, inputs1, inputs2):
inputs1_ = tf.identity(inputs1)
inputs2_ = tf.identity(inputs2)
input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs1_)
input_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs2_)
output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
output_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0)

def b(i, out_ta, out_ta2):
new_i = tf.add(i, 1)
x = input_ta.read(i)
y = input_ta2.read(i)
z = x + 3 + y
p = x * y * 2
out_ta_new = out_ta.write(i, z)
out_ta_new2 = out_ta2.write(i, p)
return new_i, out_ta_new, out_ta_new2

i_final, out_final, out_final2 = tf.while_loop(c, b, [i, output_ta, output_ta2])
i_final_ = tf.identity(i_final, name="i")
out_final_ = tf.identity(out_final.stack(), name="output_ta")
out_final2_ = tf.identity(out_final2.stack(), name="output_ta2")
return i_final_, out_final_, out_final2_

input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"]
feed_dict = {"input_1:0": np.array(0, dtype=np.int32),
"input_2:0": np.array([2.0, 16.0, 5.0, 1.6, 5.0, 6.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32),
"input_3:0": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 16.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32)}
output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"]
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)

@check_onnxruntime_min_version(
"0.5.0",
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"
Expand Down
23 changes: 11 additions & 12 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def version_7(cls, ctx, node, **kwargs):
del output_names[idx]
del body.outputs[idx]

removed_scan_outputs = {}
scan_output_names = []
# remove tensor array that are passed in to the loop
for idx, n in reversed(to_remove):
ctx.remove_node(n.name)
Expand All @@ -430,19 +430,15 @@ def version_7(cls, ctx, node, **kwargs):
del body.func_inputs[idx]
del cond_graph.func_inputs[idx]
del tf_while_inputs[idx]
# save the index of the scan output
removed_scan_outputs[body.outputs[idx]] = idx
scan_output_names.append(body.outputs[idx])
del body.outputs[idx]
# FIXME: Output shapes may be in wrong order if there are multiple scan outputs
output_shapes.append(output_shapes[idx])
output_dtypes.append(output_dtypes[idx])
output_names.append(output_names[idx])
del output_shapes[idx]
del output_dtypes[idx]
del output_names[idx]

utils.make_sure(len(removed_scan_outputs) <= 1, "converter only supports while loops with a single scan output")

ctx.remove_node(node.name)

# In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
Expand All @@ -467,7 +463,7 @@ def version_7(cls, ctx, node, **kwargs):
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()

wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs)
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)

# if there was a tensorflow variant type, bind in a real type here
# FIXME: I don't think this is needed anymore
Expand All @@ -477,7 +473,7 @@ def version_7(cls, ctx, node, **kwargs):


def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes,
output_dtypes, scope, parent, cond_graph, tf_while_inputs, removed_scan_outputs):
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
"""Wire subgraph graph into main."""
remove_parents = []
to_remove = []
Expand Down Expand Up @@ -521,9 +517,10 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
g.replace_inputs(node, [node.input[2]])
scan_outputs.append(node.output[0])

if len(scan_outputs) != len(removed_scan_outputs):
if len(scan_outputs) != len(scan_output_names):
raise ValueError("While loop couldn't find scan output index for nodes")

names_to_scan_outputs = {}
for output in scan_outputs:
last_output = output
consumers = g.find_output_consumers(last_output)
Expand All @@ -533,10 +530,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
raise ValueError("While loop couldn't find scan output index for node " + node.name)
last_output = node.output[0]
consumers = g.find_output_consumers(last_output)
if last_output not in removed_scan_outputs:
if last_output not in scan_output_names:
raise ValueError("While loop couldn't find scan output index for node " + node.name)
# TODO: store index to ensure scan outputs are in correct order for multiple outputs
# initial_output_index = removed_scan_outputs[last_output]
names_to_scan_outputs[last_output] = output

# Reorder scan outputs
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]

# remove all nodes feeding to TensorListSetItem's reserved tensor
while remove_parents:
Expand Down