1
1
// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
2
2
3
+ //-----------------------------------------------------------------------------
4
+ // [Patterns: TransferWriteDropUnitDimsPattern, TransferReadeDropUnitDimsPattern]
5
+ //-----------------------------------------------------------------------------
6
+
3
7
func.func @transfer_read_rank_reducing (
4
8
%arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>) -> vector <3 x2 xi8 > {
5
9
%c0 = arith.constant 0 : index
@@ -14,7 +18,29 @@ func.func @transfer_read_rank_reducing(
14
18
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
15
19
// CHECK: vector.transfer_read %[[SUBVIEW]]
16
20
17
- func.func @transfer_write_rank_reducing (%arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>, %vec : vector <3 x2 xi8 >) {
21
+ func.func @transfer_read_rank_reducing_masked (
22
+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
23
+ %mask: vector <3 x2 xi1 >) -> vector <3 x2 xi8 > {
24
+ %c0 = arith.constant 0 : index
25
+ %cst = arith.constant 0 : i8
26
+ %v = vector.mask %mask {
27
+ vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 ], %cst :
28
+ memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>, vector <3 x2 xi8 >
29
+ } : vector <3 x2 xi1 > -> vector <3 x2 xi8 >
30
+ return %v : vector <3 x2 xi8 >
31
+ }
32
+ // CHECK-LABEL: func @transfer_read_rank_reducing_masked
33
+ // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
34
+ // CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
35
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
36
+ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
37
+ // CHECK: vector.mask %[[MASK]]
38
+ // CHECK-SAME: vector.transfer_read %[[SUBVIEW]]
39
+
40
+ func.func @transfer_write_rank_reducing (
41
+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
42
+ %vec : vector <3 x2 xi8 >) {
43
+
18
44
%c0 = arith.constant 0 : index
19
45
vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
20
46
vector <3 x2 xi8 >, memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>
@@ -26,6 +52,26 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6,
26
52
// CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
27
53
// CHECK: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
28
54
55
+ func.func @transfer_write_rank_reducing_masked (
56
+ %arg : memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>,
57
+ %vec : vector <3 x2 xi8 >,
58
+ %mask: vector <3 x2 xi1 >) {
59
+ %c0 = arith.constant 0 : index
60
+ vector.mask %mask {
61
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 ] :
62
+ vector <3 x2 xi8 >, memref <1 x1 x3 x2 xi8 , strided <[6 , 6 , 2 , 1 ], offset : ?>>
63
+ } : vector <3 x2 xi1 >
64
+ return
65
+ }
66
+ // CHECK-LABEL: func @transfer_write_rank_reducing_masked
67
+ // CHECK-SAME: %[[ARG:.+]]: memref<1x1x3x2xi8
68
+ // CHECK-SAME: %[[VEC:.+]]: vector<3x2xi8>
69
+ // CHECK-SAME: %[[MASK:.+]]: vector<3x2xi1>
70
+ // CHECK: %[[SUBVIEW:.+]] = memref.subview %[[ARG]][0, 0, 0, 0] [1, 1, 3, 2] [1, 1, 1, 1]
71
+ // CHECK-SAME: memref<1x1x3x2xi8, {{.*}}> to memref<3x2xi8, {{.*}}>
72
+ // CHECK: vector.mask %[[MASK]]
73
+ // CHECK-SAME: vector.transfer_write %{{.*}}, %[[SUBVIEW]]
74
+
29
75
func.func @transfer_read_and_vector_rank_reducing (
30
76
%arg : memref <1 x1 x3 x2 x1 xf32 >) -> vector <3 x2 x1 xf32 > {
31
77
%c0 = arith.constant 0 : index
@@ -68,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
68
114
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
69
115
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
70
116
117
+ func.func @transfer_read_and_vector_rank_reducing_to_0d_masked (
118
+ %arg : memref <1 x1 x1 x1 x1 xf32 >,
119
+ %mask: vector <1 x1 x1 xi1 >) -> vector <1 x1 x1 xf32 > {
120
+
121
+ %c0 = arith.constant 0 : index
122
+ %cst = arith.constant 0.0 : f32
123
+ %v = vector.mask %mask {
124
+ vector.transfer_read %arg [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst
125
+ : memref <1 x1 x1 x1 x1 xf32 >, vector <1 x1 x1 xf32 >
126
+ } : vector <1 x1 x1 xi1 > -> vector <1 x1 x1 xf32 >
127
+ return %v : vector <1 x1 x1 xf32 >
128
+ }
129
+ // CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
130
+ // CHECK-NOT: vector.shape_cast
131
+ // CHECK-NOT: memref.subview
132
+
71
133
func.func @transfer_write_and_vector_rank_reducing_to_0d (
72
134
%arg : memref <1 x1 x1 x1 x1 xf32 >,
73
135
%vec : vector <1 x1 x1 xf32 >) {
@@ -82,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
82
144
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
83
145
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
84
146
147
+ func.func @transfer_write_and_vector_rank_reducing_to_0d_masked (
148
+ %arg : memref <1 x1 x1 x1 x1 xf32 >,
149
+ %vec : vector <1 x1 x1 xf32 >,
150
+ %mask: vector <1 x1 x1 xi1 >) {
151
+
152
+ %c0 = arith.constant 0 : index
153
+ %cst = arith.constant 0.0 : f32
154
+ vector.mask %mask {
155
+ vector.transfer_write %vec , %arg [%c0 , %c0 , %c0 , %c0 , %c0 ] :
156
+ vector <1 x1 x1 xf32 >, memref <1 x1 x1 x1 x1 xf32 >
157
+ } : vector <1 x1 x1 xi1 >
158
+ return
159
+ }
160
+ // CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
161
+ // CHECK-NOT: vector.shape_cast
162
+ // CHECK-NOT: memref.subview
163
+
85
164
func.func @transfer_read_dynamic_rank_reducing (
86
165
%arg : memref <?x1 xi8 , strided <[?, ?], offset : ?>>) -> vector <[16 ]x1 xi8 > {
87
166
%c0 = arith.constant 0 : index
0 commit comments