@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
169
169
return %0 : vector <2 x3 xf32 >
170
170
}
171
171
172
+ // CHECK-LABEL: func @matmul_scalable
173
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
174
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
175
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
176
+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
177
+ // CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
178
+ //
179
+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
180
+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
181
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
182
+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
183
+ //
184
+ // CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
185
+ // CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
186
+ // CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
187
+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
188
+ //
189
+ // CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
190
+ // CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
191
+ // CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
192
+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
193
+ //
194
+ // CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
195
+ // CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
196
+ // CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
197
+ // CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
198
+ //
199
+ // CHECK: return %[[c3]] : vector<2x[3]xf32>
200
+ func.func @matmul_scalable (%arg0: vector <2 x4 xf32 >,
201
+ %arg1: vector <4 x[3 ]xf32 >,
202
+ %arg2: vector <2 x[3 ]xf32 >) -> vector <2 x[3 ]xf32 > {
203
+ %0 = vector.contract #matmat_trait %arg0 , %arg1 , %arg2
204
+ : vector <2 x4 xf32 >, vector <4 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
205
+ return %0 : vector <2 x[3 ]xf32 >
206
+ }
207
+
172
208
// CHECK-LABEL: func @matmul_0
173
209
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
174
210
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
186
222
return %0 : vector <2 x3 xf32 >
187
223
}
188
224
225
+ // CHECK-LABEL: func @matmul_0_scalable
226
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
227
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
228
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
229
+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
230
+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
231
+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
232
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
233
+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
234
+ func.func @matmul_0_scalable (%arg0: vector <2 x1 xf32 >, %arg1: vector <1 x[3 ]xf32 >, %arg2: vector <2 x[3 ]xf32 >)
235
+ -> vector <2 x[3 ]xf32 >
236
+ {
237
+ %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
238
+ : vector <2 x1 xf32 >, vector <1 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
239
+ return %0 : vector <2 x[3 ]xf32 >
240
+ }
241
+
189
242
// CHECK-LABEL: func @matmul_0_mixed
190
243
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
191
244
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
205
258
return %0 : vector <2 x3 xf32 >
206
259
}
207
260
261
+ // CHECK-LABEL: func @matmul_0_mixed_scalable
262
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
263
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
264
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
265
+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
266
+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
267
+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
268
+ // CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
269
+ // CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
270
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
271
+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
272
+ func.func @matmul_0_mixed_scalable (%arg0: vector <2 x1 xf16 >, %arg1: vector <1 x[3 ]xf16 >, %arg2: vector <2 x[3 ]xf32 >)
273
+ -> vector <2 x[3 ]xf32 >
274
+ {
275
+ %0 = vector.contract #matmat_trait_0 %arg0 , %arg1 , %arg2
276
+ : vector <2 x1 xf16 >, vector <1 x[3 ]xf16 > into vector <2 x[3 ]xf32 >
277
+ return %0 : vector <2 x[3 ]xf32 >
278
+ }
279
+
208
280
#matmat_accesses_1 = [
209
281
affine_map <(m , n , k ) -> (m , k )>,
210
282
affine_map <(m , n , k ) -> (n , k )>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
233
305
return %0 : vector <2 x3 xf32 >
234
306
}
235
307
308
+ // CHECK-LABEL: func @matmul_1_scalable
309
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
310
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
311
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
312
+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
313
+ // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
314
+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
315
+ // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
316
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
317
+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
318
+ func.func @matmul_1_scalable (%arg0: vector <2 x1 xf32 >, %arg1: vector <[3 ]x1 xf32 >, %arg2: vector <2 x[3 ]xf32 >)
319
+ -> vector <2 x[3 ]xf32 >
320
+ {
321
+ %0 = vector.contract #matmat_trait_1 %arg0 , %arg1 , %arg2
322
+ : vector <2 x1 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
323
+ return %0 : vector <2 x[3 ]xf32 >
324
+ }
325
+
236
326
#matmat_accesses_2 = [
237
327
affine_map <(m , n , k ) -> (k , m )>,
238
328
affine_map <(m , n , k ) -> (k , n )>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
259
349
return %0 : vector <2 x3 xf32 >
260
350
}
261
351
352
+ // CHECK-LABEL: func @matmul_2_scalable
353
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
354
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
355
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
356
+ // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
357
+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
358
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
359
+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
360
+ func.func @matmul_2_scalable (%arg0: vector <1 x2 xf32 >, %arg1: vector <1 x[3 ]xf32 >, %arg2: vector <2 x[3 ]xf32 >)
361
+ -> vector <2 x[3 ]xf32 >
362
+ {
363
+ %0 = vector.contract #matmat_trait_2 %arg0 , %arg1 , %arg2
364
+ : vector <1 x2 xf32 >, vector <1 x[3 ]xf32 > into vector <2 x[3 ]xf32 >
365
+ return %0 : vector <2 x[3 ]xf32 >
366
+ }
367
+
262
368
#matmat_accesses_3 = [
263
369
affine_map <(m , n , k ) -> (k , m )>,
264
370
affine_map <(m , n , k ) -> (n , k )>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
286
392
return %0 : vector <2 x3 xf32 >
287
393
}
288
394
395
+ // CHECK-LABEL: func @matmul_3_scalable
396
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
397
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
398
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
399
+ // CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
400
+ // CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
401
+ // CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
402
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
403
+ // CHECK: return %[[c0]] : vector<2x[3]xf32>
404
+ func.func @matmul_3_scalable (%arg0: vector <1 x2 xf32 >, %arg1: vector <[3 ]x1 xf32 >, %arg2: vector <2 x[3 ]xf32 >)
405
+ -> vector <2 x[3 ]xf32 >
406
+ {
407
+ %0 = vector.contract #matmat_trait_3 %arg0 , %arg1 , %arg2
408
+ : vector <1 x2 xf32 >, vector <[3 ]x1 xf32 > into vector <2 x[3 ]xf32 >
409
+ return %0 : vector <2 x[3 ]xf32 >
410
+ }
411
+
289
412
#matmat_accesses_4 = [
290
413
affine_map <(m , n , k ) -> (m , k )>,
291
414
affine_map <(m , n , k ) -> (k , n )>,
@@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
313
436
return %0 : vector <3 x2 xf32 >
314
437
}
315
438
439
+ // CHECK-LABEL: func @matmul_4_scalable
440
+ // CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
441
+ // CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
442
+ // CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
443
+ // CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
444
+ // CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
445
+ // CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
446
+ // CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
447
+ // CHECK: return %[[c0]] : vector<3x[2]xf32>
448
+ func.func @matmul_4_scalable (%arg0: vector <[2 ]x1 xf32 >, %arg1: vector <1 x3 xf32 >, %arg2: vector <3 x[2 ]xf32 >)
449
+ -> vector <3 x[2 ]xf32 >
450
+ {
451
+ %0 = vector.contract #matmat_trait_4 %arg0 , %arg1 , %arg2
452
+ : vector <[2 ]x1 xf32 >, vector <1 x3 xf32 > into vector <3 x[2 ]xf32 >
453
+ return %0 : vector <3 x[2 ]xf32 >
454
+ }
455
+
456
+ #matmat_accesses_5 = [
457
+ affine_map <(m , n , k ) -> (m , k )>,
458
+ affine_map <(m , n , k ) -> (k , n )>,
459
+ affine_map <(m , n , k ) -> (n , m )>
460
+ ]
461
+ #matmat_trait_5 = {
462
+ indexing_maps = #matmat_accesses_5 ,
463
+ iterator_types = [" parallel" , " parallel" , " reduction" ]
464
+ }
465
+
316
466
// CHECK-LABEL: @masked_matvec_mk_k_m
317
467
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
318
468
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>
0 commit comments