Skip to content

Commit 04a222f

Browse files
authored
[SYCL][Matrix spec] Add joint_matrix_prefetch and overloads of load/store with annotated_ptr (#11473)
1 parent 16e06ff commit 04a222f

File tree

2 files changed

+115
-2
lines changed

2 files changed

+115
-2
lines changed

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_intel_matrix.asciidoc

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,14 +148,28 @@ template <typename Group, typename T, size_t Rows, size_t Cols,
148148
access::decorated IsDecorated>
149149
void joint_matrix_store(Group g,
150150
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
151-
multi_ptr<T, Space, IsDecorated> src, size_t stride);
151+
multi_ptr<T, Space, IsDecorated> dest, size_t stride);
152152

153153
template <typename Group, typename T, size_t Rows, size_t Cols,
154154
layout Layout, access::address_space Space,
155155
access::decorated IsDecorated>
156156
void joint_matrix_store(Group g,
157157
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
158-
multi_ptr<T, Space, IsDecorated> src, size_t stride);
158+
multi_ptr<T, Space, IsDecorated> dest, size_t stride);
159+
160+
template <typename Group, typename T, size_t Rows, size_t Cols,
161+
layout Layout, typename PropertyListT>
162+
void joint_matrix_store(Group g,
163+
const joint_matrix<Group, T, use::a, Rows, Cols, Layout> &res,
164+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
165+
size_t stride);
166+
167+
template <typename Group, typename T, size_t Rows, size_t Cols,
168+
layout Layout, typename PropertyListT>
169+
void joint_matrix_store(Group g,
170+
const joint_matrix<Group, T, use::b, Rows, Cols, Layout> &res,
171+
ext::oneapi::experimental::annotated_ptr<T, PropertyListT> dest,
172+
size_t stride);
159173

160174
} // namespace sycl::ext::intel::experimental::matrix
161175
```
@@ -327,6 +341,7 @@ q.submit([&](sycl::handler& cgh) {
327341
});
328342
q.wait();
329343
```
344+
330345
== Revision History
331346

332347
[frame="none",options="header"]

sycl/doc/extensions/experimental/sycl_ext_matrix/sycl_ext_oneapi_matrix.asciidoc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,23 @@ void joint_matrix_load(Group g,
228228
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
229229
multi_ptr<T2, Space, IsDecorated> src, size_t stride);
230230

231+
// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
232+
template <typename Group, typename T1, typename T2,
233+
size_t Rows, size_t Cols,
234+
typename PropertyListT>
235+
void joint_matrix_load(Group g,
236+
joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
237+
annotated_ptr<T2, PropertyListT> src, size_t stride, layout Layout);
238+
239+
// Only available when Layout != layout::dynamic
240+
// and when std::is_same_v<T1, std::remove_const_t<T2>>
241+
template <typename Group, typename T1, typename T2,
242+
size_t Rows, size_t Cols, use Use, layout Layout,
243+
typename PropertyListT>
244+
void joint_matrix_load(Group g,
245+
joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
246+
annotated_ptr<T2, PropertyListT> src, size_t stride);
247+
231248
} // namespace sycl::ext::oneapi::experimental::matrix
232249
```
233250

@@ -248,6 +265,33 @@ fashion. `stride` describes the number of elements between consecutive
248265
rows for the row major layout, or between columns for the column major
249266
layout.
250267

268+
The two last overloads of `joint_matrix_load` take
269+
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
270+
of `sycl::multi_ptr`. The property list associated with the
271+
`annotated_ptr` argument represents the compile-time constant
272+
properties for cache control included in the SYCL extenion
273+
link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
274+
as illustrated in the example below.
275+
276+
```c++
277+
using syclex = sycl::ext::oneapi::experimental;
278+
using syclintelex = sycl::ext::intel::experimental;
279+
280+
auto A_ptr = syclex::annotated_ptr{A,
281+
syclex::properties{syclintelex::read_hint<
282+
syclintelex::cache_control<syclintelex::cache_mode::cached,
283+
syclex::cache_level::L2>>}};
284+
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
285+
sub_group sg = it.get_sub_group();
286+
joint_matrix<sub_group, bfloat16, use::a, tM, tK, layout::row_major> tA;
287+
for (int k = 0; k < K; k += tileK) {
288+
// User specifies that this load will be cached to L2
289+
joint_matrix_load(sg, tA, A_ptr + sg_startx * tM * K + k, K);
290+
...
291+
}
292+
});
293+
```
294+
251295
==== Store
252296
```c++
253297
namespace sycl::ext::oneapi::experimental::matrix {
@@ -259,6 +303,12 @@ void joint_matrix_store(Group g,
259303
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
260304
multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);
261305

306+
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
307+
typename PropertyListT>
308+
void joint_matrix_store(Group g,
309+
const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
310+
annotated_ptr<T2, PropertyListT> dest, size_t stride, layout Layout);
311+
262312
} // namespace sycl::ext::oneapi::experimental::matrix
263313
```
264314
This function stores the data in the accumulator matrix from the
@@ -270,6 +320,11 @@ written in a row (`row_major`), column major (`col_major`)
270320
fashion. `stride` describes the number of elements between consecutive
271321
rows for the row major layout, or between columns for the column major layout.
272322

323+
The second overload of `joint_matrix_store` takes
324+
`sycl::ext::oneapi::experimental::annotated_ptr` as argument instead
325+
of `sycl::multi_ptr`. The property list associated with the
326+
`annotated_ptr` argument represents the compile-time constant
327+
properties for cache control included in the SYCL extenion link:../../proposed/sycl_ext_intel_cache_controls.asciidoc[sycl_ext_intel_cache_controls]
273328

274329
==== Multiply and Add
275330

@@ -372,6 +427,47 @@ joint_matrix_apply(sg, C, [=](T &x) {
372427
});
373428
```
374429

