Skip to content

Commit 49c7d2e

Browse files
committed
[mlir][vector] Add scalable vectors to tests for vector.contract
Update the remaining tests for matrix multiplication (_matmul_) in: * vector-contract-to-outerproduct-transforms.mlir with cases for scalable vectors. Note that in order for the "vector.contract -> vector.outerproduct" patterns to work, only the non-reduction dimension can be scalable (*). For Matmul operations that is set to be the N dimension (i.e. rows of the output matrix), which matches how matrix multiplication are normally implemented for e.g. Arm's SVE. However, making the M dimension scalable (i.e. columns of the output matrix) should work as well. Making both parellel dimensions scalable is left as a TODO for when support for 2-D scalable vectors is more established (this is work-in-progress as part of the effort to support Arm's SME in MLIR). The change in: * `UnrolledOuterProductGenerator` is a "bug fix" to make sure that the conversion pattern correctly propagates scalability when creating `arith.extf` operations. (*) The conversion tested in this file unrolls along the reduction dimension, which is not supported for scalable vectors.
1 parent ea1909f commit 49c7d2e

File tree

2 files changed

+151
-1
lines changed

2 files changed

+151
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ struct UnrolledOuterProductGenerator
418418
return v;
419419
Type promotedType = dstElementType;
420420
if (vecType)
421-
promotedType = VectorType::get(vecType.getShape(), promotedType);
421+
promotedType = vecType.clone(promotedType);
422422
if (isa<FloatType>(dstElementType))
423423
return rewriter.create<arith::ExtFOp>(loc, promotedType, v);
424424
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);

mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,42 @@ func.func @matmul(%arg0: vector<2x4xf32>,
169169
return %0 : vector<2x3xf32>
170170
}
171171

172+
// CHECK-LABEL: func @matmul_scalable
173+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>,
174+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x[3]xf32>,
175+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
176+
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
177+
// CHECK-SAME: : vector<2x4xf32> to vector<4x2xf32>
178+
//
179+
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<4x2xf32>
180+
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<4x[3]xf32>
181+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
182+
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
183+
//
184+
// CHECK: %[[a1:.*]] = vector.extract %[[At]][1] : vector<2xf32> from vector<4x2xf32>
185+
// CHECK: %[[b1:.*]] = vector.extract %[[B]][1] : vector<[3]xf32> from vector<4x[3]xf32>
186+
// CHECK: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]]
187+
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
188+
//
189+
// CHECK: %[[a2:.*]] = vector.extract %[[At]][2] : vector<2xf32> from vector<4x2xf32>
190+
// CHECK: %[[b2:.*]] = vector.extract %[[B]][2] : vector<[3]xf32> from vector<4x[3]xf32>
191+
// CHECK: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]]
192+
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
193+
//
194+
// CHECK: %[[a3:.*]] = vector.extract %[[At]][3] : vector<2xf32> from vector<4x2xf32>
195+
// CHECK: %[[b3:.*]] = vector.extract %[[B]][3] : vector<[3]xf32> from vector<4x[3]xf32>
196+
// CHECK: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]]
197+
// CHECK-SAME: : vector<2xf32>, vector<[3]xf32>
198+
//
199+
// CHECK: return %[[c3]] : vector<2x[3]xf32>
200+
func.func @matmul_scalable(%arg0: vector<2x4xf32>,
201+
%arg1: vector<4x[3]xf32>,
202+
%arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
203+
%0 = vector.contract #matmat_trait %arg0, %arg1, %arg2
204+
: vector<2x4xf32>, vector<4x[3]xf32> into vector<2x[3]xf32>
205+
return %0 : vector<2x[3]xf32>
206+
}
207+
172208
// CHECK-LABEL: func @matmul_0
173209
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
174210
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
@@ -186,6 +222,23 @@ func.func @matmul_0(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
186222
return %0 : vector<2x3xf32>
187223
}
188224

225+
// CHECK-LABEL: func @matmul_0_scalable
226+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
227+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
228+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
229+
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
230+
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
231+
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
232+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
233+
// CHECK: return %[[c0]] : vector<2x[3]xf32>
234+
func.func @matmul_0_scalable(%arg0: vector<2x1xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
235+
-> vector<2x[3]xf32>
236+
{
237+
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
238+
: vector<2x1xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
239+
return %0 : vector<2x[3]xf32>
240+
}
241+
189242
// CHECK-LABEL: func @matmul_0_mixed
190243
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
191244
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf16>,
@@ -205,6 +258,25 @@ func.func @matmul_0_mixed(%arg0: vector<2x1xf16>, %arg1: vector<1x3xf16>, %arg2:
205258
return %0 : vector<2x3xf32>
206259
}
207260

261+
// CHECK-LABEL: func @matmul_0_mixed_scalable
262+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf16>,
263+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf16>,
264+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
265+
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
266+
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf16> from vector<1x2xf16>
267+
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf16> from vector<1x[3]xf16>
268+
// CHECK: %[[a1:.*]] = arith.extf %[[a0]] : vector<2xf16> to vector<2xf32>
269+
// CHECK: %[[b1:.*]] = arith.extf %[[b0]] : vector<[3]xf16> to vector<[3]xf32>
270+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[C]]
271+
// CHECK: return %[[c0]] : vector<2x[3]xf32>
272+
func.func @matmul_0_mixed_scalable(%arg0: vector<2x1xf16>, %arg1: vector<1x[3]xf16>, %arg2: vector<2x[3]xf32>)
273+
-> vector<2x[3]xf32>
274+
{
275+
%0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2
276+
: vector<2x1xf16>, vector<1x[3]xf16> into vector<2x[3]xf32>
277+
return %0 : vector<2x[3]xf32>
278+
}
279+
208280
#matmat_accesses_1 = [
209281
affine_map<(m, n, k) -> (m, k)>,
210282
affine_map<(m, n, k) -> (n, k)>,
@@ -233,6 +305,24 @@ func.func @matmul_1(%arg0: vector<2x1xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
233305
return %0 : vector<2x3xf32>
234306
}
235307

