Skip to content

[ESIMD] Fix DPAS implementations accepting/returning fp16/bf16 #6891

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
#define __ESIMD_EMU_DNS sycl::ext::intel::esimd::emu::detail
#define __ESIMD_ENS sycl::ext::intel::experimental::esimd
#define __ESIMD_EDNS sycl::ext::intel::experimental::esimd::detail
#define __ESIMD_XMX_NS sycl::ext::intel::esimd::xmx
#define __ESIMD_XMX_DNS sycl::ext::intel::esimd::xmx::detail

#define __ESIMD_QUOTE1(m) #m
#define __ESIMD_QUOTE(m) __ESIMD_QUOTE1(m)
Expand Down
68 changes: 36 additions & 32 deletions sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,29 @@ namespace detail {
template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
// TODO: add support for tfloat32 here.
if constexpr (std::is_same_v<T, sycl::half>)
return dpas_argument_type::FP16;
return dpas_argument_type::fp16;
else if constexpr (std::is_same_v<T,
sycl::ext::oneapi::experimental::bfloat16>)
return dpas_argument_type::BF16;
return dpas_argument_type::bf16;
else if constexpr (std::is_same_v<T, unsigned char>)
return dpas_argument_type::U8;
return dpas_argument_type::u8;
else if constexpr (__ESIMD_DNS::is_type<T, char, signed char>())
return dpas_argument_type::S8;
return dpas_argument_type::s8;
else
return dpas_argument_type::Invalid;
}

template <dpas_argument_type T> constexpr int dpas_bitsize_from_precision() {
if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2)
if constexpr (T == dpas_argument_type::u2 || T == dpas_argument_type::s2)
return 2;
else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4)
else if constexpr (T == dpas_argument_type::u4 || T == dpas_argument_type::s4)
return 4;
else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8)
else if constexpr (T == dpas_argument_type::u8 || T == dpas_argument_type::s8)
return 8;
else if constexpr (T == dpas_argument_type::BF16 ||
T == dpas_argument_type::FP16)
else if constexpr (T == dpas_argument_type::bf16 ||
T == dpas_argument_type::fp16)
return 16;
else if constexpr (T == dpas_argument_type::TF32)
else if constexpr (T == dpas_argument_type::tf32)
return 32;
else
return -1;
Expand Down Expand Up @@ -124,8 +124,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16),
"Execution size must be 8 or 16 for DPAS and 8 for DPASW.");

