Skip to content

Commit 8f0960e

Browse files
authored
[ESIMD] Fix DPAS implementations accepting/returning fp16/bf16 (#6891)
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent cc51a64 commit 8f0960e

File tree

5 files changed

+413
-119
lines changed

5 files changed

+413
-119
lines changed

sycl/include/sycl/ext/intel/esimd/detail/defines_elementary.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
#define __ESIMD_EMU_DNS sycl::ext::intel::esimd::emu::detail
5454
#define __ESIMD_ENS sycl::ext::intel::experimental::esimd
5555
#define __ESIMD_EDNS sycl::ext::intel::experimental::esimd::detail
56+
#define __ESIMD_XMX_NS sycl::ext::intel::esimd::xmx
57+
#define __ESIMD_XMX_DNS sycl::ext::intel::esimd::xmx::detail
5658

5759
#define __ESIMD_QUOTE1(m) #m
5860
#define __ESIMD_QUOTE(m) __ESIMD_QUOTE1(m)

sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -26,29 +26,29 @@ namespace detail {
2626
template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
2727
// TODO: add support for tfloat32 here.
2828
if constexpr (std::is_same_v<T, sycl::half>)
29-
return dpas_argument_type::FP16;
29+
return dpas_argument_type::fp16;
3030
else if constexpr (std::is_same_v<T,
3131
sycl::ext::oneapi::experimental::bfloat16>)
32-
return dpas_argument_type::BF16;
32+
return dpas_argument_type::bf16;
3333
else if constexpr (std::is_same_v<T, unsigned char>)
34-
return dpas_argument_type::U8;
34+
return dpas_argument_type::u8;
3535
else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
36-
return dpas_argument_type::S8;
36+
return dpas_argument_type::s8;
3737
else
3838
return dpas_argument_type::Invalid;
3939
}
4040

4141
template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() {
42-
if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2)
42+
if constexpr (T == dpas_argument_type::u2 || T == dpas_argument_type::s2)
4343
return 2;
44-
else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4)
44+
else if constexpr (T == dpas_argument_type::u4 || T == dpas_argument_type::s4)
4545
return 4;
46-
else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8)
46+
else if constexpr (T == dpas_argument_type::u8 || T == dpas_argument_type::s8)
4747
return 8;
48-
else if constexpr (T == dpas_argument_type::BF16 ||
49-
T == dpas_argument_type::FP16)
48+
else if constexpr (T == dpas_argument_type::bf16 ||
49+
T == dpas_argument_type::fp16)
5050
return 16;
51-
else if constexpr (T == dpas_argument_type::TF32)
51+
else if constexpr (T == dpas_argument_type::tf32)
5252
return 32;
5353
else
5454
return -1;
@@ -124,8 +124,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
124124
static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
125125
"Execution size must be 8 or 16 for DPAS and 8 for DPASW.");
126126