308+
// CHECK-LABEL: func @matmul_1_scalable
309+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x1xf32>,
310+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
311+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
312+
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
313+
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
314+
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<2xf32> from vector<1x2xf32>
315+
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
316+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
317+
// CHECK: return %[[c0]] : vector<2x[3]xf32>
318+
func.func @matmul_1_scalable(%arg0: vector<2x1xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
319+
-> vector<2x[3]xf32>
320+
{
321+
%0 = vector.contract #matmat_trait_1 %arg0, %arg1, %arg2
322+
: vector<2x1xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
323+
return %0 : vector<2x[3]xf32>
324+
}
325+
236326
#matmat_accesses_2 = [
237327
affine_map<(m, n, k) -> (k, m)>,
238328
affine_map<(m, n, k) -> (k, n)>,
@@ -259,6 +349,22 @@ func.func @matmul_2(%arg0: vector<1x2xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
259349
return %0 : vector<2x3xf32>
260350
}
261351

352+
// CHECK-LABEL: func @matmul_2_scalable
353+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
354+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x[3]xf32>,
355+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
356+
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
357+
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<[3]xf32> from vector<1x[3]xf32>
358+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
359+
// CHECK: return %[[c0]] : vector<2x[3]xf32>
360+
func.func @matmul_2_scalable(%arg0: vector<1x2xf32>, %arg1: vector<1x[3]xf32>, %arg2: vector<2x[3]xf32>)
361+
-> vector<2x[3]xf32>
362+
{
363+
%0 = vector.contract #matmat_trait_2 %arg0, %arg1, %arg2
364+
: vector<1x2xf32>, vector<1x[3]xf32> into vector<2x[3]xf32>
365+
return %0 : vector<2x[3]xf32>
366+
}
367+
262368
#matmat_accesses_3 = [
263369
affine_map<(m, n, k) -> (k, m)>,
264370
affine_map<(m, n, k) -> (n, k)>,
@@ -286,6 +392,23 @@ func.func @matmul_3(%arg0: vector<1x2xf32>, %arg1: vector<3x1xf32>, %arg2: vecto
286392
return %0 : vector<2x3xf32>
287393
}
288394

395+
// CHECK-LABEL: func @matmul_3_scalable
396+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<1x2xf32>,
397+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<[3]x1xf32>,
398+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x[3]xf32>
399+
// CHECK: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0]
400+
// CHECK: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2xf32> from vector<1x2xf32>
401+
// CHECK: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<[3]xf32> from vector<1x[3]xf32>
402+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]]
403+
// CHECK: return %[[c0]] : vector<2x[3]xf32>
404+
func.func @matmul_3_scalable(%arg0: vector<1x2xf32>, %arg1: vector<[3]x1xf32>, %arg2: vector<2x[3]xf32>)
405+
-> vector<2x[3]xf32>
406+
{
407+
%0 = vector.contract #matmat_trait_3 %arg0, %arg1, %arg2
408+
: vector<1x2xf32>, vector<[3]x1xf32> into vector<2x[3]xf32>
409+
return %0 : vector<2x[3]xf32>
410+
}
411+
289412
#matmat_accesses_4 = [
290413
affine_map<(m, n, k) -> (m, k)>,
291414
affine_map<(m, n, k) -> (k, n)>,
@@ -313,6 +436,33 @@ func.func @matmul_4(%arg0: vector<2x1xf32>, %arg1: vector<1x3xf32>, %arg2: vecto
313436
return %0 : vector<3x2xf32>
314437
}
315438

439+
// CHECK-LABEL: func @matmul_4_scalable
440+
// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: vector<[2]x1xf32>,
441+
// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: vector<1x3xf32>,
442+
// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x[2]xf32>
443+
// CHECK: %[[At:.*]] = vector.transpose %[[A]], [1, 0]
444+
// CHECK: %[[b0:.*]] = vector.extract %[[B]][0] : vector<3xf32> from vector<1x3xf32>
445+
// CHECK: %[[a0:.*]] = vector.extract %[[At]][0] : vector<[2]xf32> from vector<1x[2]xf32>
446+
// CHECK: %[[c0:.*]] = vector.outerproduct %[[b0]], %[[a0]], %[[C]]
447+
// CHECK: return %[[c0]] : vector<3x[2]xf32>
448+
func.func @matmul_4_scalable(%arg0: vector<[2]x1xf32>, %arg1: vector<1x3xf32>, %arg2: vector<3x[2]xf32>)
449+
-> vector<3x[2]xf32>
450+
{
451+
%0 = vector.contract #matmat_trait_4 %arg0, %arg1, %arg2
452+
: vector<[2]x1xf32>, vector<1x3xf32> into vector<3x[2]xf32>
453+
return %0 : vector<3x[2]xf32>
454+
}
455+
456+
#matmat_accesses_5 = [
457+
affine_map<(m, n, k) -> (m, k)>,
458+
affine_map<(m, n, k) -> (k, n)>,
459+
affine_map<(m, n, k) -> (n, m)>
460+
]
461+
#matmat_trait_5 = {
462+
indexing_maps = #matmat_accesses_5,
463+
iterator_types = ["parallel", "parallel", "reduction"]
464+
}
465+
316466
// CHECK-LABEL: @masked_matvec_mk_k_m
317467
// CHECK-SAME: %[[MAT:.+]]: vector<4x2xf32>
318468
// CHECK-SAME: %[[VEC:.+]]: vector<2xf32>

0 commit comments

Comments
 (0)