@@ -1222,6 +1222,73 @@ func.func @batch_reduce_matmul(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32
1222
1222
1223
1223
// -----
1224
1224
1225
+ // CHECK-LABEL: func @matmul_transpose_a_explicit
1226
+ // CHECK: linalg.matmul
1227
+ // CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<5x3xf32>, memref<5x7xf32>)
1228
+ // CHECK-SAME: outs(%{{.+}} : memref<3x7xf32>)
1229
+ func.func @matmul_transpose_a_explicit (%arg0: memref <5 x3 xf32 >, %arg1: memref <5 x7 xf32 >, %arg2: memref <3 x7 xf32 >) {
1230
+ linalg.matmul index ing_maps = [
1231
+ affine_map <(d0 , d1 , d2 ) -> (d2 , d0 )>,
1232
+ affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>,
1233
+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1234
+ ]
1235
+ ins (%arg0 , %arg1 : memref <5 x3 xf32 >, memref <5 x7 xf32 >)
1236
+ outs (%arg2: memref <3 x7 xf32 >)
1237
+ return
1238
+ }
1239
+
1240
+ // -----
1241
+
1242
+ func.func @matmul_transpose_b_explicit (%arg0: memref <3 x5 xf32 >, %arg1: memref <7 x5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1243
+ linalg.matmul index ing_maps = [
1244
+ affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>,
1245
+ affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>,
1246
+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1247
+ ]
1248
+ ins (%arg0 , %arg1 : memref <3 x5 xf32 >, memref <7 x5 xf32 >)
1249
+ outs (%arg2: memref <3 x7 xf32 >)
1250
+ return
1251
+ }
1252
+
1253
+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)>
1254
+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1255
+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1256
+
1257
+ // CHECK-LABEL: func.func @matmul_transpose_b_explicit(
1258
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<3x5xf32>,
1259
+ // CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
1260
+ // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1261
+ // CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<3x5xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
1262
+ // CHECK: return
1263
+ // CHECK: }
1264
+
1265
+ // -----
1266
+
1267
+ func.func @matmul_transpose_a_b_explicit (%arg0: memref <5 x3 xf32 >, %arg1: memref <7 x5 xf32 >, %arg2: memref <3 x7 xf32 >) {
1268
+ linalg.matmul index ing_maps = [
1269
+ affine_map <(d0 , d1 , d2 ) -> (d2 , d0 )>,
1270
+ affine_map <(d0 , d1 , d2 ) -> (d1 , d2 )>,
1271
+ affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
1272
+ ]
1273
+ ins (%arg0 , %arg1 : memref <5 x3 xf32 >, memref <7 x5 xf32 >)
1274
+ outs (%arg2: memref <3 x7 xf32 >)
1275
+ return
1276
+ }
1277
+
1278
+ // CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0)>
1279
+ // CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1, d2) -> (d1, d2)>
1280
+ // CHECK: #[[$ATTR_2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)>
1281
+
1282
+ // CHECK-LABEL: func.func @matmul_transpose_a_b_explicit(
1283
+ // CHECK-SAME: %[[VAL_0:.*]]: memref<5x3xf32>,
1284
+ // CHECK-SAME: %[[VAL_1:.*]]: memref<7x5xf32>,
1285
+ // CHECK-SAME: %[[VAL_2:.*]]: memref<3x7xf32>) {
1286
+ // CHECK: linalg.matmul indexing_maps = [#[[$ATTR_0]], #[[$ATTR_1]], #[[$ATTR_2]]] ins(%[[VAL_0]], %[[VAL_1]] : memref<5x3xf32>, memref<7x5xf32>) outs(%[[VAL_2]] : memref<3x7xf32>)
1287
+ // CHECK: return
1288
+ // CHECK: }
1289
+
1290
+ // -----
1291
+
1225
1292
func.func @matmul_bcast_a (%arg0: memref <5 xf32 >, %arg1: memref <5 x7 xf32 >, %arg2: memref <3 x7 xf32 >) {
1226
1293
linalg.matmul index ing_maps = [
1227
1294
affine_map <(d0 , d1 , d2 ) -> (d2 )>,
0 commit comments