if constexpr (APrecision == dpas_argument_type::FP16 ||
BPrecision == dpas_argument_type::FP16) {
if constexpr (APrecision == dpas_argument_type::fp16 ||
BPrecision == dpas_argument_type::fp16) {
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
__ESIMD_DNS::is_type<T, float>() &&
Expand All @@ -141,8 +141,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
" Result | C | B | A \n"
" f, hf | f, hf | hf | hf \n");
}
} else if constexpr (APrecision == dpas_argument_type::BF16 ||
BPrecision == dpas_argument_type::BF16) {
} else if constexpr (APrecision == dpas_argument_type::bf16 ||
BPrecision == dpas_argument_type::bf16) {
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
if constexpr (ExecutionSize == 8) {
static_assert(APrecision == BPrecision &&
Expand All @@ -159,8 +159,8 @@ constexpr int verify_parameters_and_deduce_exec_size() {
" Result | C | B | A \n"
" f, bf | f, bf | bf | bf \n");
}
} else if constexpr (APrecision == dpas_argument_type::TF32 ||
BPrecision == dpas_argument_type::TF32) {
} else if constexpr (APrecision == dpas_argument_type::tf32 ||
BPrecision == dpas_argument_type::tf32) {
static_assert(ExecutionSize == 16,
"tf32 type can be used only with ExecutionSize=16");
static_assert(APrecision == BPrecision && std::is_same_v<T, float> &&
Expand All @@ -169,18 +169,18 @@ constexpr int verify_parameters_and_deduce_exec_size() {
" Result | C | B | A \n"
" f | f | tf32 | tf32 \n");
} else {
static_assert((APrecision == dpas_argument_type::U2 ||
APrecision == dpas_argument_type::S2 ||
APrecision == dpas_argument_type::U4 ||
APrecision == dpas_argument_type::S4 ||
APrecision == dpas_argument_type::U8 ||
APrecision == dpas_argument_type::S8) &&
(BPrecision == dpas_argument_type::U2 ||
BPrecision == dpas_argument_type::S2 ||
BPrecision == dpas_argument_type::U4 ||
BPrecision == dpas_argument_type::S4 ||
BPrecision == dpas_argument_type::U8 ||
BPrecision == dpas_argument_type::S8),
static_assert((APrecision == dpas_argument_type::u2 ||
APrecision == dpas_argument_type::s2 ||
APrecision == dpas_argument_type::u4 ||
APrecision == dpas_argument_type::s4 ||
APrecision == dpas_argument_type::u8 ||
APrecision == dpas_argument_type::s8) &&
(BPrecision == dpas_argument_type::u2 ||
BPrecision == dpas_argument_type::s2 ||
BPrecision == dpas_argument_type::u4 ||
BPrecision == dpas_argument_type::s4 ||
BPrecision == dpas_argument_type::u8 ||
BPrecision == dpas_argument_type::s8),
"Unsupported DPAS types! The supported types are:\n"
" Result | C | B | A \n"
" ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n");
Expand Down Expand Up @@ -221,7 +221,8 @@ __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<CT, N> C,
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();
using CRawT = typename __ESIMD_NS::simd<CT, N>::raw_element_type;
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, T,
using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
return __esimd_dpas2<BPrecision, APrecision, SystolicDepth, RepeatCount, RawT,
CRawT, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}
Expand Down Expand Up @@ -257,8 +258,9 @@ auto dpas(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {

constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpas_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
__esimd_dpas_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}
Expand Down Expand Up @@ -289,9 +291,10 @@ __ESIMD_NS::simd<T, N> dpasw(__ESIMD_NS::simd<T, N> C,
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

using RawT = typename __ESIMD_NS::simd<T, N>::raw_element_type;
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
return __esimd_dpasw<Info, T, int, int, N, BNCasted, ANCasted>(
return __esimd_dpasw<Info, RawT, int, int, N, BNCasted, ANCasted>(
C.data(), BCasted.data(), ACasted.data());
}

Expand Down Expand Up @@ -325,10 +328,11 @@ auto dpasw(__ESIMD_NS::simd<BT, BN> B, __ESIMD_NS::simd<AT, AN> A) {
__ESIMD_NS::simd<int, ANCasted> ACasted = A.template bit_cast_view<int>();
__ESIMD_NS::simd<int, BNCasted> BCasted = B.template bit_cast_view<int>();

using RawT = typename __ESIMD_NS::simd<T, ResultN>::raw_element_type;
constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) +
((int)APrecision << 8) + (int)BPrecision;
__ESIMD_NS::simd<T, ResultN> Result =
__esimd_dpasw_nosrc0<Info, T, int, int, ResultN, BNCasted, ANCasted>(
__esimd_dpasw_nosrc0<Info, RawT, int, int, ResultN, BNCasted, ANCasted>(
BCasted.data(), ACasted.data());
return Result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,25 +420,25 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ)
}

inline constexpr __ESIMD_NS::uint
__esimd_dpas_bits_precision(__ESIMD_ENS::argument_type precisionType) {
return precisionType == __ESIMD_ENS::argument_type::TF32 ? 32
: precisionType == __ESIMD_ENS::argument_type::BF16 ||
precisionType == __ESIMD_ENS::argument_type::FP16
__esimd_dpas_bits_precision(__ESIMD_XMX_NS::dpas_argument_type precisionType) {
return precisionType == __ESIMD_XMX_NS::dpas_argument_type::tf32 ? 32
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::bf16 ||
precisionType == __ESIMD_XMX_NS::dpas_argument_type::fp16
? 16
: precisionType == __ESIMD_ENS::argument_type::S8 ||
precisionType == __ESIMD_ENS::argument_type::U8
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s8 ||
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u8
? 8
: precisionType == __ESIMD_ENS::argument_type::S4 ||
precisionType == __ESIMD_ENS::argument_type::U4
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u4
? 4
: precisionType == __ESIMD_ENS::argument_type::S2 ||
precisionType == __ESIMD_ENS::argument_type::U2
: precisionType == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
precisionType == __ESIMD_XMX_NS::dpas_argument_type::u2
? 2
: 1;
}

template <__ESIMD_ENS::argument_type src1_precision,
__ESIMD_ENS::argument_type src2_precision, int systolic_depth,
template <__ESIMD_XMX_NS::dpas_argument_type src1_precision,
__ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth,
int repeat_count, typename RT, typename T0, typename T1, typename T2,
__ESIMD_NS::uint SZ, __ESIMD_NS::uint N1, __ESIMD_NS::uint N2>
inline __ESIMD_DNS::vector_type_t<RT, SZ>
Expand All @@ -463,16 +463,16 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8));

uint32_t src1_signed =
src1_precision == __ESIMD_ENS::argument_type::S2 ||
src1_precision == __ESIMD_ENS::argument_type::S4 ||
src1_precision == __ESIMD_ENS::argument_type::S8
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s8
? 1
: 0;

uint32_t src2_signed =
src2_precision == __ESIMD_ENS::argument_type::S2 ||
src2_precision == __ESIMD_ENS::argument_type::S4 ||
src2_precision == __ESIMD_ENS::argument_type::S8
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 ||
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 ||
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s8
? 1
: 0;

Expand All @@ -484,19 +484,31 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
constexpr bool isPvc = SIMDSize == 16;

constexpr bool
pvcHfDest = isPvc && std::is_same<RT, __ESIMD_EMU_DNS::half>::value,
pvcBfDest = isPvc && std::is_same<RT, short>::value,
pvcHfDest = isPvc && std::is_same_v<RT, unsigned short> &&
src1_precision == __ESIMD_ENS::argument_type::FP16 &&
src2_precision == __ESIMD_ENS::argument_type::FP16,
pvcHfSrc0 = isPvc && std::is_same_v<T0, unsigned short> &&
src1_precision == __ESIMD_ENS::argument_type::FP16 &&
src2_precision == __ESIMD_ENS::argument_type::FP16,
pvcBfDest = isPvc && std::is_same_v<RT, unsigned short> &&
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
src2_precision == __ESIMD_ENS::argument_type::BF16,
pvcBfSrc0 = isPvc && std::is_same_v<T0, unsigned short> &&
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
src2_precision == __ESIMD_ENS::argument_type::BF16,
pvcBfOrHfDest = pvcBfDest || pvcHfDest,

pvcBfDestChecks = pvcBfDest &&
src1_precision == __ESIMD_ENS::argument_type::BF16 &&
src2_precision == __ESIMD_ENS::argument_type::BF16,
pvcBfDestChecks =
pvcBfDest &&
src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 &&
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16,

pvcHfDestChecks =
pvcHfDest && ((src1_precision == __ESIMD_ENS::argument_type::FP16 &&
src2_precision == __ESIMD_ENS::argument_type::FP16) ||
(src1_precision == __ESIMD_ENS::argument_type::BF16 &&
src2_precision == __ESIMD_ENS::argument_type::BF16)),
pvcHfDest &&
((src1_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16 &&
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16) ||
(src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 &&
src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16)),

destTypeChk =
(!pvcBfOrHfDest && __ESIMD_EMU_DNS::is_fp_or_dword_type<RT>::value) ||
Expand Down Expand Up @@ -547,9 +559,11 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
if (src0 != nullptr) {
auto src0El = src0[0][r * SIMDSize + n];

if (pvcBfDest) {
if (pvcBfSrc0) {
const auto tmp = (uint32_t)(src0El) << 16;
simdAcc[n] = reinterpret_cast<const TmpAccEl &>(tmp);
} else if (pvcHfSrc0) {
simdAcc[n] = reinterpret_cast<const __ESIMD_EMU_DNS::half &>(src0El);
} else
simdAcc[n] = src0El;
} else
Expand All @@ -566,7 +580,7 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
p = d + (s % src1_ops_per_dword) * ops_per_chan;
uint32_t extension_temp = false;

if (src2_precision == __ESIMD_ENS::argument_type::BF16) {
if (src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16) {
const auto s1 =
extract<uint32_t>(src1_el_bits, p * src1_el_bits,
src1[U * SIMDSize + n], extension_temp)
Expand All @@ -577,7 +591,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
<< 16;
simdAcc[n] += reinterpret_cast<const float &>(s2) *
reinterpret_cast<const float &>(s1);
} else if (src2_precision == __ESIMD_ENS::argument_type::FP16) {
} else if (src2_precision ==
__ESIMD_XMX_NS::dpas_argument_type::fp16) {
const auto s1 =
extract<short>(src1_el_bits, p * src1_el_bits,
src1[U * SIMDSize + n], extension_temp);
Expand Down Expand Up @@ -615,6 +630,10 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
}
retv[r * SIMDSize + n] =
static_cast<short>(reinterpret_cast<uint32_t &>(tmpUint) >> 16);
} else if constexpr (pvcHfDest) {
retv[r * SIMDSize + n] =
__ESIMD_EMU_DNS::satur<sycl::half>::saturate<TmpAccEl>(simdAcc[n],
sat1);
} else
retv[r * SIMDSize + n] =
__ESIMD_EMU_DNS::satur<RT>::template saturate<TmpAccEl>(simdAcc[n],
Expand All @@ -627,8 +646,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
}
#endif // #ifndef __SYCL_DEVICE_ONLY__

template <__ESIMD_ENS::argument_type src1_precision,
__ESIMD_ENS::argument_type src2_precision, int systolic_depth,
template <__ESIMD_XMX_NS::dpas_argument_type src1_precision,
__ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth,
int repeat_count, typename T, typename T0, typename T1, typename T2,
int N, int N1, int N2, int res_sign = std::is_signed_v<T>,
int acc_sign = std::is_signed_v<T0>>
Expand All @@ -654,10 +673,10 @@ __esimd_dpas_nosrc0(__ESIMD_DNS::vector_type_t<T1, N1> src1,
;
#else // !__SYCL_DEVICE_ONLY__
{
constexpr __ESIMD_ENS::argument_type src1_precision =
static_cast<__ESIMD_ENS::argument_type>(Info & 0xff);
constexpr __ESIMD_ENS::argument_type src2_precision =
static_cast<__ESIMD_ENS::argument_type>((Info >> 8) & 0xff);
constexpr __ESIMD_XMX_NS::dpas_argument_type src1_precision =
static_cast<__ESIMD_XMX_NS::dpas_argument_type>(Info & 0xff);
constexpr __ESIMD_XMX_NS::dpas_argument_type src2_precision =
static_cast<__ESIMD_XMX_NS::dpas_argument_type>((Info >> 8) & 0xff);
constexpr int systolic_depth = (Info >> 16) & 0xff;
constexpr int repeat_count = (Info >> 24) & 0xff;
return __esimd_dpas_inner<src1_precision, src2_precision, systolic_depth,
Expand Down
3 changes: 0 additions & 3 deletions sycl/include/sycl/ext/intel/experimental/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1796,9 +1796,6 @@ __SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()")
__ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(
__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
constexpr bool is_4xhf =
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
src1_precision == src2_precision && src1_precision == argument_type::FP16;

__ESIMD_NS::simd<T, N> result =
__ESIMD_NS::xmx::dpasw<systolic_depth, repeat_count, T, T1, T2,
Expand Down
Loading