430+
==== Prefetch
431+
432+
```c++
433+
namespace sycl::ext::oneapi::experimental::matrix {
434+
435+
template <size_t Rows, size_t Cols, typename Group, typename T,
436+
typename Properties = empty_properties_t>
437+
void joint_matrix_prefetch(Group g, T* ptr, size_t stride, layout Layout,
438+
Properties properties = {});
439+
440+
} // namespace sycl::ext::oneapi::experimental::matrix
441+
```
442+
443+
`joint_matrix_prefetch` allows groups of work-items to cooperatively
444+
prefetch `Rows x Cols` elements in a 2d manner. This function is a group
445+
function, as defined in Section 4.17.3 of the core SYCL
446+
specification.
447+
448+
The level of cache targeted by `joint_matrix_prefetch` in the last
449+
argument is specified using the compile-time properties defined in the
450+
SYCL extension
451+
link:../../proposed/sycl_ext_oneapi_prefetch.asciidoc[sycl_ext_oneapi_prefetch]
452+
as illustrated in the example below. When no cache levels are
453+
specified, the default behavior is to prefetch into the lowest level
454+
cache (i.e. L1).
455+
456+
```c++
457+
using syclex = sycl::ext::oneapi::experimental;
458+
459+
bfloat16 *memA = malloc_shared<bfloat16>(M*K, q);
460+
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
461+
sub_group sg = it.get_sub_group();
462+
for (int k = 0; k < K; k += tileK) {
463+
syclex::joint_matrix_prefetch<tM, tK>(sg, memA + tM * K + tK, K,
464+
layout::row_major,
465+
syclex::properties{syclex::prefetch_hint_L2});
466+
...
467+
}
468+
});
469+
```
470+
375471
=== Support for Machine Learning Types
376472
Some devices support special matrix element types that are commonly
377473
used in machine learning algorithms.
@@ -1035,4 +1131,6 @@ and Intel XMX
10351131
|8 |2023-10-05 |Mahmoud Moadeli |Add AMD Matrix Core supported combinations
10361132
|9 |2023-11-13 |Dounia Khaldi |Add Granite Rapids Intel AMX
10371133
supported combinations
1134+
|9 |2023-12-04 |Dounia Khaldi |Add prefetch and `annotated_ptr`
1135+
load/store overloads
10381136
|======================

0 commit comments

Comments
 (0)