From 7e9875adcd58f3f2133cb7275dc400f0ce3ef276 Mon Sep 17 00:00:00 2001 From: Vyacheslav N Klochkov Date: Wed, 14 Sep 2022 23:22:33 -0700 Subject: [PATCH 1/2] [ESIMD] Implement the new non-experimental low-level API for DPAS The new DPAS API are added to the new esimd::xmx (Xe Matrix eXtension) namespace. The old/experimental DPAS API is marked as deprecated and now it simply calls the new DPAS API. The DPAS emulation sequences has got the automatic detection of the execution size instead of being defined through the macro ESIMD_XE_HPC. Signed-off-by: Vyacheslav N Klochkov --- sycl/include/sycl/ext/intel/esimd.hpp | 1 + .../sycl/ext/intel/esimd/xmx/common.hpp | 47 +++ .../include/sycl/ext/intel/esimd/xmx/dpas.hpp | 310 +++++++++++++++ .../ext/intel/experimental/esimd/common.hpp | 17 +- .../experimental/esimd/detail/math_intrin.hpp | 15 +- .../ext/intel/experimental/esimd/math.hpp | 368 ++---------------- 6 files changed, 403 insertions(+), 355 deletions(-) create mode 100644 sycl/include/sycl/ext/intel/esimd/xmx/common.hpp create mode 100644 sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp diff --git a/sycl/include/sycl/ext/intel/esimd.hpp b/sycl/include/sycl/ext/intel/esimd.hpp index 19398ec0abd0d..021a4fc26120f 100644 --- a/sycl/include/sycl/ext/intel/esimd.hpp +++ b/sycl/include/sycl/ext/intel/esimd.hpp @@ -85,6 +85,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp new file mode 100644 index 0000000000000..45b1f7d5155b5 --- /dev/null +++ b/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp @@ -0,0 +1,47 @@ +//==-------------- xmx/common.hpp - DPC++ Explicit SIMD API ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Explicit SIMD API types used in ESIMD Intel Xe Matrix eXtension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext::intel::esimd::xmx { + +enum class dpas_argument_type { + Invalid = 0, + u1 = 1, // unsigned 1 bit + U1 __SYCL_DEPRECATED("use u1") = u1, + s1 = 2, // signed 1 bit + S1 __SYCL_DEPRECATED("use s1") = s1, + u2 = 3, // unsigned 2 bits + U2 __SYCL_DEPRECATED("use u2") = u2, + s2 = 4, // signed 2 bits + S2 __SYCL_DEPRECATED("use s2") = s2, + u4 = 5, // unsigned 4 bits + U4 __SYCL_DEPRECATED("use u4") = u4, + s4 = 6, // signed 4 bits + S4 __SYCL_DEPRECATED("use s4") = s4, + u8 = 7, // unsigned 8 bits + U8 __SYCL_DEPRECATED("use u8") = u8, + s8 = 8, // signed 8 bits + S8 __SYCL_DEPRECATED("use s8") = s8, + bf16 = 9, // bfloat 16 + BF16 __SYCL_DEPRECATED("use bf16") = bf16, + fp16 = 10, // half float + FP16 __SYCL_DEPRECATED("use fp16") = fp16, + tf32 = 12, // tensorfloat 32 + TF32 __SYCL_DEPRECATED("use tf32") = tf32 +}; + +} // namespace ext::intel::esimd::xmx +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp new file mode 100644 index 0000000000000..ec98544ad0be6 --- /dev/null +++ b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp @@ -0,0 +1,310 @@ +//==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Explicit SIMD API for DPAS Intel Xe Matrix eXtension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + +namespace ext::intel::esimd::xmx { + +namespace detail { + +template constexpr dpas_argument_type dpas_precision_from_type() { + // TODO: add support for tfloat32 here. + if constexpr (std::is_same_v) + return dpas_argument_type::FP16; + else if constexpr (std::is_same_v) + return dpas_argument_type::BF16; + else if constexpr (std::is_same_v) + return dpas_argument_type::U8; + else if constexpr (__ESIMD_DNS::is_type()) + return dpas_argument_type::S8; + else + return dpas_argument_type::Invalid; +} + +template constexpr int dpas_bitsize_from_precision() { + if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2) + return 2; + else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4) + return 4; + else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8) + return 8; + else if constexpr (T == dpas_argument_type::BF16 || + T == dpas_argument_type::FP16) + return 16; + else if constexpr (T == dpas_argument_type::TF32) + return 32; + else + return -1; +} + +template +constexpr void verify_repeat_count() { + static_assert(RepeatCount >= 1 && RepeatCount <= 8, + "Repeat count must be within 1 to 8 range"); + + if constexpr (IsDPASW && RepeatCount != 8) { + static_assert(!(AElemBitSize == 2 && BElemBitSize > 4), + "Unsupported repeat count for DPASW operation"); + + static_assert( + RepeatCount == 4 || + (AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)), + "Unsupported repeat count for DPASW operation"); + } +} + +template +constexpr int verify_parameters_and_deduce_exec_size() { + + static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8"); + static_assert( + APrecision != dpas_argument_type::Invalid && + BPrecision != dpas_argument_type::Invalid, + "The types of dpas arguments are either incorrect or cannot be deduced." + "Fix the types and/or explicitly specify them."); + + constexpr int AElemBitSize = dpas_bitsize_from_precision(); + constexpr int BElemBitSize = dpas_bitsize_from_precision(); + static_assert(AElemBitSize != -1 && BElemBitSize != -1, + "Cannot deduce element size of input arguments"); + verify_repeat_count(); + + constexpr int OpsPerChannel = + std::min(32 / std::max(AElemBitSize, BElemBitSize), 8); + + // A(_Mx_K) * B(_Kx_N) + C(_Mx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int _M = RepeatCount; + constexpr int _K = SystolicDepth * OpsPerChannel; + + // Compute _N (aka ExecutionSize) from the matrix B. + // It has _K*_N elements of BPrecision type, and BN elements of BT type + // hold those _K*_N*BPrecision bits, which let's us compute _N. + constexpr int BMatrixBitSize = sizeof(BT) * BN * 8; + constexpr int BNumElems = BMatrixBitSize / BElemBitSize; + constexpr int _N = BNumElems / _K; + static_assert(_K * _N == BNumElems, "Cannot deduce the execution size."); + + // Now verify that AN elements of AT type hold exactly _M*_K elements + // of APrecision type/size. Similarly for B: BN elements of BT type must + // hold _K*_N elements of BPrecision type/size. + // DPASW accepts 2x less expected AN elements than regular DPAS. + constexpr int AFactorForDPASW = IsDPASW ? 2 : 1; + static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW, + "The first matrix multiplier has wrong size."); + static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8, + "The second matrix multiplier has wrong size."); + + // Execution size may be 8 or 16 depending on the target device. + // User must check if used execution size is supported before calling DPAS. + constexpr int ExecutionSize = _N; + + static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16), + "Execution size must be 8 or 16 for DPAS and 8 for DPASW."); + + if constexpr (APrecision == dpas_argument_type::FP16 || + BPrecision == dpas_argument_type::FP16) { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, hf | f, hf | hf | hf \n"); + } else if constexpr (APrecision == dpas_argument_type::BF16 || + BPrecision == dpas_argument_type::BF16) { + using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, bf | f, bf | bf | bf \n"); + } else if constexpr (APrecision == dpas_argument_type::TF32 || + BPrecision == dpas_argument_type::TF32) { + static_assert(APrecision == BPrecision && std::is_same_v && + std::is_same_v, + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | tf32 | tf32 \n"); + } else { + static_assert((APrecision == dpas_argument_type::U2 || + APrecision == dpas_argument_type::S2 || + APrecision == dpas_argument_type::U4 || + APrecision == dpas_argument_type::S4 || + APrecision == dpas_argument_type::U8 || + APrecision == dpas_argument_type::S8) && + (BPrecision == dpas_argument_type::U2 || + BPrecision == dpas_argument_type::S2 || + BPrecision == dpas_argument_type::U4 || + BPrecision == dpas_argument_type::S4 || + BPrecision == dpas_argument_type::U8 || + BPrecision == dpas_argument_type::S8), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n"); + } + return ExecutionSize; +} + +} // namespace detail + +/// @defgroup sycl_esimd_xmx_systolic_array_api Systolic Array APIs. +/// APIs below are used to implement dot product accumulate systolic functions +/// @ingroup sycl_esimd + +/// @addtogroup sycl_esimd_xmx_systolic_array_api +/// @{ +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = C + A x B; +/// @param C represents DPAS accumulator operand. +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT, + typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int N, int BN, int AN> +__ESIMD_NS::simd dpas(__ESIMD_NS::simd C, + __ESIMD_NS::simd B, + __ESIMD_NS::simd A) { + (void)detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN, + AN>(); + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + using CRawT = typename __ESIMD_NS::simd::raw_element_type; + return __esimd_dpas2( + C.data(), BCasted.data(), ACasted.data()); +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = A x B; +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int BN, int AN> +auto dpas(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { + + constexpr int ExecutionSize = + detail::verify_parameters_and_deduce_exec_size(); + // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + C(_Mx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int ResultN = RepeatCount * ExecutionSize; + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + return __esimd_dpas_nosrc0( + BCasted.data(), ACasted.data()); +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = C + A x B; +/// @param C represents DPAS accumulator operand. +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int N, int BN, int AN> +__ESIMD_NS::simd dpasw(__ESIMD_NS::simd C, + __ESIMD_NS::simd B, + __ESIMD_NS::simd A) { + + constexpr bool IsDPASW = true; + (void)detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, + IsDPASW>(); + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + return __esimd_dpasw( + C.data(), BCasted.data(), ACasted.data()); +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = A x B; +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int N, int BN, int AN> +__ESIMD_NS::simd dpasw(__ESIMD_NS::simd B, + __ESIMD_NS::simd A) { + + constexpr bool IsDPASW = true; + (void)detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, + IsDPASW>(); + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + return __esimd_dpasw_nosrc0( + BCasted.data(), ACasted.data()); +} + +/// @} sycl_esimd_xmx_systolic_array_api + +} // namespace ext::intel::esimd::xmx +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp index 28fef416acbbc..1acbc9788502e 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -23,19 +24,9 @@ namespace ext::intel::experimental::esimd { /// @addtogroup sycl_esimd_core /// @{ -enum class argument_type { - U1 = 1, // unsigned 1 bit - S1 = 2, // signed 1 bit - U2 = 3, // unsigned 2 bits - S2 = 4, // signed 2 bits - U4 = 5, // unsigned 4 bits - S4 = 6, // signed 4 bits - U8 = 7, // unsigned 8 bits - S8 = 8, // signed 8 bits - BF16 = 9, // bfloat 16 - FP16 = 10, // half float - TF32 = 12 // tensorfloat 32 -}; +using argument_type + __SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas_argument_type") = + __ESIMD_NS::xmx::dpas_argument_type; /// The scope that lsc_fence operation should apply to /// Supported platforms: DG2, PVC diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp index aab438ba698fc..3498ac7c70d7b 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp @@ -13,6 +13,8 @@ /// @cond ESIMD_DETAIL #include +#include +#include #include #define __ESIMD_raw_vec_t(T, SZ) \ @@ -474,13 +476,12 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, ? 1 : 0; -#if defined(ESIMD_XE_HPC) || defined(ESIMD_XE_HPG) - constexpr bool isPvc = true; - constexpr size_t SIMDSize = 16; -#else - constexpr bool isPvc = false; - constexpr size_t SIMDSize = 8; -#endif + constexpr uint32_t src1_vec_bit_size = sizeof(T1) * N1 * 8; + constexpr uint32_t src1_num_elem = src1_vec_bit_size / src1_el_bits; + constexpr size_t SIMDSize = src1_num_elem / (systolic_depth * ops_per_chan); + static_assert(SIMDSize == 8 || SIMDSize == 16, + "Execution size must be 8 or 16"); + constexpr bool isPvc = SIMDSize == 16; constexpr bool pvcHfDest = isPvc && std::is_same::value, diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp index 26778a4ef3017..ad2bc78578259 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp @@ -1667,78 +1667,6 @@ __ESIMD_NS::simd dp4(__ESIMD_NS::simd v1, /// @} sycl_esimd_math -/// @cond ESIMD_DETAIL -// dpas helpers -namespace detail { - -enum class dpas_ops_per_channel : unsigned { - OP1 = 1u, - OP2 = 2u, - OP4 = 4u, - OP8 = 8u, - INVALID = 0xffffffffu -}; -constexpr dpas_ops_per_channel -get_ops_per_channel(argument_type src1_precision, - argument_type src2_precision) { - if ((src1_precision == argument_type::U8) || - (src1_precision == argument_type::S8)) { - if ((src2_precision == argument_type::U8) || - (src2_precision == argument_type::S8) || - (src2_precision == argument_type::U4) || - (src2_precision == argument_type::S4) || - (src2_precision == argument_type::U2) || - (src2_precision == argument_type::S2)) { - return dpas_ops_per_channel::OP4; - } - } else if ((src1_precision == argument_type::U4) || - (src1_precision == argument_type::S4) || - (src1_precision == argument_type::U2) || - (src1_precision == argument_type::S2)) { - if ((src2_precision == argument_type::U8) || - (src2_precision == argument_type::S8)) { - return dpas_ops_per_channel::OP4; - } else if ((src2_precision == argument_type::U4) || - (src2_precision == argument_type::S4) || - (src2_precision == argument_type::U2) || - (src2_precision == argument_type::S2)) { - return dpas_ops_per_channel::OP8; - } - } else if ((src1_precision == argument_type::BF16) && - (src2_precision == argument_type::BF16)) { - return dpas_ops_per_channel::OP2; - } else if ((src1_precision == argument_type::FP16) && - (src2_precision == argument_type::FP16)) { - return dpas_ops_per_channel::OP2; - } else if ((src1_precision == argument_type::TF32) && - (src2_precision == argument_type::TF32)) { - return dpas_ops_per_channel::OP1; - } - return dpas_ops_per_channel::INVALID; -} - -constexpr unsigned get_precision_bits(argument_type src_precision) { - if ((src_precision == argument_type::U8) || - (src_precision == argument_type::S8)) { - return 8; - } else if ((src_precision == argument_type::U4) || - (src_precision == argument_type::S4)) { - return 4; - } else if ((src_precision == argument_type::U2) || - (src_precision == argument_type::S2)) { - return 2; - } else if ((src_precision == argument_type::BF16) || - (src_precision == argument_type::FP16)) { - return 16; - } else if (src_precision == argument_type::TF32) { - return 32; - } - return 0; -} - -} // namespace detail -/// @endcond ESIMD_DETAIL - /// @defgroup sycl_esimd_systolic_array_api Systolic Array APIs. /// APIs below are used to implement dot product accumulate systolic functions /// @ingroup sycl_esimd @@ -1759,120 +1687,14 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - // types: dst, src0, src1, src2 - // ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 - constexpr bool check_integer = - detail::is_one_of_v && - detail::is_one_of_v && - detail::is_one_of_enum_v && - detail::is_one_of_enum_v; - - // TODO: Allow the maximum possible combination of types and also allow - // both execution sizes (8 and 16). That will give user all the control - // over the DPAS functionality without the need to define macros like - // ESIMD_XE_HPC. In this case it's user's responsibility to dispatch the code - // and use DPAS with types supported by the target device. - // From ESIMD compiler side the additional compile-time help/convenience may - // be provided via using optional target-specific macros to enforce - // verification of arguments and returns at compilation time. -#if defined(ESIMD_XE_HPC) - // f, bf | f, bf | bf | bf - constexpr bool check_bf16 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::BF16 && - src2_precision == argument_type::BF16; - - // f,hf | f, hf | hf | hf - constexpr bool check_hf = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::FP16 && - src1_precision == argument_type::FP16; - - // f | f | tf32 | tf32 - constexpr bool check_tf32 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::TF32 && - src2_precision == argument_type::TF32; - - constexpr bool check_passed = - (check_integer || check_hf || check_bf16 || check_tf32); - static_assert(check_passed, - "unsupported dpas type! The supported types are:\n" - " dst | src0 | src1 | src2 \n" - " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n" - " f, bf | f, bf | bf | bf \n" - " f, hf | f, hf | hf | hf \n" - " f | f | tf32 | tf32 \n"); - - static_assert((N == 16 * repeat_count), "Execution size on PVC must be 16"); -#else // else defined(ESIMD_XE_HPC) - // f | f | bf | bf - constexpr bool check_bf16 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::BF16 && - src2_precision == argument_type::BF16; - - // f | f | hf | hf - constexpr bool check_hf = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::FP16 && - src1_precision == argument_type::FP16; - - constexpr bool check_passed = (check_integer || check_hf || check_bf16); - static_assert(check_passed, - "unsupported dpas type! The supported types are:\n" - " dst | src0 | src1 | src2 \n" - " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n" - " f | f | bf | bf \n" - " f | f | hf | hf \n"); - - static_assert((N == 8 * repeat_count), "Execution size must be 8"); -#endif // end else defined(ESIMD_XE_HPC) - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - repeat_count) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - __ESIMD_NS::simd result = - __esimd_dpas2( - src0.data(), src1.data(), src2.data()); - +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::native::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { + auto result = + __ESIMD_NS::xmx::dpas(src0, src1, src2); if constexpr (std::is_same_v) return result; else @@ -1893,10 +1715,11 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { return dpas( src0, src1, src2, sat); } @@ -1913,48 +1736,14 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - - static_assert(__ESIMD_DNS::is_fp_or_dword_type::value, - "Dst must be FP or DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - repeat_count) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { + __ESIMD_NS::simd result = - __esimd_dpas_nosrc0(src1.data(), - src2.data()); + __ESIMD_NS::xmx::dpas(src1, src2); if constexpr (std::is_same_v) return result; @@ -1976,63 +1765,15 @@ template -__ESIMD_API __ESIMD_NS::simd -dpasw(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - constexpr bool is_4xhf = - std::is_same_v> && - (src1_precision == src2_precision) && - (src1_precision == argument_type::FP16); - - constexpr bool is_4xbf = __ESIMD_DNS::is_word_type::value && - src1_precision == src2_precision && - src1_precision == argument_type::BF16; - - constexpr bool is_common_dpas = __ESIMD_DNS::is_fp_or_dword_type::value; +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()") +__ESIMD_API __ESIMD_NS::simd dpasw( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - static_assert((is_4xhf || is_4xbf || is_common_dpas), - "unsupported dpas type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - ((repeat_count + 1) / 2)) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; __ESIMD_NS::simd result = - __esimd_dpasw(src0.data(), src1.data(), - src2.data()); - + __ESIMD_NS::xmx::dpasw(src0, src1, src2); if constexpr (std::is_same_v) return result; else @@ -2051,60 +1792,17 @@ template -__ESIMD_API __ESIMD_NS::simd -dpasw2(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()") +__ESIMD_API __ESIMD_NS::simd dpasw2( + __ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { constexpr bool is_4xhf = std::is_same_v> && src1_precision == src2_precision && src1_precision == argument_type::FP16; - constexpr bool is_4xbf = __ESIMD_DNS::is_word_type::value && - src1_precision == src2_precision && - src1_precision == argument_type::BF16; - - constexpr bool is_common_dpas = __ESIMD_DNS::is_fp_or_dword_type::value; - - static_assert((is_4xhf || is_4xbf || is_common_dpas), - "unsupported dpas type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - ((repeat_count + 1) / 2)) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; __ESIMD_NS::simd result = - __esimd_dpasw_nosrc0(src1.data(), - src2.data()); + __ESIMD_NS::xmx::dpasw(src1, src2); if constexpr (std::is_same_v) return result; From 25db74d16600cb0acafb0cb939b3feb44c117e05 Mon Sep 17 00:00:00 2001 From: Vyacheslav N Klochkov Date: Wed, 21 Sep 2022 15:22:27 -0700 Subject: [PATCH 2/2] Enforce dpas template arg checks, Fix dpasw(), Fix in-tree LIT test. Signed-off-by: Vyacheslav N Klochkov --- .../include/sycl/ext/intel/esimd/xmx/dpas.hpp | 72 +++++++++++++------ sycl/test/esimd/dpas.cpp | 35 +++++---- 2 files changed, 72 insertions(+), 35 deletions(-) diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp index ec98544ad0be6..258a7393e2d34 100644 --- a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp +++ b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp @@ -126,23 +126,43 @@ constexpr int verify_parameters_and_deduce_exec_size() { if constexpr (APrecision == dpas_argument_type::FP16 || BPrecision == dpas_argument_type::FP16) { - static_assert(APrecision == BPrecision && - __ESIMD_DNS::is_type() && - __ESIMD_DNS::is_type(), - "Unsupported DPAS types! The supported types are:\n" - " Result | C | B | A \n" - " f, hf | f, hf | hf | hf \n"); + if constexpr (ExecutionSize == 8) { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | hf | hf \n"); + } else { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, hf | f, hf | hf | hf \n"); + } } else if constexpr (APrecision == dpas_argument_type::BF16 || BPrecision == dpas_argument_type::BF16) { using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; - static_assert(APrecision == BPrecision && - __ESIMD_DNS::is_type() && - __ESIMD_DNS::is_type(), - "Unsupported DPAS types! The supported types are:\n" - " Result | C | B | A \n" - " f, bf | f, bf | bf | bf \n"); + if constexpr (ExecutionSize == 8) { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | bf | bf \n"); + } else { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, bf | f, bf | bf | bf \n"); + } } else if constexpr (APrecision == dpas_argument_type::TF32 || BPrecision == dpas_argument_type::TF32) { + static_assert(ExecutionSize == 16, + "tf32 type can be used only with ExecutionSize=16"); static_assert(APrecision == BPrecision && std::is_same_v && std::is_same_v, "Unsupported DPAS types! The supported types are:\n" @@ -223,7 +243,7 @@ auto dpas(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { detail::verify_parameters_and_deduce_exec_size(); - // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + C(_Mx_N) + // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) // where: // _M = RepeatCount; // _K = SystolicDepth * OpsPerChannel; @@ -237,8 +257,10 @@ auto dpas(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + ((int)APrecision << 8) + (int)BPrecision; - return __esimd_dpas_nosrc0( - BCasted.data(), ACasted.data()); + __ESIMD_NS::simd Result = + __esimd_dpas_nosrc0( + BCasted.data(), ACasted.data()); + return Result; } /// DPAS (Dot Product Accumulate Systolic) @@ -283,15 +305,21 @@ template < int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, dpas_argument_type BPrecision = detail::dpas_precision_from_type(), dpas_argument_type APrecision = detail::dpas_precision_from_type(), - int N, int BN, int AN> -__ESIMD_NS::simd dpasw(__ESIMD_NS::simd B, - __ESIMD_NS::simd A) { + int BN, int AN> +auto dpasw(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { constexpr bool IsDPASW = true; - (void)detail::verify_parameters_and_deduce_exec_size< + constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size< SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, IsDPASW>(); + // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int ResultN = RepeatCount * ExecutionSize; + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); __ESIMD_NS::simd ACasted = A.template bit_cast_view(); @@ -299,8 +327,10 @@ __ESIMD_NS::simd dpasw(__ESIMD_NS::simd B, constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + ((int)APrecision << 8) + (int)BPrecision; - return __esimd_dpasw_nosrc0( - BCasted.data(), ACasted.data()); + __ESIMD_NS::simd Result = + __esimd_dpasw_nosrc0( + BCasted.data(), ACasted.data()); + return Result; } /// @} sycl_esimd_xmx_systolic_array_api diff --git a/sycl/test/esimd/dpas.cpp b/sycl/test/esimd/dpas.cpp index 835a94c726f28..a86af83c74cf7 100644 --- a/sycl/test/esimd/dpas.cpp +++ b/sycl/test/esimd/dpas.cpp @@ -1,5 +1,5 @@ -// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -Xclang -emit-llvm %s -o %t -// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t +// RUN: %clangxx -O0 -fsycl -c -Xclang -emit-llvm %s -o %t +// RUN: %clangxx -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t // RUN: sycl-post-link -split-esimd -lower-esimd -O0 -S %t -o %t.table // RUN: FileCheck %s -input-file=%t_esimd_0.ll @@ -27,13 +27,13 @@ void bar() { } SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() { - simd A_ACC = 7; + simd A_ACC = 7; simd A_ISRC1 = 0; simd A_ISRC2 = 0; simd A_DST = dpas( A_ACC, A_ISRC1, A_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1) + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0) simd B_ACC = 7; simd B_ISRC1 = 0; @@ -49,16 +49,23 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() { C_ISRC1, C_ISRC2); // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) - simd D_ACC = 7; - simd D_ISRC1 = 0; - simd D_ISRC2 = 0; - simd D_DST = dpasw( + simd D_ACC = + 7; // MxN: 1x8 floats (M=RepeatCount=1, N=ExecutionSize=8) + simd D_ISRC1 = + 0; // KxN: 16x8 bf16: (K=SysDepth*OpsPerChan=8*2, N=ExecutionSize=8) + simd D_ISRC2 = + 0; // MxK/2: 1x8 bf16: (M=RepeatCount=1, K=SysDepth*OpsPerChan=8*2) + // Result is MxN: 1x8 floats + simd D_DST = dpasw( D_ACC, D_ISRC1, D_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpasw.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) - simd E_ISRC1 = 0; - simd E_ISRC2 = 0; - simd E_DST = dpasw2(E_ISRC1, E_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpasw.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) + simd E_ISRC1 = + 0; // KxN: 16x8 bf16: K=SysDepth*OPC=8*2, N=ExecutionSize=8 + simd E_ISRC2 = + 0; // MxK/2: 1x16/2 bf16: M=RepeatCount, K=SysDepth*OPC=8*2 + // Result is MxN: 1x8 floats + simd E_DST = dpasw2(E_ISRC1, E_ISRC2); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) }