@@ -228,6 +228,23 @@ void joint_matrix_load(Group g,
228
228
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
229
229
multi_ptr<T2, Space, IsDecorated> src, size_t stride);
230
230
231
+ // Only available when std::is_same_v<T1, std::remove_const_t<T2>>
232
+ template <typename Group, typename T1, typename T2,
233
+ size_t Rows, size_t Cols,
234
+ typename PropertyListT>
235
+ void joint_matrix_load(Group g,
236
+ joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
237
+ annotated_ptr<T2, PropertyListT> src, size_t stride, layout Layout);
238
+
239
+ // Only available when Layout != layout::dynamic
240
+ // and when std::is_same_v<T1, std::remove_const_t<T2>>
241
+ template <typename Group, typename T1, typename T2,
242
+ size_t Rows, size_t Cols, use Use, layout Layout,
243
+ typename PropertyListT>
244
+ void joint_matrix_load(Group g,
245
+ joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
246
+ annotated_ptr<T2, PropertyListT> src, size_t stride);
247
+
231
248
} // namespace sycl::ext::oneapi::experimental::matrix
232
249
```
233
250
@@ -248,6 +265,33 @@ fashion. `stride` describes the number of elements between consecutive
248
265
rows for the row major layout, or between columns for the column major
249
266
layout.
250
267
268
+ The two last overloads of `joint_matrix_load` take
269
+ `sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
270
+ of `sycl::multi_ptr`. The property list associated with the
271
+ `annotated_ptr` argument represents the compile-time constant
272
+ properties for cache control included in the SYCL extenion
273
+ link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
274
+ as illustrated in the example below.
275
+
276
+ ```c++
277
+ using syclex = sycl::ext::oneapi::experimental;
278
+ using syclintelex = sycl::ext::intel::experimental;
279
+
280
+ auto A_ptr = syclex::annotated_ptr{A,
281
+ syclex::properties{syclintelex::read_hint<
282
+ syclintelex::cache_control<syclintelex::cache_mode::cached,
283
+ syclex::cache_level::L2>>}};
284
+ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
285
+ sub_group sg = it.get_sub_group();
286
+ joint_matrix<sub_group, bfloat16, use::a, tM, tK, layout::row_major> tA;
287
+ for (int k = 0; k < K; k += tileK) {
288
+ // User specifies that this load will be cached to L2
289
+ joint_matrix_load(sg, tA, A_ptr + sg_startx * tM * K + k, K);
290
+ ...
291
+ }
292
+ });
293
+ ```
294
+
251
295
==== Store
252
296
```c++
253
297
namespace sycl::ext::oneapi::experimental::matrix {
@@ -259,6 +303,12 @@ void joint_matrix_store(Group g,
259
303
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
260
304
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);
261
305
306
+ template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
307
+ typename PropertyListT>
308
+ void joint_matrix_store(Group g,
309
+ const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
310
+ annotated_ptr<T2, PropertyListT> dest, size_t stride, layout Layout);
311
+
262
312
} // namespace sycl::ext::oneapi::experimental::matrix
263
313
```
264
314
This function stores the data in the accumulator matrix from the
@@ -270,6 +320,11 @@ written in a row (`row_major`), column major (`col_major`)
270
320
fashion. `stride` describes the number of elements between consecutive
271
321
rows for the row major layout, or between columns for the column major layout.
272
322
323
+ The second overload of `joint_matrix_store` takes
324
+ `sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
325
+ of `sycl::multi_ptr`. The property list associated with the
326
+ `annotated_ptr` argument represents the compile-time constant
327
+ properties for cache control included in the SYCL extenion link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
273
328
274
329
==== Multiply and Add
275
330
@@ -372,6 +427,47 @@ joint_matrix_apply(sg, C, [=](T &x) {
372
427
});
373
428
```
374
429
430
+ ==== Prefetch
431
+
432
+ ```c++
433
+ namespace sycl::ext::oneapi::experimental::matrix {
434
+
435
+ template <size_t Rows, size_t Cols, typename Group, typename T,
436
+ typename Properties = empty_properties_t>
437
+ void joint_matrix_prefetch(Group g, T* ptr, size_t stride, layout Layout,
438
+ Properties properties = {});
439
+
440
+ } // namespace sycl::ext::oneapi::experimental::matrix
441
+ ```
442
+
443
+ `joint_matrix_prefetch` allows groups of work-items to cooperatively
444
+ prefetch `Rows x Cols` elements in a 2d manner. This function is a group
445
+ function, as defined in Section 4.17.3 of the core SYCL
446
+ specification.
447
+
448
+ The level of cache targeted by `joint_matrix_prefetch` in the last
449
+ argument is specified using the compile-time properties defined in the
450
+ SYCL extension
451
+ link:../../proposed/sycl_ext_oneapi_prefetch.asciidoc[sycl_ext_oneapi_prefetch]
452
+ as illustrated in the example below. When no cache levels are
453
+ specified, the default behavior is to prefetch into the lowest level
454
+ cache (i.e. L1).
455
+
456
+ ```c++
457
+ using syclex = sycl::ext::oneapi::experimental;
458
+
459
+ bfloat16 *memA = malloc_shared<bfloat16>(M*K, q);
460
+ q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
461
+ sub_group sg = it.get_sub_group();
462
+ for (int k = 0; k < K; k += tileK) {
463
+ syclex::joint_matrix_prefetch<tM, tK>(sg, memA + tM * K + tK, K,
464
+ layout::row_major,
465
+ syclex::properties{syclex::prefetch_hint_L2});
466
+ ...
467
+ }
468
+ });
469
+ ```
470
+
375
471
=== Support for Machine Learning Types
376
472
Some devices support special matrix element types that are commonly
377
473
used in machine learning algorithms.
@@ -1035,4 +1131,6 @@ and Intel XMX
1035
1131
|8 |2023-10-05 |Mahmoud Moadeli |Add AMD Matrix Core supported combinations
1036
1132
|9 |2023-11-13 |Dounia Khaldi |Add Granite Rapids Intel AMX
1037
1133
supported combinations
1134
+ |9 |2023-12-04 |Dounia Khaldi |Add prefetch and `annotated_ptr`
1135
+ load/store overloads
1038
1136
|======================
0 commit comments