Skip to content

While loop wrong input shape with TF shape invariance  #2202

@f-salvetti

Description

@f-salvetti

Describe the bug
When converting a TF While node, if the the while loop has some shape invariance and the initial tensor has fully defined shape, the converter assumes the first shape as fixed, causing wrong inference behavior (infinite loop).

To Reproduce

class Model(tf.Module):
    @tf.function(input_signature=[
        tf.TensorSpec(shape=(1,), dtype=tf.int32, name="x"),
    ])
    def call(self, x):
        const = tf.constant(np.array([2], dtype=np.int32))
        c = lambda x: tf.reduce_all(tf.shape(x) < 10)
        b = lambda x: tf.concat([x, const], 0)
        r = tf.while_loop(c, b, [x], shape_invariants=[tf.TensorShape([None])])
        return tf.identity(r, name="output")
        
    def __call__(self, *args, **kwargs):
        return self.call(*args, **kwargs)

model = Model()
tf.saved_model.save(model, export_dir="test_model", signatures=model.call.get_concrete_function())

When called with onnxruntime it enters and infinite loop and warns:

2023-07-10 11:01:43.855159751 [W:onnxruntime:, execution_frame.cc:835 VerifyOutputSizes] Expected shape from model of {2} does not match actual shape of {3} for output while/concat:0
2023-07-10 11:01:43.855182551 [W:onnxruntime:, execution_frame.cc:835 VerifyOutputSizes] Expected shape from model of {2} does not match actual shape of {4} for output while/concat:0

and so on

Screenshots
image

ONNX assumes the input to be of fixed (1,) shape instead of (-1,)

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugAn unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions