Skip to content

Commit 3ad8cd6

Browse files
committed
Add test to verify pack/producer unpack/consumer fusion
1 parent 671829e commit 3ad8cd6

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// RUN: mlir-opt %s --transform-interpreter --split-input-file -canonicalize | FileCheck %s
2+
3+
// For pack op, we use lowerPadLikeWithInsertSlice = false to ensure no insert_slice is generated.
4+
// This allows linalg.transpose to be fused as a producer operation. Alternatively, without this attribute
5+
// insert_slice will be generated and fusion blocked.
6+
7+
module {
8+
// CHECK-label: func @fuse_pack_as_producer
9+
// CHECK: scf.forall {{.*}} {
10+
// CHECK: linalg.transpose
11+
// CHECK: linalg.generic
12+
// CHECK: scf.forall.in_parallel
13+
// CHECK: }
14+
func.func @fuse_pack_as_producer(%src: tensor<128x256xf32>, %other: tensor<4x4x128x256xf32>)
15+
-> tensor<4x4x128x256xf32> {
16+
%dest = tensor.empty() : tensor<1x1x128x256xf32>
17+
%pack = tensor.pack %src inner_dims_pos = [0, 1] inner_tiles = [128, 256]
18+
into %dest : tensor<128x256xf32> -> tensor<1x1x128x256xf32>
19+
20+
%out = tensor.empty() : tensor<4x4x128x256xf32>
21+
%res = linalg.generic
22+
{indexing_maps = [affine_map<(i, j, k, l) -> (0, 0, k, l)>,
23+
affine_map<(i, j, k, l) -> (i, j, k, l)>,
24+
affine_map<(i, j, k, l) -> (i, j, k, l)>],
25+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
26+
ins(%pack, %other: tensor<1x1x128x256xf32>, tensor<4x4x128x256xf32>)
27+
outs(%out: tensor<4x4x128x256xf32>) {
28+
^bb0(%pack_elem: f32, %other_elem: f32, %out_elem: f32):
29+
%r = arith.addf %pack_elem, %other_elem : f32
30+
linalg.yield %r : f32
31+
} -> tensor<4x4x128x256xf32>
32+
33+
return %res : tensor<4x4x128x256xf32>
34+
}
35+
36+
module attributes {transform.with_named_sequence} {
37+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
38+
// Find and lower pack operation.
39+
%pack = transform.structured.match ops{["tensor.pack"]} in %arg1
40+
: (!transform.any_op) -> !transform.op<"tensor.pack">
41+
%paded, %expanded, %transpose = transform.structured.lower_pack %pack {lowerPadLikeWithInsertSlice = false}
42+
: (!transform.op<"tensor.pack">)
43+
-> (!transform.op<"tensor.pad">,
44+
!transform.op<"tensor.expand_shape">,
45+
!transform.op<"linalg.transpose">)
46+
47+
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
48+
: (!transform.any_op) -> !transform.any_op
49+
// Tile the lialg operation with parallel forall loop tiling [4, 4].
50+
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
51+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
52+
53+
// Fuse the transpose operation into the tiled loop.
54+
transform.structured.fuse_into_containing_op %transpose into %forall_op
55+
: (!transform.op<"linalg.transpose">, !transform.any_op) -> (!transform.any_op, !transform.any_op)
56+
transform.yield
57+
}
58+
}
59+
}
60+
61+
// -----
62+
// For unpack op, we use lowerUnpadLikeWithExtractSlice = false to ensure no extract_slice is generated.
63+
// This allows linalg.transpose to be fused as a consumer operation. Alternatively, without this attribute
64+
// extract_slice will be generated and fusion blocked.
65+
66+
module {
67+
// CHECK-label: func @fuse_unpack_as_consumer
68+
// CHECK: scf.forall {{.*}} {
69+
// CHECK: linalg.generic
70+
// CHECK: linalg.transpose
71+
// CHECK: scf.forall.in_parallel
72+
// CHECK: }
73+
func.func @fuse_unpack_as_consumer(%src: tensor<4x4x128x256xf32>, %other: tensor<4x4x128x256xf32>)
74+
-> tensor<128x256xf32> {
75+
%out = tensor.empty() : tensor<1x1x128x256xf32>
76+
%res = linalg.generic
77+
{indexing_maps = [affine_map<(i, j, k, l) -> (i, j, k, l)>,
78+
affine_map<(i, j, k, l) -> (i, j, k, l)>,
79+
affine_map<(i, j, k, l) -> (0, 0, k, l)>],
80+
iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
81+
ins(%src, %other: tensor<4x4x128x256xf32>, tensor<4x4x128x256xf32>)
82+
outs(%out: tensor<1x1x128x256xf32>) {
83+
^bb0(%unpack_elem: f32, %other_elem: f32, %out_elem: f32):
84+
%r = arith.addf %unpack_elem, %other_elem : f32
85+
linalg.yield %r : f32
86+
} -> tensor<1x1x128x256xf32>
87+
88+
%dest = tensor.empty() : tensor<128x256xf32>
89+
%unpack = tensor.unpack %res inner_dims_pos = [0, 1] inner_tiles = [128, 256]
90+
into %dest : tensor<1x1x128x256xf32> -> tensor<128x256xf32>
91+
92+
return %unpack : tensor<128x256xf32>
93+
}
94+
95+
module attributes {transform.with_named_sequence} {
96+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
97+
// Find and lower unpack operation.
98+
%unpack = transform.structured.match ops{["tensor.unpack"]} in %arg1
99+
: (!transform.any_op) -> !transform.op<"tensor.unpack">
100+
transform.structured.lower_unpack %unpack {lowerUnpadLikeWithExtractSlice = false}
101+
: (!transform.op<"tensor.unpack">)
102+
-> (!transform.op<"tensor.empty">,
103+
!transform.op<"linalg.transpose">,
104+
!transform.op<"tensor.collapse_shape">,
105+
!transform.op<"tensor.extract_slice">)
106+
107+
%root = transform.structured.match ops{["linalg.generic"]} in %arg1
108+
: (!transform.any_op) -> !transform.any_op
109+
// Tile the lialg operation with parallel forall loop tiling [4, 4].
110+
%tiled_op, %forall_op = transform.structured.tile_using_forall %root num_threads [4, 4]
111+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
112+
113+
// Fuse the consumer operation into the tiled loop.
114+
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %forall_op
115+
: (!transform.any_op) -> !transform.op<"tensor.parallel_insert_slice">
116+
transform.test.fuse_consumer %slice_op
117+
: (!transform.op<"tensor.parallel_insert_slice">) -> (!transform.any_op, !transform.any_op)
118+
transform.yield
119+
}
120+
}
121+
}

0 commit comments

Comments
 (0)