diff --git a/tests/test_loops.py b/tests/test_loops.py index 1f374ac3a..0260afa2d 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -7,7 +7,7 @@ import tensorflow as tf from backend_test_base import Tf2OnnxBackendTestBase -from common import unittest_main, check_tf_min_version, check_tf_max_version, \ +from common import unittest_main, check_tf_min_version, \ check_onnxruntime_min_version, check_tfjs_max_version, skip_tflite from tf2onnx.tf_loader import is_tf2 @@ -286,15 +286,13 @@ def func(x, y): self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-5) @check_tf_min_version("1.9") - @check_tf_max_version("1.15") + @skip_tflite("infinite loop with tflite") def test_simple_while_loop_var_shape(self): # test for while_loop with variant shape variables - # may not meet ONNX Loop spec - # Note: this is not working on tf2 itself. def func(i): const = tf.constant(np.array([2], dtype=np.int32)) c = lambda i: tf.reduce_all(tf.shape(i) < 10) - b = lambda i: tf.concat([i, const], 0) + b = lambda i: [tf.concat([i, const], 0)] r = tf.while_loop(c, b, [i], shape_invariants=[tf.TensorShape([None])]) return tf.identity(r, name="output") input_names_with_port = ["input_1:0"] diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index 03697948a..4f577facc 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -571,7 +571,8 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_ g.set_dtype(func_inputs[0], onnx_pb.TensorProto.INT64) g.inputs = [g.get_node_by_output(inp) for inp in func_inputs] - for p, c in zip(loop_node.input, func_inputs): + # we should use outputs shape, not inputs, since there may be shape invariants + for p, c in zip(loop_node.output, func_inputs[2:]): g.copy_shape(p, c) for i, node in enumerate(g.inputs):