Skip to content

Commit e9478b1

Browse files
authored
[mlir][SVE] Add more e2e test for vector.contract (#70367)
Adds basic integration tests for `vector.contract` for the dot product and matvec operations. These tests exercise scalable vectors. Depends on #69845
1 parent c131455 commit e9478b1

File tree

1 file changed

+97
-1
lines changed

1 file changed

+97
-1
lines changed

mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/test-contraction.mlir

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,38 @@
1313
// REDEFINE: %{entry} = matmul_f32
1414
// RUN: %{run} | FileCheck %s --check-prefix=F32
1515

16+
// REDEFINE: %{entry} = dot_product_i32
17+
// RUN: %{run} | FileCheck %s --check-prefix=DP
18+
19+
// REDEFINE: %{entry} = matvec_i32
20+
// RUN: %{run} | FileCheck %s --check-prefix=MV
21+
1622
// NOTE: These tests are meant to complement the integration tests from:
1723
// * ../test-contraction.mlir
1824
// (tests with fixed width vectors). Rather than duplicating those tests, this
1925
// file focuses on excercissing scalable vectors in a few most common cases.
2026

21-
// TODO: Masks + matvec + dot product
27+
// TODO: Masks
28+
29+
#dotp_accesses = [
30+
affine_map<(i) -> (i)>,
31+
affine_map<(i) -> (i)>,
32+
affine_map<(i) -> ()>
33+
]
34+
#dotp_trait = {
35+
indexing_maps = #dotp_accesses,
36+
iterator_types = ["reduction"]
37+
}
38+
39+
#matvec_accesses = [
40+
affine_map<(i, j) -> (i, j)>,
41+
affine_map<(i, j) -> (j)>,
42+
affine_map<(i, j) -> (i)>
43+
]
44+
#matvec_trait = {
45+
indexing_maps = #matvec_accesses,
46+
iterator_types = ["parallel", "reduction"]
47+
}
2248

2349
#matmat_accesses = [
2450
affine_map<(i, j, k) -> (i, k)>,
@@ -30,6 +56,76 @@
3056
iterator_types = ["parallel", "parallel", "reduction"]
3157
}
3258

59+
// Contraction: dot-product a x b.
60+
func.func @dot_product_i32() {
61+
%acc = arith.constant 0: i32
62+
63+
%vector_a = arith.constant dense<123> : vector<[4]xi32>
64+
%vector_b = arith.constant dense<314> : vector<[4]xi32>
65+
%vector_c = arith.constant dense<0> : vector<[4]xi32>
66+
67+
// DOT PRODUCT 1
68+
%dp1 = vector.contract #dotp_trait %vector_a, %vector_b, %acc
69+
: vector<[4]xi32>, vector<[4]xi32> into i32
70+
// Dot product should be:
71+
// * val = (123 * 314) * 4 * vscale,
72+
// so ...
73+
%vscale = vector.vscale
74+
%vscale_i32 = arith.index_cast %vscale : index to i32
75+
%dp1_div = arith.divui %dp1, %vscale_i32 : i32
76+
// ... val / vscale = 123 * 314 * 4 = 154488
77+
// DP: 154488
78+
vector.print %dp1_div : i32
79+
80+
// DOT PRODUCT 2
81+
// The result of this dot-product should be 0.
82+
%dp2 = vector.contract #dotp_trait %vector_a, %vector_c, %acc
83+
: vector<[4]xi32>, vector<[4]xi32> into i32
84+
// DP: 0
85+
vector.print %dp2 : i32
86+
87+
// DP: SVE: END OF TEST OUTPUT
88+
vector.print str "SVE: END OF TEST OUTPUT"
89+
90+
return
91+
}
92+
93+
// Contraction: matrix-vector A x c
94+
func.func @matvec_i32() {
95+
%acc = arith.constant dense<0>: vector<3xi32>
96+
97+
%vector_a = arith.constant dense<123> : vector<3x[4]xi32>
98+
%vector_b = arith.constant dense<314> : vector<[4]xi32>
99+
%vector_c = arith.constant dense<0> : vector<[4]xi32>
100+
101+
// MATVEC 1
102+
%mv1 = vector.contract #matvec_trait %vector_a, %vector_b, %acc
103+
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
104+
// Every element in the output vector is a result of a dot product, for
105+
// which:
106+
// val = (123 * 314) * 4 * vscale
107+
// so ...
108+
%vscale = vector.vscale
109+
%vscale_v = vector.splat %vscale : vector<3xindex>
110+
%vscale_i32 = arith.index_cast %vscale_v : vector<3xindex> to vector<3xi32>
111+
%mv1_div = arith.divui %mv1, %vscale_i32 : vector<3xi32>
112+
// ... val / vscale = 123 * 314 * 4 = 154488
113+
// MV: 154488, 154488, 154488
114+
vector.print %mv1_div : vector<3xi32>
115+
116+
// MATVEC 2
117+
// The result of this matvec should be a vector of 0s.
118+
%mv2 = vector.contract #matvec_trait %vector_a, %vector_c, %acc
119+
: vector<3x[4]xi32>, vector<[4]xi32> into vector<3xi32>
120+
// MV: 0, 0, 0
121+
vector.print %mv2 : vector<3xi32>
122+
123+
// MV: SVE: END OF TEST OUTPUT
124+
vector.print str "SVE: END OF TEST OUTPUT"
125+
126+
return
127+
}
128+
33129
func.func @matmul_i32() {
34130
// Setup vector A:
35131
%vector_a = arith.constant dense<123> : vector<3x5xi32>

0 commit comments

Comments
 (0)