diff --git a/tests/test_loops.py b/tests/test_loops.py index 4c8f46d8b..fdbad9392 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -188,6 +188,40 @@ def b(i, out_ta): output_names_with_port = ["i:0", "output_ta:0"] self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) + def test_while_loop_with_multi_scan_outputs(self): + def func(i, inputs1, inputs2): + inputs1_ = tf.identity(inputs1) + inputs2_ = tf.identity(inputs2) + input_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs1_) + input_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True).unstack(inputs2_) + output_ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) + output_ta2 = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True) + + c = lambda i, *_: tf.logical_and(tf.less(i, 10), i >= 0) + + def b(i, out_ta, out_ta2): + new_i = tf.add(i, 1) + x = input_ta.read(i) + y = input_ta2.read(i) + z = x + 3 + y + p = x * y * 2 + out_ta_new = out_ta.write(i, z) + out_ta_new2 = out_ta2.write(i, p) + return new_i, out_ta_new, out_ta_new2 + + i_final, out_final, out_final2 = tf.while_loop(c, b, [i, output_ta, output_ta2]) + i_final_ = tf.identity(i_final, name="i") + out_final_ = tf.identity(out_final.stack(), name="output_ta") + out_final2_ = tf.identity(out_final2.stack(), name="output_ta2") + return i_final_, out_final_, out_final2_ + + input_names_with_port = ["input_1:0", "input_2:0", "input_3:0"] + feed_dict = {"input_1:0": np.array(0, dtype=np.int32), + "input_2:0": np.array([2.0, 16.0, 5.0, 1.6, 5.0, 6.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32), + "input_3:0": np.array([1.0, 2.0, 3.0, 4.0, 5.0, 16.0, 7.0, 8.0, 9.0, 10.], dtype=np.float32)} + output_names_with_port = ["i:0", "output_ta:0", "output_ta2:0"] + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06) + @check_onnxruntime_min_version( "0.5.0", "disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272" diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index ac5787f02..904cab997 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -421,7 +421,7 @@ def version_7(cls, ctx, node, **kwargs): del output_names[idx] del body.outputs[idx] - removed_scan_outputs = {} + scan_output_names = [] # remove tensor array that are passed in to the loop for idx, n in reversed(to_remove): ctx.remove_node(n.name) @@ -430,10 +430,8 @@ def version_7(cls, ctx, node, **kwargs): del body.func_inputs[idx] del cond_graph.func_inputs[idx] del tf_while_inputs[idx] - # save the index of the scan output - removed_scan_outputs[body.outputs[idx]] = idx + scan_output_names.append(body.outputs[idx]) del body.outputs[idx] - # FIXME: Output shapes may be in wrong order if there are multiple scan outputs output_shapes.append(output_shapes[idx]) output_dtypes.append(output_dtypes[idx]) output_names.append(output_names[idx]) @@ -441,8 +439,6 @@ def version_7(cls, ctx, node, **kwargs): del output_dtypes[idx] del output_names[idx] - utils.make_sure(len(removed_scan_outputs) <= 1, "converter only supports while loops with a single scan output") - ctx.remove_node(node.name) # In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph @@ -467,7 +463,7 @@ def version_7(cls, ctx, node, **kwargs): ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes() wire_while_body(ctx, body, loop_node.inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes, - output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, removed_scan_outputs) + output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names) # if there was a tensorflow variant type, bind in a real type here # FIXME: I don't think this is needed anymore @@ -477,7 +473,7 @@ def version_7(cls, ctx, node, **kwargs): def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond_input_to_state_var, output_shapes, - output_dtypes, scope, parent, cond_graph, tf_while_inputs, removed_scan_outputs): + output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names): """Wire subgraph graph into main.""" remove_parents = [] to_remove = [] @@ -521,9 +517,10 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond g.replace_inputs(node, [node.input[2]]) scan_outputs.append(node.output[0]) - if len(scan_outputs) != len(removed_scan_outputs): + if len(scan_outputs) != len(scan_output_names): raise ValueError("While loop couldn't find scan output index for nodes") + names_to_scan_outputs = {} for output in scan_outputs: last_output = output consumers = g.find_output_consumers(last_output) @@ -533,10 +530,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond raise ValueError("While loop couldn't find scan output index for node " + node.name) last_output = node.output[0] consumers = g.find_output_consumers(last_output) - if last_output not in removed_scan_outputs: + if last_output not in scan_output_names: raise ValueError("While loop couldn't find scan output index for node " + node.name) - # TODO: store index to ensure scan outputs are in correct order for multiple outputs - # initial_output_index = removed_scan_outputs[last_output] + names_to_scan_outputs[last_output] = output + + # Reorder scan outputs + scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names] # remove all nodes feeding to TensorListSetItem's reserved tensor while remove_parents: