Skip to content

Commit b10c739

Browse files
yubingex007-a11yDoyleLi
authored andcommitted
[SYCL][Matrix] Add support for more types to joint_matrix_mad (intel#4486)
With this patch, more cases for C=A*B+C can be realized: 1. A is uint8, B is int8 2. A is int8, B is uint8 3. A is uint8, B is uint8
1 parent fe0ec06 commit b10c739

File tree

6 files changed

+58
-11
lines changed

6 files changed

+58
-11
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,42 @@ __spirv_JointMatrixMadINTEL(
5050
__spv::__spirv_JointMatrixINTEL<T2, M, N, LC, S> *C,
5151
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
5252

53+
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
54+
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
55+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
56+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
57+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
58+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
59+
__spirv_JointMatrixUUMadINTEL(
60+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
61+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
62+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
63+
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
64+
65+
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
66+
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
67+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
68+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
69+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
70+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
71+
__spirv_JointMatrixUSMadINTEL(
72+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
73+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
74+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
75+
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
76+
77+
template <typename T1, typename T2, typename T3, std::size_t M, std::size_t K,
78+
std::size_t N, __spv::MatrixLayout LA = __spv::MatrixLayout::RowMajor,
79+
__spv::MatrixLayout LB = __spv::MatrixLayout::RowMajor,
80+
__spv::MatrixLayout LC = __spv::MatrixLayout::RowMajor,
81+
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
82+
extern SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *
83+
__spirv_JointMatrixSUMadINTEL(
84+
__spv::__spirv_JointMatrixINTEL<T1, M, K, LA, S> *A,
85+
__spv::__spirv_JointMatrixINTEL<T2, K, N, LB, S> *B,
86+
__spv::__spirv_JointMatrixINTEL<T3, M, N, LC, S> *C,
87+
__spv::Scope::Flag Sc = __spv::Scope::Flag::Subgroup);
88+
5389
#ifndef __SPIRV_BUILTIN_DECLARATIONS__
5490
#error \
5591
"SPIR-V built-ins are not available. Please set -fdeclare-spirv-builtins flag."

sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,16 +160,27 @@ joint_matrix_store(Group sg,
160160
#endif // __SYCL_DEVICE_ONLY__
161161
}
162162

163-
template <typename Group, typename T1, typename T2, size_t M, size_t K,
164-
size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
163+
template <typename Group, typename T1, typename T2, typename T3, size_t M,
164+
size_t K, size_t N, matrix_layout LayoutA, matrix_layout LayoutB,
165165
matrix_layout LayoutC>
166-
inline __SYCL_ALWAYS_INLINE joint_matrix<T2, M, N, LayoutC, Group>
166+
inline __SYCL_ALWAYS_INLINE joint_matrix<T3, M, N, LayoutC, Group>
167167
joint_matrix_mad(Group sg, joint_matrix<T1, M, K, LayoutA, Group> &mA,
168-
joint_matrix<T1, K, N, LayoutB, Group> &mB,
169-
joint_matrix<T2, M, N, LayoutC, Group> &mC) {
168+
joint_matrix<T2, K, N, LayoutB, Group> &mB,
169+
joint_matrix<T3, M, N, LayoutC, Group> &mC) {
170170
#ifdef __SYCL_DEVICE_ONLY__
171-
joint_matrix<T2, M, N, LayoutC, Group> res(sg);
172-
res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
171+
joint_matrix<T3, M, N, LayoutC, Group> res(sg);
172+
if constexpr (std::is_same<T1, uint16_t>::value &&
173+
std::is_same<T2, uint16_t>::value &&
174+
std::is_same<T3, float>::value)
175+
res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
176+
else if constexpr (std::is_unsigned<T1>::value && std::is_unsigned<T2>::value)
177+
res.spvm = __spirv_JointMatrixUUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
178+
else if constexpr (std::is_signed<T1>::value && std::is_unsigned<T2>::value)
179+
res.spvm = __spirv_JointMatrixSUMadINTEL(mA.spvm, mB.spvm, mC.spvm);
180+
else if constexpr (std::is_unsigned<T1>::value && std::is_signed<T2>::value)
181+
res.spvm = __spirv_JointMatrixUSMadINTEL(mA.spvm, mB.spvm, mC.spvm);
182+
else
183+
res.spvm = __spirv_JointMatrixMadINTEL(mA.spvm, mB.spvm, mC.spvm);
173184
return res;
174185
#else
175186
(void)sg;

sycl/test/matrix/matrix-bf16-test-SG-16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
44
#include <iostream>

sycl/test/matrix/matrix-bf16-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
44
#include <iostream>

sycl/test/matrix/matrix-int8-test-SG-16.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
44
#include <iostream>

sycl/test/matrix/matrix-int8-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: %clangxx -march=sapphirerapids -fsycl -O2 %s -o %t.out
1+
// RUN: %clangxx -fsycl -O2 %s -o %t.out
22
#include <CL/sycl.hpp>
33
#if (SYCL_EXT_ONEAPI_MATRIX == 2)
44
#include <iostream>

0 commit comments

Comments
 (0)