diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 61ebc63a9..7ab159d0e 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -116,6 +116,52 @@ def test_transpose_with_concat(self, input_shape, perm, inner_perm): } self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=1) + @parameterized.expand([ + ((2, 3, 4, 5), [0, 3, 1, 2], [0, 2, 3, 1]), + ((2, 3, 4, 5, 6), [0, 4, 1, 2, 3], [0, 2, 3, 4, 1]), + ((2, 3, 4, 5, 6), [0, 2, 3, 4, 1], [0, 4, 1, 2, 3]), + ]) + def test_transpose_with_split(self, input_shape, perm, inner_perm): + input_shape_with_trans = [input_shape[i] for i in perm] + output_before_trans = list(input_shape) + output_shape = [output_before_trans[i] for i in perm] + for axis in range(len(input_shape)): + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=inner_perm, name="trans1") + node2 = helper.make_node("Split", ["Y"], ["Z"], axis=axis, name="split") + node3 = helper.make_node("Transpose", ["Z"], ["res"], perm=perm, name="trans2") + + graph = helper.make_graph( + [node1, node2, node3], + "test_transpose_with_split", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape_with_trans)], + [helper.make_tensor_value_info("res", TensorProto.FLOAT, output_shape)], + ) + + model_proto = self.make_model(graph, producer_name="onnx-tests") + feed_dict = {"X": np.random.randn(*input_shape_with_trans).astype(np.float32)} + self.run_transpose_compare(["res"], feed_dict, model_proto, remaining_transpose_num=0) + + @parameterized.expand([ + ((1, -1), (1, 1710), (1710,), [1, 0]), + ((3, 1, 1, 5, -1), (3, 1, 1, 5, 6), (3, 5, 6), [0, 2, 3, 4, 1]), + ]) + @check_opset_max_version(12, "split attribute changed to input in opset 13") + def test_transpose_with_split_dynamic_shape(self, input_shape, specific_input, output_shape, perm): + node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans") + node2 = helper.make_node("Split", ["Y"], ["Z"], axis=1, split=[1], name="split") + node3 = helper.make_node("Squeeze", ["Z"], ["B"], name="squeeze") + + graph = helper.make_graph( + [node1, node2, node3], + "test_transpose_with_split_dynamic_shape", + [helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)], + [helper.make_tensor_value_info("B", TensorProto.FLOAT, output_shape)], + ) + + model_proto = self.make_model(graph, producer_name="onnx-tests") + self.run_transpose_compare(["B"], {"X": np.random.randn(*specific_input).astype(np.float32)}, + model_proto, remaining_transpose_num=0) + @parameterized.expand([ ((2, 3, 4), [2, 0, 1], [1, 2, 0]), ((2, 3, 4, 5), [0, 2, 3, 1], [0, 3, 1, 2]),