@@ -26,29 +26,29 @@ namespace detail {
26
26
template <typename T> constexpr dpas_argument_type dpas_precision_from_type () {
27
27
// TODO: add support for tfloat32 here.
28
28
if constexpr (std::is_same_v<T, sycl::half>)
29
- return dpas_argument_type::FP16 ;
29
+ return dpas_argument_type::fp16 ;
30
30
else if constexpr (std::is_same_v<T,
31
31
sycl::ext::oneapi::experimental::bfloat16>)
32
- return dpas_argument_type::BF16 ;
32
+ return dpas_argument_type::bf16 ;
33
33
else if constexpr (std::is_same_v<T, unsigned char >)
34
- return dpas_argument_type::U8 ;
34
+ return dpas_argument_type::u8 ;
35
35
else if constexpr (__ESIMD_DNS::is_type<T, char , signed char >())
36
- return dpas_argument_type::S8 ;
36
+ return dpas_argument_type::s8 ;
37
37
else
38
38
return dpas_argument_type::Invalid;
39
39
}
40
40
41
41
template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision () {
42
- if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2 )
42
+ if constexpr (T == dpas_argument_type::u2 || T == dpas_argument_type::s2 )
43
43
return 2 ;
44
- else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4 )
44
+ else if constexpr (T == dpas_argument_type::u4 || T == dpas_argument_type::s4 )
45
45
return 4 ;
46
- else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8 )
46
+ else if constexpr (T == dpas_argument_type::u8 || T == dpas_argument_type::s8 )
47
47
return 8 ;
48
- else if constexpr (T == dpas_argument_type::BF16 ||
49
- T == dpas_argument_type::FP16 )
48
+ else if constexpr (T == dpas_argument_type::bf16 ||
49
+ T == dpas_argument_type::fp16 )
50
50
return 16 ;
51
- else if constexpr (T == dpas_argument_type::TF32 )
51
+ else if constexpr (T == dpas_argument_type::tf32 )
52
52
return 32 ;
53
53
else
54
54
return -1 ;
@@ -124,8 +124,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
124
124
static_assert (ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16 ),
125
125
" Execution size must be 8 or 16 for DPAS and 8 for DPASW." );
126
126
127
- if constexpr (APrecision == dpas_argument_type::FP16 ||
128
- BPrecision == dpas_argument_type::FP16 ) {
127
+ if constexpr (APrecision == dpas_argument_type::fp16 ||
128
+ BPrecision == dpas_argument_type::fp16 ) {
129
129
if constexpr (ExecutionSize == 8 ) {
130
130
static_assert (APrecision == BPrecision &&
131
131
__ESIMD_DNS::is_type<T, float >() &&
@@ -141,8 +141,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
141
141
" Result | C | B | A \n "
142
142
" f, hf | f, hf | hf | hf \n " );
143
143
}
144
- } else if constexpr (APrecision == dpas_argument_type::BF16 ||
145
- BPrecision == dpas_argument_type::BF16 ) {
144
+ } else if constexpr (APrecision == dpas_argument_type::bf16 ||
145
+ BPrecision == dpas_argument_type::bf16 ) {
146
146
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
147
147
if constexpr (ExecutionSize == 8 ) {
148
148
static_assert (APrecision == BPrecision &&
@@ -159,8 +159,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
159
159
" Result | C | B | A \n "
160
160
" f, bf | f, bf | bf | bf \n " );
161
161
}
162
- } else if constexpr (APrecision == dpas_argument_type::TF32 ||
163
- BPrecision == dpas_argument_type::TF32 ) {
162
+ } else if constexpr (APrecision == dpas_argument_type::tf32 ||
163
+ BPrecision == dpas_argument_type::tf32 ) {
164
164
static_assert (ExecutionSize == 16 ,
165
165
" tf32 type can be used only with ExecutionSize=16" );
166
166
static_assert (APrecision == BPrecision && std::is_same_v<T, float > &&
@@ -169,18 +169,18 @@ constexpr int verify_parameters_and_deduce_exec_size() {
169
169
" Result | C | B | A \n "
170
170
" f | f | tf32 | tf32 \n " );
171
171
} else {
172
- static_assert ((APrecision == dpas_argument_type::U2 ||
173
- APrecision == dpas_argument_type::S2 ||
174
- APrecision == dpas_argument_type::U4 ||
175
- APrecision == dpas_argument_type::S4 ||
176
- APrecision == dpas_argument_type::U8 ||
177
- APrecision == dpas_argument_type::S8 ) &&
178
- (BPrecision == dpas_argument_type::U2 ||
179
- BPrecision == dpas_argument_type::S2 ||
180
- BPrecision == dpas_argument_type::U4 ||
181
- BPrecision == dpas_argument_type::S4 ||
182
- BPrecision == dpas_argument_type::U8 ||
183
- BPrecision == dpas_argument_type::S8 ),
172
+ static_assert ((APrecision == dpas_argument_type::u2 ||
173
+ APrecision == dpas_argument_type::s2 ||
174
+ APrecision == dpas_argument_type::u4 ||
175
+ APrecision == dpas_argument_type::s4 ||
176
+ APrecision == dpas_argument_type::u8 ||
177
+ APrecision == dpas_argument_type::s8 ) &&
178
+ (BPrecision == dpas_argument_type::u2 ||
179
+ BPrecision == dpas_argument_type::s2 ||
180
+ BPrecision == dpas_argument_type::u4 ||
181
+ BPrecision == dpas_argument_type::s4 ||
182
+ BPrecision == dpas_argument_type::u8 ||
183
+ BPrecision == dpas_argument_type::s8 ),
184
184
" Unsupported DPAS types! The supported types are:\n "
185
185
" Result | C | B | A \n "
186
186
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n " );
@@ -221,7 +221,8 @@ __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
221
221
__ESIMD_NS::simd<int , ANCasted> ACasted = A.template bit_cast_view <int >();
222
222
__ESIMD_NS::simd<int , BNCasted> BCasted = B.template bit_cast_view <int >();
223
223
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
224
- return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T,
224
+ using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
225
+ return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, RawT,
225
226
CRawT, int , int , N, BNCasted, ANCasted>(
226
227
C.data (), BCasted.data (), ACasted.data ());
227
228
}
@@ -257,8 +258,9 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
257
258
258
259
constexpr int Info = (RepeatCount << 24 ) + (SystolicDepth << 16 ) +
259
260
((int )APrecision << 8 ) + (int )BPrecision;
261
+ using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
260
262
__ESIMD_NS::simd<T, ResultN> Result =
261
- __esimd_dpas_nosrc0<Info, T , int , int , ResultN, BNCasted, ANCasted>(
263
+ __esimd_dpas_nosrc0<Info, RawT , int , int , ResultN, BNCasted, ANCasted>(
262
264
BCasted.data (), ACasted.data ());
263
265
return Result;
264
266
}
@@ -289,9 +291,10 @@ __ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
289
291
__ESIMD_NS::simd<int , ANCasted> ACasted = A.template bit_cast_view <int >();
290
292
__ESIMD_NS::simd<int , BNCasted> BCasted = B.template bit_cast_view <int >();
291
293
294
+ using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
292
295
constexpr int Info = (RepeatCount << 24 ) + (SystolicDepth << 16 ) +
293
296
((int )APrecision << 8 ) + (int )BPrecision;
294
- return __esimd_dpasw<Info, T , int , int , N, BNCasted, ANCasted>(
297
+ return __esimd_dpasw<Info, RawT , int , int , N, BNCasted, ANCasted>(
295
298
C.data (), BCasted.data (), ACasted.data ());
296
299
}
297
300
@@ -325,10 +328,11 @@ auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
325
328
__ESIMD_NS::simd<int , ANCasted> ACasted = A.template bit_cast_view <int >();
326
329
__ESIMD_NS::simd<int , BNCasted> BCasted = B.template bit_cast_view <int >();
327
330
331
+ using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
328
332
constexpr int Info = (RepeatCount << 24 ) + (SystolicDepth << 16 ) +
329
333
((int )APrecision << 8 ) + (int )BPrecision;
330
334
__ESIMD_NS::simd<T, ResultN> Result =
331
- __esimd_dpasw_nosrc0<Info, T , int , int , ResultN, BNCasted, ANCasted>(
335
+ __esimd_dpasw_nosrc0<Info, RawT , int , int , ResultN, BNCasted, ANCasted>(
332
336
BCasted.data (), ACasted.data ());
333
337
return Result;
334
338
}
0 commit comments