Skip to content

Commit cfd0d41

Browse files
authored
[SYCL][Joint Matrix] Test stores A and B for bfloat16 16x16x16, 32x64x16, 1x64x16 (#13572)
1 parent 3756fd1 commit cfd0d41

File tree

1 file changed

+67
-44
lines changed

1 file changed

+67
-44
lines changed

sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 67 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> mat,
2222
}
2323

2424
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
25-
size_t SUB_COLS, class kernel_name, typename OP>
26-
void verify_op_a(const T l, const T r, const float ref, OP op) {
27-
T mat[NUM_ROWS][NUM_COLS];
28-
big_matrix<T, NUM_ROWS, NUM_COLS> big_mat((T *)&mat);
25+
size_t SUB_COLS, use Use, layout Layout, size_t VF, class kernel_name,
26+
typename OP>
27+
void verify_op_ab(const T l, const T r, const float ref, OP op) {
28+
T mat[NUM_ROWS / VF][NUM_COLS * VF];
29+
big_matrix<T, NUM_ROWS / VF, NUM_COLS * VF> big_mat((T *)&mat);
2930

30-
buffer<T, 2> bufMat(big_mat.get_data(), range<2>(NUM_ROWS, NUM_COLS));
31+
buffer<T, 2> bufMat(big_mat.get_data(),
32+
range<2>(NUM_ROWS / VF, NUM_COLS * VF));
3133

3234
queue q;
3335
size_t sg_size = get_sg_size<kernel_name>(q);
@@ -47,20 +49,19 @@ void verify_op_a(const T l, const T r, const float ref, OP op) {
4749
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
4850

4951
sub_group sg = spmd_item.get_sub_group();
50-
joint_matrix<sub_group, T, use::a, SUB_ROWS, SUB_COLS,
51-
layout::row_major>
52-
sub_mat;
52+
joint_matrix<sub_group, T, Use, SUB_ROWS, SUB_COLS, Layout> sub_mat;
5353
joint_matrix_fill(sg, sub_mat, l);
5454
joint_matrix_apply(sg, sub_mat, [=](T &x) { x = op(x, r); });
5555
ext::intel::experimental::matrix::joint_matrix_store(
5656
sg, sub_mat,
5757
accessMat.template get_multi_ptr<access::decorated::no>() +
58-
(sg_startx * SUB_ROWS) * NUM_COLS +
59-
sg_starty / sg_size * SUB_COLS,
60-
NUM_COLS);
58+
(sg_startx * SUB_ROWS / VF) * NUM_COLS * VF +
59+
sg_starty / sg_size * SUB_COLS * VF,
60+
NUM_COLS * VF);
6161
}); // parallel for
6262
}).wait();
63-
assert_ops_ref<T, NUM_ROWS, NUM_COLS>(bufMat.get_host_access(read_only), ref);
63+
assert_ops_ref<T, NUM_ROWS / VF, NUM_COLS * VF>(
64+
bufMat.get_host_access(read_only), ref);
6465
}
6566

6667
template <typename T, size_t NUM_ROWS, size_t NUM_COLS, size_t SUB_ROWS,
@@ -105,37 +106,55 @@ void verify_op_c(const T l, const T r, const float ref, OP op) {
105106
}
106107

