@@ -22,12 +22,14 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
22
22
}
23
23
24
24
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
25
- size_t SUB_COLS, class kernel_name , typename OP>
26
- void verify_op_a (const T l, const T r, const float ref, OP op) {
27
- T mat[NUM_ROWS][NUM_COLS];
28
- big_matrix<T, NUM_ROWS, NUM_COLS> big_mat ((T *)&mat);
25
+ size_t SUB_COLS, use Use, layout Layout, size_t VF, class kernel_name ,
26
+ typename OP>
27
+ void verify_op_ab (const T l, const T r, const float ref, OP op) {
28
+ T mat[NUM_ROWS / VF][NUM_COLS * VF];
29
+ big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> big_mat ((T *)&mat);
29
30
30
- buffer<T, 2 > bufMat (big_mat.get_data (), range<2 >(NUM_ROWS, NUM_COLS));
31
+ buffer<T, 2 > bufMat (big_mat.get_data (),
32
+ range<2 >(NUM_ROWS / VF, NUM_COLS * VF));
31
33
32
34
queue q;
33
35
size_t sg_size = get_sg_size<kernel_name>(q);
@@ -47,20 +49,19 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
47
49
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
48
50
49
51
sub_group sg = spmd_item.get_sub_group ();
50
- joint_matrix<sub_group, T, use::a, SUB_ROWS, SUB_COLS,
51
- layout::row_major>
52
- sub_mat;
52
+ joint_matrix<sub_group, T, Use, SUB_ROWS, SUB_COLS, Layout> sub_mat;
53
53
joint_matrix_fill (sg, sub_mat, l);
54
54
joint_matrix_apply (sg, sub_mat, [=](T &x) { x = op (x, r); });
55
55
ext::intel::experimental::matrix::joint_matrix_store (
56
56
sg, sub_mat,
57
57
accessMat.template get_multi_ptr <access ::decorated::no>() +
58
- (sg_startx * SUB_ROWS) * NUM_COLS +
59
- sg_starty / sg_size * SUB_COLS,
60
- NUM_COLS);
58
+ (sg_startx * SUB_ROWS / VF ) * NUM_COLS * VF +
59
+ sg_starty / sg_size * SUB_COLS * VF ,
60
+ NUM_COLS * VF );
61
61
}); // parallel for
62
62
}).wait ();
63
- assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access (read_only), ref);
63
+ assert_ops_ref<T, NUM_ROWS / VF, NUM_COLS * VF>(
64
+ bufMat.get_host_access (read_only), ref);
64
65
}
65
66
66
67
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
@@ -105,37 +106,55 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
105
106
}
106
107
107
108
// Avoid same kernel name for different types
108
- template <typename T, class name > class ewops_a {};
109
- template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a () {
110
- std::cout << " Test A " << SROWS << " x" << SCOLS << " \n " ;
109
+ template <typename T, size_t SROWS, size_t SCOLS, use Use, class name >
110
+ class ewops_ab {};
111
+ template <typename T, size_t SROWS, size_t SCOLS, use Use, layout Layout,
112
+ size_t VF>
113
+ void test_ewops_ab () {
114
+ if constexpr (Use == use::a)
115
+ std::cout << " Test A " ;
116
+ else
117
+ std::cout << " Test B " ;
118
+ std::cout << SROWS << " x" << SCOLS << " \n " ;
111
119
112
120
static constexpr size_t NROWS = SROWS * 2 ;
113
121
static constexpr size_t NCOLS = SCOLS * 2 ;
114
122
115
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add >>(
123
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
124
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_add >>(
116
125
T (5.0 ), T (2.0 ), 7.0 , [](auto l, auto r) { return l + r; });
117
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub >>(
126
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
127
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_sub >>(
118
128
T (5.0 ), T (2.0 ), 3.0 , [](auto l, auto r) { return l - r; });
119
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul >>(
129
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
130
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_mul >>(
120
131
T (5.0 ), T (2.0 ), 10.0 , [](auto l, auto r) { return l * r; });
121
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div >>(
132
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
133
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_div >>(
122
134
T (5.0 ), T (2.0 ), 2.5 , [](auto l, auto r) { return l / r; });
123
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical >>(
135
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
136
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_logical >>(
124
137
T (5.0 ), T (5.0 ), 5.0 , [](auto l, auto r) { return l == r ? l : T (1.0 ); });
125
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq >>(
138
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
139
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_eq >>(
126
140
T (5.0 ), T (4.0 ), 4.0 , [](auto l, auto r) { return l == r ? l : r; });
127
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne >>(
141
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
142
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_ne >>(
128
143
T (5.0 ), T (5.0 ), 1.0 , [](auto l, auto r) { return l != r ? l : T (1.0 ); });
129
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt >>(
144
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
145
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_gt >>(
130
146
T (5.0 ), T (2.0 ), 3.0 ,
131
147
[](auto l, auto r) { return l > r ? T (3.0 ) : T (2.0 ); });
132
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt >>(
148
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
149
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_lt >>(
133
150
T (5.0 ), T (2.0 ), 2.0 ,
134
151
[](auto l, auto r) { return l < r ? T (3.0 ) : T (2.0 ); });
135
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge >>(
152
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
153
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_ge >>(
136
154
T (5.0 ), T (2.0 ), 3.0 ,
137
155
[](auto l, auto r) { return l >= r ? T (3.0 ) : T (2.0 ); });
138
- verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le >>(
156
+ verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
157
+ ewops_ab<T, SROWS, SCOLS, Use, class ab_le >>(
139
158
T (5.0 ), T (2.0 ), 2.0 ,
140
159
[](auto l, auto r) { return l <= r ? T (3.0 ) : T (2.0 ); });
141
160
}
@@ -194,30 +213,34 @@ int main() {
194
213
.get_info <sycl::ext::oneapi::experimental::info::device::
195
214
matrix_combinations>();
196
215
197
- for (unsigned int i = 0 ; i < combinations.size (); i++) {
198
- if (combinations[i].atype == matrix_type::bf16) {
199
-
200
- if (combinations[i].nsize == 0 ||
201
- (combinations[i].msize == 0 && combinations[i].nsize == 16 )) {
202
- test_ewops_a<bfloat16, 8 , 16 >();
203
- test_ewops_c<float , 8 , 16 >();
204
- }
205
-
206
- if (combinations[i].msize == 16 && combinations[i].nsize == 16 ) {
216
+ for (auto &combination : combinations) {
217
+ if (combination.nsize == 0 ||
218
+ combination.nsize == 16 ) { // Intel AMX or architecture::intel_gpu_pvc
219
+ test_ewops_ab<bfloat16, 1 , 16 , use::a, layout::row_major, 1 >();
220
+ test_ewops_ab<bfloat16, 8 , 16 , use::a, layout::row_major, 1 >();
221
+ test_ewops_ab<bfloat16, 16 , 16 , use::b, layout::ext_intel_packed, 2 >();
222
+ test_ewops_c<float , 1 , 16 >();
223
+ test_ewops_c<float , 8 , 16 >();
224
+
225
+ if (combination.nsize == 16 ) { // architecture::intel_gpu_pvc
226
+ test_ewops_ab<bfloat16, 16 , 16 , use::a, layout::row_major, 1 >();
207
227
test_ewops_c<float , 16 , 16 >();
208
- }
209
-
210
228
// This combination is not currently supported for sub group size = 32 in IGC
211
229
#if (!defined(SG_SZ) || SG_SZ != 32)
212
- if (combinations[i].msize == 32 && combinations[i].nsize == 64 ) {
230
+ test_ewops_ab<bfloat16, 32 , 16 , use::a, layout::row_major, 1 >();
231
+ test_ewops_ab<bfloat16, 16 , 64 , use::b, layout::ext_intel_packed, 2 >();
232
+ test_ewops_c<float , 1 , 64 >();
213
233
test_ewops_c<float , 32 , 64 >();
214
- }
215
234
#endif
216
-
217
- if (combinations[i].nsize == 8 ) {
218
- test_ewops_a<bfloat16, 8 , 16 >();
219
- test_ewops_c<float , 8 , 8 >();
220
235
}
236
+ break ;
237
+ }
238
+
239
+ if (combination.nsize == 8 ) { // architecture::intel_gpu_dg2*
240
+ test_ewops_ab<bfloat16, 8 , 16 , use::a, layout::row_major, 1 >();
241
+ test_ewops_ab<bfloat16, 16 , 8 , use::b, layout::ext_intel_packed, 2 >();
242
+ test_ewops_c<float , 8 , 8 >();
243
+ break ;
221
244
}
222
245
}
223
246
0 commit comments