127-
if constexpr (APrecision == dpas_argument_type::FP16 ||
128-
BPrecision == dpas_argument_type::FP16) {
127+
if constexpr (APrecision == dpas_argument_type::fp16 ||
128+
BPrecision == dpas_argument_type::fp16) {
129129
if constexpr (ExecutionSize == 8) {
130130
static_assert(APrecision == BPrecision &&
131131
__ESIMD_DNS::is_type<T, float>() &&
@@ -141,8 +141,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
141141
" Result | C | B | A \n"
142142
" f, hf | f, hf | hf | hf \n");
143143
}
144-
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
145-
BPrecision == dpas_argument_type::BF16) {
144+
} else if constexpr (APrecision == dpas_argument_type::bf16 ||
145+
BPrecision == dpas_argument_type::bf16) {
146146
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
147147
if constexpr (ExecutionSize == 8) {
148148
static_assert(APrecision == BPrecision &&
@@ -159,8 +159,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
159159
" Result | C | B | A \n"
160160
" f, bf | f, bf | bf | bf \n");
161161
}
162-
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
163-
BPrecision == dpas_argument_type::TF32) {
162+
} else if constexpr (APrecision == dpas_argument_type::tf32 ||
163+
BPrecision == dpas_argument_type::tf32) {
164164
static_assert(ExecutionSize == 16,
165165
"tf32 type can be used only with ExecutionSize=16");
166166
static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
@@ -169,18 +169,18 @@ constexpr int verify_parameters_and_deduce_exec_size() {
169169
" Result | C | B | A \n"
170170
" f | f | tf32 | tf32 \n");
171171
} else {
172-
static_assert((APrecision == dpas_argument_type::U2 ||
173-
APrecision == dpas_argument_type::S2 ||
174-
APrecision == dpas_argument_type::U4 ||
175-
APrecision == dpas_argument_type::S4 ||
176-
APrecision == dpas_argument_type::U8 ||
177-
APrecision == dpas_argument_type::S8) &&
178-
(BPrecision == dpas_argument_type::U2 ||
179-
BPrecision == dpas_argument_type::S2 ||
180-
BPrecision == dpas_argument_type::U4 ||
181-
BPrecision == dpas_argument_type::S4 ||
182-
BPrecision == dpas_argument_type::U8 ||
183-
BPrecision == dpas_argument_type::S8),
172+
static_assert((APrecision == dpas_argument_type::u2 ||
173+
APrecision == dpas_argument_type::s2 ||
174+
APrecision == dpas_argument_type::u4 ||
175+
APrecision == dpas_argument_type::s4 ||
176+
APrecision == dpas_argument_type::u8 ||
177+
APrecision == dpas_argument_type::s8) &&
178+
(BPrecision == dpas_argument_type::u2 ||
179+
BPrecision == dpas_argument_type::s2 ||
180+
BPrecision == dpas_argument_type::u4 ||
181+
BPrecision == dpas_argument_type::s4 ||
182+
BPrecision == dpas_argument_type::u8 ||
183+
BPrecision == dpas_argument_type::s8),
184184
"Unsupported DPAS types! The supported types are:\n"
185185
" Result | C | B | A \n"
186186
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
@@ -221,7 +221,8 @@ __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
221221
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
222222
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
223223
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
224-
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T,
224+
using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
225+
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, RawT,
225226
CRawT, int, int, N, BNCasted, ANCasted>(
226227
C.data(), BCasted.data(), ACasted.data());
227228
}
@@ -257,8 +258,9 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
257258

258259
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
259260
((int)APrecision << 8) + (int)BPrecision;
261+
using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
260262
__ESIMD_NS::simd<T, ResultN> Result =
261-
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
263+
__esimd_dpas_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
262264
BCasted.data(), ACasted.data());
263265
return Result;
264266
}
@@ -289,9 +291,10 @@ __ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
289291
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
290292
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
291293

294+
using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
292295
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
293296
((int)APrecision << 8) + (int)BPrecision;
294-
return __esimd_dpasw<Info, T, int, int, N, BNCasted, ANCasted>(
297+
return __esimd_dpasw<Info, RawT, int, int, N, BNCasted, ANCasted>(
295298
C.data(), BCasted.data(), ACasted.data());
296299
}
297300

@@ -325,10 +328,11 @@ auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
325328
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
326329
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
327330

331+
using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
328332
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
329333
((int)APrecision << 8) + (int)BPrecision;
330334
__ESIMD_NS::simd<T, ResultN> Result =
331-
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
335+
__esimd_dpasw_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
332336
BCasted.data(), ACasted.data());
333337
return Result;
334338
}

sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -420,25 +420,25 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
420420
}
421421