107108
// Avoid same kernel name for different types
108-
template <typename T, class name> class ewops_a {};
109-
template <typename T, size_t SROWS, size_t SCOLS> void test_ewops_a() {
110-
std::cout << "Test A " << SROWS << "x" << SCOLS << "\n";
109+
template <typename T, size_t SROWS, size_t SCOLS, use Use, class name>
110+
class ewops_ab {};
111+
template <typename T, size_t SROWS, size_t SCOLS, use Use, layout Layout,
112+
size_t VF>
113+
void test_ewops_ab() {
114+
if constexpr (Use == use::a)
115+
std::cout << "Test A ";
116+
else
117+
std::cout << "Test B ";
118+
std::cout << SROWS << "x" << SCOLS << "\n";
111119

112120
static constexpr size_t NROWS = SROWS * 2;
113121
static constexpr size_t NCOLS = SCOLS * 2;
114122

115-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_add>>(
123+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
124+
ewops_ab<T, SROWS, SCOLS, Use, class ab_add>>(
116125
T(5.0), T(2.0), 7.0, [](auto l, auto r) { return l + r; });
117-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_sub>>(
126+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
127+
ewops_ab<T, SROWS, SCOLS, Use, class ab_sub>>(
118128
T(5.0), T(2.0), 3.0, [](auto l, auto r) { return l - r; });
119-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_mul>>(
129+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
130+
ewops_ab<T, SROWS, SCOLS, Use, class ab_mul>>(
120131
T(5.0), T(2.0), 10.0, [](auto l, auto r) { return l * r; });
121-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_div>>(
132+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
133+
ewops_ab<T, SROWS, SCOLS, Use, class ab_div>>(
122134
T(5.0), T(2.0), 2.5, [](auto l, auto r) { return l / r; });
123-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_logical>>(
135+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
136+
ewops_ab<T, SROWS, SCOLS, Use, class ab_logical>>(
124137
T(5.0), T(5.0), 5.0, [](auto l, auto r) { return l == r ? l : T(1.0); });
125-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_eq>>(
138+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
139+
ewops_ab<T, SROWS, SCOLS, Use, class ab_eq>>(
126140
T(5.0), T(4.0), 4.0, [](auto l, auto r) { return l == r ? l : r; });
127-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ne>>(
141+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
142+
ewops_ab<T, SROWS, SCOLS, Use, class ab_ne>>(
128143
T(5.0), T(5.0), 1.0, [](auto l, auto r) { return l != r ? l : T(1.0); });
129-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_gt>>(
144+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
145+
ewops_ab<T, SROWS, SCOLS, Use, class ab_gt>>(
130146
T(5.0), T(2.0), 3.0,
131147
[](auto l, auto r) { return l > r ? T(3.0) : T(2.0); });
132-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_lt>>(
148+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
149+
ewops_ab<T, SROWS, SCOLS, Use, class ab_lt>>(
133150
T(5.0), T(2.0), 2.0,
134151
[](auto l, auto r) { return l < r ? T(3.0) : T(2.0); });
135-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_ge>>(
152+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
153+
ewops_ab<T, SROWS, SCOLS, Use, class ab_ge>>(
136154
T(5.0), T(2.0), 3.0,
137155
[](auto l, auto r) { return l >= r ? T(3.0) : T(2.0); });
138-
verify_op_a<T, NROWS, NCOLS, SROWS, SCOLS, ewops_a<T, class a_le>>(
156+
verify_op_ab<T, NROWS, NCOLS, SROWS, SCOLS, Use, Layout, VF,
157+
ewops_ab<T, SROWS, SCOLS, Use, class ab_le>>(
139158
T(5.0), T(2.0), 2.0,
140159
[](auto l, auto r) { return l <= r ? T(3.0) : T(2.0); });
141160
}
@@ -194,30 +213,34 @@ int main() {
194213
.get_info<sycl::ext::oneapi::experimental::info::device::
195214
matrix_combinations>();
196215

197-
for (unsigned int i = 0; i < combinations.size(); i++) {
198-
if (combinations[i].atype == matrix_type::bf16) {
199-
200-
if (combinations[i].nsize == 0 ||
201-
(combinations[i].msize == 0 && combinations[i].nsize == 16)) {
202-
test_ewops_a<bfloat16, 8, 16>();
203-
test_ewops_c<float, 8, 16>();
204-
}
205-
206-
if (combinations[i].msize == 16 && combinations[i].nsize == 16) {
216+
for (auto &combination : combinations) {
217+
if (combination.nsize == 0 ||
218+
combination.nsize == 16) { // Intel AMX or architecture::intel_gpu_pvc
219+
test_ewops_ab<bfloat16, 1, 16, use::a, layout::row_major, 1>();
220+
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
221+
test_ewops_ab<bfloat16, 16, 16, use::b, layout::ext_intel_packed, 2>();
222+
test_ewops_c<float, 1, 16>();
223+
test_ewops_c<float, 8, 16>();
224+
225+
if (combination.nsize == 16) { // architecture::intel_gpu_pvc
226+
test_ewops_ab<bfloat16, 16, 16, use::a, layout::row_major, 1>();
207227
test_ewops_c<float, 16, 16>();
208-
}
209-
210228
// This combination is not currently supported for sub group size = 32 in IGC
211229
#if (!defined(SG_SZ) || SG_SZ != 32)
212-
if (combinations[i].msize == 32 && combinations[i].nsize == 64) {
230+
test_ewops_ab<bfloat16, 32, 16, use::a, layout::row_major, 1>();
231+
test_ewops_ab<bfloat16, 16, 64, use::b, layout::ext_intel_packed, 2>();
232+
test_ewops_c<float, 1, 64>();
213233
test_ewops_c<float, 32, 64>();
214-
}
215234
#endif
216-
217-
if (combinations[i].nsize == 8) {
218-
test_ewops_a<bfloat16, 8, 16>();
219-
test_ewops_c<float, 8, 8>();
220235
}
236+
break;
237+
}
238+
239+
if (combination.nsize == 8) { // architecture::intel_gpu_dg2*
240+
test_ewops_ab<bfloat16, 8, 16, use::a, layout::row_major, 1>();
241+
test_ewops_ab<bfloat16, 16, 8, use::b, layout::ext_intel_packed, 2>();
242+
test_ewops_c<float, 8, 8>();
243+
break;
221244
}
222245
}
223246

0 commit comments

Comments
 (0)