422422
inline constexpr __ESIMD_NS::uint
423-
__esimd_dpas_bits_precision(__ESIMD_ENS::argument_type precisionType) {
424-
return precisionType == __ESIMD_ENS::argument_type::TF32 ? 32
425-
: precisionType == __ESIMD_ENS::argument_type::BF16 ||
426-
precisionType == __ESIMD_ENS::argument_type::FP16
423+
__esimd_dpas_bits_precision(__ESIMD_XMX_NS::dpas_argument_type precisionType) {
424+
return precisionType == __ESIMD_XMX_NS::dpas_argument_type::tf32 ? 32
425+
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::bf16 ||
426+
precisionType == __ESIMD_XMX_NS::dpas_argument_type::fp16
427427
? 16
428-
: precisionType == __ESIMD_ENS::argument_type::S8 ||
429-
precisionType == __ESIMD_ENS::argument_type::U8
428+
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s8 ||
429+
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u8
430430
? 8
431-
: precisionType == __ESIMD_ENS::argument_type::S4 ||
432-
precisionType == __ESIMD_ENS::argument_type::U4
431+
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
432+
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u4
433433
? 4
434-
: precisionType == __ESIMD_ENS::argument_type::S2 ||
435-
precisionType == __ESIMD_ENS::argument_type::U2
434+
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
435+
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u2
436436
? 2
437437
: 1;
438438
}
439439

440-
template <__ESIMD_ENS::argument_type src1_precision,
441-
__ESIMD_ENS::argument_type src2_precision, int systolic_depth,
440+
template <__ESIMD_XMX_NS::dpas_argument_type src1_precision,
441+
__ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth,
442442
int repeat_count, typename RT, typename T0, typename T1, typename T2,
443443
__ESIMD_NS::uint SZ, __ESIMD_NS::uint N1, __ESIMD_NS::uint N2>
444444
inline __ESIMD_DNS::vector_type_t<RT, SZ>
@@ -463,16 +463,16 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
463463
std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8));
464464

465465
uint32_t src1_signed =
466-
src1_precision == __ESIMD_ENS::argument_type::S2 ||
467-
src1_precision == __ESIMD_ENS::argument_type::S4 ||
468-
src1_precision == __ESIMD_ENS::argument_type::S8
466+
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
467+
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
468+
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s8
469469
? 1
470470
: 0;
471471

472472
uint32_t src2_signed =
473-
src2_precision == __ESIMD_ENS::argument_type::S2 ||
474-
src2_precision == __ESIMD_ENS::argument_type::S4 ||
475-
src2_precision == __ESIMD_ENS::argument_type::S8
473+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
474+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
475+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s8
476476
? 1
477477
: 0;
478478

@@ -484,19 +484,31 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
484484
constexpr bool isPvc = SIMDSize == 16;
485485

486486
constexpr bool
487-
pvcHfDest = isPvc && std::is_same<RT, __ESIMD_EMU_DNS::half>::value,
488-
pvcBfDest = isPvc && std::is_same<RT, short>::value,
487+
pvcHfDest = isPvc && std::is_same_v<RT, unsigned short> &&
488+
src1_precision == __ESIMD_ENS::argument_type::FP16 &&
489+
src2_precision == __ESIMD_ENS::argument_type::FP16,
490+
pvcHfSrc0 = isPvc && std::is_same_v<T0, unsigned short> &&
491+
src1_precision == __ESIMD_ENS::argument_type::FP16 &&
492+
src2_precision == __ESIMD_ENS::argument_type::FP16,
493+
pvcBfDest = isPvc && std::is_same_v<RT, unsigned short> &&
494+
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
495+
src2_precision == __ESIMD_ENS::argument_type::BF16,
496+
pvcBfSrc0 = isPvc && std::is_same_v<T0, unsigned short> &&
497+
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
498+
src2_precision == __ESIMD_ENS::argument_type::BF16,
489499
pvcBfOrHfDest = pvcBfDest || pvcHfDest,
490500

491-
pvcBfDestChecks = pvcBfDest &&
492-
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
493-
src2_precision == __ESIMD_ENS::argument_type::BF16,
501+
pvcBfDestChecks =
502+
pvcBfDest &&
503+
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 &&
504+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16,
494505

495506
pvcHfDestChecks =
496-
pvcHfDest && ((src1_precision == __ESIMD_ENS::argument_type::FP16 &&
497-
src2_precision == __ESIMD_ENS::argument_type::FP16) ||
498-
(src1_precision == __ESIMD_ENS::argument_type::BF16 &&
499-
src2_precision == __ESIMD_ENS::argument_type::BF16)),
507+
pvcHfDest &&
508+
((src1_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16 &&
509+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16) ||
510+
(src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 &&
511+
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16)),
500512

501513
destTypeChk =
502514
(!pvcBfOrHfDest && __ESIMD_EMU_DNS::is_fp_or_dword_type<RT>::value) ||
@@ -547,9 +559,11 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
547559
if (src0 != nullptr) {
548560
auto src0El = src0[0][r * SIMDSize + n];
549561

550-
if (pvcBfDest) {
562+
if (pvcBfSrc0) {
551563
const auto tmp = (uint32_t)(src0El) << 16;
552564
simdAcc[n] = reinterpret_cast<const TmpAccEl &>(tmp);
565+
} else if (pvcHfSrc0) {
566+
simdAcc[n] = reinterpret_cast<const __ESIMD_EMU_DNS::half &>(src0El);
553567
} else
554568
simdAcc[n] = src0El;
555569
} else
@@ -566,7 +580,7 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
566580
p = d + (s % src1_ops_per_dword) * ops_per_chan;
567581
uint32_t extension_temp = false;
568582

569-
if (src2_precision == __ESIMD_ENS::argument_type::BF16) {
583+
if (src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16) {
570584
const auto s1 =
571585
extract<uint32_t>(src1_el_bits, p * src1_el_bits,
572586
src1[U * SIMDSize + n], extension_temp)
@@ -577,7 +591,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
577591
<< 16;
578592
simdAcc[n] += reinterpret_cast<const float &>(s2) *
579593
reinterpret_cast<const float &>(s1);
580-
} else if (src2_precision == __ESIMD_ENS::argument_type::FP16) {
594+
} else if (src2_precision ==
595+
__ESIMD_XMX_NS::dpas_argument_type::fp16) {
581596
const auto s1 =
582597
extract<short>(src1_el_bits, p * src1_el_bits,
583598
src1[U * SIMDSize + n], extension_temp);
@@ -615,6 +630,10 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
615630
}
616631
retv[r * SIMDSize + n] =
617632
static_cast<short>(reinterpret_cast<uint32_t &>(tmpUint) >> 16);
633+
} else if constexpr (pvcHfDest) {
634+
retv[r * SIMDSize + n] =
635+
__ESIMD_EMU_DNS::satur<sycl::half>::saturate<TmpAccEl>(simdAcc[n],
636+
sat1);
618637
} else
619638
retv[r * SIMDSize + n] =
620639
__ESIMD_EMU_DNS::satur<RT>::template saturate<TmpAccEl>(simdAcc[n],
@@ -627,8 +646,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
627646
}
628647
#endif // #ifndef __SYCL_DEVICE_ONLY__
629648

630-
template <__ESIMD_ENS::argument_type src1_precision,
631-
__ESIMD_ENS::argument_type src2_precision, int systolic_depth,
649+
template <__ESIMD_XMX_NS::dpas_argument_type src1_precision,
650+
__ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth,
632651
int repeat_count, typename T, typename T0, typename T1, typename T2,
633652
int N, int N1, int N2, int res_sign = std::is_signed_v<T>,
634653
int acc_sign = std::is_signed_v<T0>>
@@ -654,10 +673,10 @@ __esimd_dpas_nosrc0(__ESIMD_DNS::vector_type_t<T1, N1> src1,
654673
;
655674
#else // !__SYCL_DEVICE_ONLY__
656675
{
657-
constexpr __ESIMD_ENS::argument_type src1_precision =
658-
static_cast<__ESIMD_ENS::argument_type>(Info & 0xff);
659-
constexpr __ESIMD_ENS::argument_type src2_precision =
660-
static_cast<__ESIMD_ENS::argument_type>((Info >> 8) & 0xff);
676+
constexpr __ESIMD_XMX_NS::dpas_argument_type src1_precision =
677+
static_cast<__ESIMD_XMX_NS::dpas_argument_type>(Info & 0xff);
678+
constexpr __ESIMD_XMX_NS::dpas_argument_type src2_precision =
679+
static_cast<__ESIMD_XMX_NS::dpas_argument_type>((Info >> 8) & 0xff);
661680
constexpr int systolic_depth = (Info >> 16) & 0xff;
662681
constexpr int repeat_count = (Info >> 24) & 0xff;
663682
return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,

sycl/include/sycl/ext/intel/experimental/esimd/math.hpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1796,9 +1796,6 @@ __SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()")
17961796
__ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(
17971797
__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
17981798
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
1799-
constexpr bool is_4xhf =
1800-
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
1801-
src1_precision == src2_precision && src1_precision == argument_type::FP16;
18021799

18031800
__ESIMD_NS::simd<T, N> result =
18041801
__ESIMD_NS::xmx::dpasw<systolic_depth, repeat_count, T, T1, T2,

0 commit comments

Comments
 (0)