diff --git a/sycl/include/sycl/ext/intel/esimd.hpp b/sycl/include/sycl/ext/intel/esimd.hpp index 19398ec0abd0d..021a4fc26120f 100644 --- a/sycl/include/sycl/ext/intel/esimd.hpp +++ b/sycl/include/sycl/ext/intel/esimd.hpp @@ -85,6 +85,7 @@ #include #include #include +#include #include #include #include diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp new file mode 100644 index 0000000000000..45b1f7d5155b5 --- /dev/null +++ b/sycl/include/sycl/ext/intel/esimd/xmx/common.hpp @@ -0,0 +1,47 @@ +//==-------------- xmx/common.hpp - DPC++ Explicit SIMD API ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Explicit SIMD API types used in ESIMD Intel Xe Matrix eXtension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { +namespace ext::intel::esimd::xmx { + +enum class dpas_argument_type { + Invalid = 0, + u1 = 1, // unsigned 1 bit + U1 __SYCL_DEPRECATED("use u1") = u1, + s1 = 2, // signed 1 bit + S1 __SYCL_DEPRECATED("use s1") = s1, + u2 = 3, // unsigned 2 bits + U2 __SYCL_DEPRECATED("use u2") = u2, + s2 = 4, // signed 2 bits + S2 __SYCL_DEPRECATED("use s2") = s2, + u4 = 5, // unsigned 4 bits + U4 __SYCL_DEPRECATED("use u4") = u4, + s4 = 6, // signed 4 bits + S4 __SYCL_DEPRECATED("use s4") = s4, + u8 = 7, // unsigned 8 bits + U8 __SYCL_DEPRECATED("use u8") = u8, + s8 = 8, // signed 8 bits + S8 __SYCL_DEPRECATED("use s8") = s8, + bf16 = 9, // bfloat 16 + BF16 __SYCL_DEPRECATED("use bf16") = bf16, + fp16 = 10, // half float + FP16 __SYCL_DEPRECATED("use fp16") = fp16, + tf32 = 12, // tensorfloat 32 + TF32 __SYCL_DEPRECATED("use tf32") = tf32 +}; + +} // namespace ext::intel::esimd::xmx +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp new file mode 100644 index 0000000000000..258a7393e2d34 --- /dev/null +++ b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp @@ -0,0 +1,340 @@ +//==----------------- xmx/dpas.hpp - DPC++ Explicit SIMD API ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// Explicit SIMD API for DPAS Intel Xe Matrix eXtension. +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include +#include +#include +#include + +namespace sycl { +__SYCL_INLINE_VER_NAMESPACE(_V1) { + +namespace ext::intel::esimd::xmx { + +namespace detail { + +template constexpr dpas_argument_type dpas_precision_from_type() { + // TODO: add support for tfloat32 here. + if constexpr (std::is_same_v) + return dpas_argument_type::FP16; + else if constexpr (std::is_same_v) + return dpas_argument_type::BF16; + else if constexpr (std::is_same_v) + return dpas_argument_type::U8; + else if constexpr (__ESIMD_DNS::is_type()) + return dpas_argument_type::S8; + else + return dpas_argument_type::Invalid; +} + +template constexpr int dpas_bitsize_from_precision() { + if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2) + return 2; + else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4) + return 4; + else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8) + return 8; + else if constexpr (T == dpas_argument_type::BF16 || + T == dpas_argument_type::FP16) + return 16; + else if constexpr (T == dpas_argument_type::TF32) + return 32; + else + return -1; +} + +template +constexpr void verify_repeat_count() { + static_assert(RepeatCount >= 1 && RepeatCount <= 8, + "Repeat count must be within 1 to 8 range"); + + if constexpr (IsDPASW && RepeatCount != 8) { + static_assert(!(AElemBitSize == 2 && BElemBitSize > 4), + "Unsupported repeat count for DPASW operation"); + + static_assert( + RepeatCount == 4 || + (AElemBitSize != 2 && (AElemBitSize != 4 || BElemBitSize <= 4)), + "Unsupported repeat count for DPASW operation"); + } +} + +template +constexpr int verify_parameters_and_deduce_exec_size() { + + static_assert(SystolicDepth == 8, "Systolic depth must be equal to 8"); + static_assert( + APrecision != dpas_argument_type::Invalid && + BPrecision != dpas_argument_type::Invalid, + "The types of dpas arguments are either incorrect or cannot be deduced." + "Fix the types and/or explicitly specify them."); + + constexpr int AElemBitSize = dpas_bitsize_from_precision(); + constexpr int BElemBitSize = dpas_bitsize_from_precision(); + static_assert(AElemBitSize != -1 && BElemBitSize != -1, + "Cannot deduce element size of input arguments"); + verify_repeat_count(); + + constexpr int OpsPerChannel = + std::min(32 / std::max(AElemBitSize, BElemBitSize), 8); + + // A(_Mx_K) * B(_Kx_N) + C(_Mx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int _M = RepeatCount; + constexpr int _K = SystolicDepth * OpsPerChannel; + + // Compute _N (aka ExecutionSize) from the matrix B. + // It has _K*_N elements of BPrecision type, and BN elements of BT type + // hold those _K*_N*BPrecision bits, which let's us compute _N. + constexpr int BMatrixBitSize = sizeof(BT) * BN * 8; + constexpr int BNumElems = BMatrixBitSize / BElemBitSize; + constexpr int _N = BNumElems / _K; + static_assert(_K * _N == BNumElems, "Cannot deduce the execution size."); + + // Now verify that AN elements of AT type hold exactly _M*_K elements + // of APrecision type/size. Similarly for B: BN elements of BT type must + // hold _K*_N elements of BPrecision type/size. + // DPASW accepts 2x less expected AN elements than regular DPAS. + constexpr int AFactorForDPASW = IsDPASW ? 2 : 1; + static_assert(_M * _K * AElemBitSize == AN * sizeof(AT) * 8 * AFactorForDPASW, + "The first matrix multiplier has wrong size."); + static_assert(_K * _N * BElemBitSize == BN * sizeof(BT) * 8, + "The second matrix multiplier has wrong size."); + + // Execution size may be 8 or 16 depending on the target device. + // User must check if used execution size is supported before calling DPAS. + constexpr int ExecutionSize = _N; + + static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16), + "Execution size must be 8 or 16 for DPAS and 8 for DPASW."); + + if constexpr (APrecision == dpas_argument_type::FP16 || + BPrecision == dpas_argument_type::FP16) { + if constexpr (ExecutionSize == 8) { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | hf | hf \n"); + } else { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, hf | f, hf | hf | hf \n"); + } + } else if constexpr (APrecision == dpas_argument_type::BF16 || + BPrecision == dpas_argument_type::BF16) { + using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; + if constexpr (ExecutionSize == 8) { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | bf | bf \n"); + } else { + static_assert(APrecision == BPrecision && + __ESIMD_DNS::is_type() && + __ESIMD_DNS::is_type(), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f, bf | f, bf | bf | bf \n"); + } + } else if constexpr (APrecision == dpas_argument_type::TF32 || + BPrecision == dpas_argument_type::TF32) { + static_assert(ExecutionSize == 16, + "tf32 type can be used only with ExecutionSize=16"); + static_assert(APrecision == BPrecision && std::is_same_v && + std::is_same_v, + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " f | f | tf32 | tf32 \n"); + } else { + static_assert((APrecision == dpas_argument_type::U2 || + APrecision == dpas_argument_type::S2 || + APrecision == dpas_argument_type::U4 || + APrecision == dpas_argument_type::S4 || + APrecision == dpas_argument_type::U8 || + APrecision == dpas_argument_type::S8) && + (BPrecision == dpas_argument_type::U2 || + BPrecision == dpas_argument_type::S2 || + BPrecision == dpas_argument_type::U4 || + BPrecision == dpas_argument_type::S4 || + BPrecision == dpas_argument_type::U8 || + BPrecision == dpas_argument_type::S8), + "Unsupported DPAS types! The supported types are:\n" + " Result | C | B | A \n" + " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n"); + } + return ExecutionSize; +} + +} // namespace detail + +/// @defgroup sycl_esimd_xmx_systolic_array_api Systolic Array APIs. +/// APIs below are used to implement dot product accumulate systolic functions +/// @ingroup sycl_esimd + +/// @addtogroup sycl_esimd_xmx_systolic_array_api +/// @{ +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = C + A x B; +/// @param C represents DPAS accumulator operand. +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename CT, typename BT, + typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int N, int BN, int AN> +__ESIMD_NS::simd dpas(__ESIMD_NS::simd C, + __ESIMD_NS::simd B, + __ESIMD_NS::simd A) { + (void)detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, CT, BT, AT, BPrecision, APrecision, BN, + AN>(); + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + using CRawT = typename __ESIMD_NS::simd::raw_element_type; + return __esimd_dpas2( + C.data(), BCasted.data(), ACasted.data()); +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = A x B; +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int BN, int AN> +auto dpas(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { + + constexpr int ExecutionSize = + detail::verify_parameters_and_deduce_exec_size(); + // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int ResultN = RepeatCount * ExecutionSize; + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + __ESIMD_NS::simd Result = + __esimd_dpas_nosrc0( + BCasted.data(), ACasted.data()); + return Result; +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = C + A x B; +/// @param C represents DPAS accumulator operand. +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int N, int BN, int AN> +__ESIMD_NS::simd dpasw(__ESIMD_NS::simd C, + __ESIMD_NS::simd B, + __ESIMD_NS::simd A) { + + constexpr bool IsDPASW = true; + (void)detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, + IsDPASW>(); + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + return __esimd_dpasw( + C.data(), BCasted.data(), ACasted.data()); +} + +/// DPAS (Dot Product Accumulate Systolic) +/// Computes the result of matrix operations: Result = A x B; +/// @param B represents the 2nd matrix multiplier. It must have the VNNI encoded +/// layout. +/// @param A represents the 1st matrix multiplier. +/// @return the vector value of DPAS computation result. +template < + int SystolicDepth, int RepeatCount, typename T, typename BT, typename AT, + dpas_argument_type BPrecision = detail::dpas_precision_from_type(), + dpas_argument_type APrecision = detail::dpas_precision_from_type(), + int BN, int AN> +auto dpasw(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { + + constexpr bool IsDPASW = true; + constexpr int ExecutionSize = detail::verify_parameters_and_deduce_exec_size< + SystolicDepth, RepeatCount, T, T, BT, AT, BPrecision, APrecision, BN, AN, + IsDPASW>(); + + // Result(_Mx_N) = A(_Mx_K) * B(_Kx_N) + // where: + // _M = RepeatCount; + // _K = SystolicDepth * OpsPerChannel; + // _N = ExecutionSize (unknown, but deducible), must be 8 or 16. + constexpr int ResultN = RepeatCount * ExecutionSize; + + constexpr int ANCasted = AN / (sizeof(int) / sizeof(AT)); + constexpr int BNCasted = BN / (sizeof(int) / sizeof(BT)); + __ESIMD_NS::simd ACasted = A.template bit_cast_view(); + __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + + constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + + ((int)APrecision << 8) + (int)BPrecision; + __ESIMD_NS::simd Result = + __esimd_dpasw_nosrc0( + BCasted.data(), ACasted.data()); + return Result; +} + +/// @} sycl_esimd_xmx_systolic_array_api + +} // namespace ext::intel::esimd::xmx +} // __SYCL_INLINE_VER_NAMESPACE(_V1) +} // namespace sycl diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp index 28fef416acbbc..1acbc9788502e 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/common.hpp @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -23,19 +24,9 @@ namespace ext::intel::experimental::esimd { /// @addtogroup sycl_esimd_core /// @{ -enum class argument_type { - U1 = 1, // unsigned 1 bit - S1 = 2, // signed 1 bit - U2 = 3, // unsigned 2 bits - S2 = 4, // signed 2 bits - U4 = 5, // unsigned 4 bits - S4 = 6, // signed 4 bits - U8 = 7, // unsigned 8 bits - S8 = 8, // signed 8 bits - BF16 = 9, // bfloat 16 - FP16 = 10, // half float - TF32 = 12 // tensorfloat 32 -}; +using argument_type + __SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas_argument_type") = + __ESIMD_NS::xmx::dpas_argument_type; /// The scope that lsc_fence operation should apply to /// Supported platforms: DG2, PVC diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp index aab438ba698fc..3498ac7c70d7b 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp @@ -13,6 +13,8 @@ /// @cond ESIMD_DETAIL #include +#include +#include #include #define __ESIMD_raw_vec_t(T, SZ) \ @@ -474,13 +476,12 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, ? 1 : 0; -#if defined(ESIMD_XE_HPC) || defined(ESIMD_XE_HPG) - constexpr bool isPvc = true; - constexpr size_t SIMDSize = 16; -#else - constexpr bool isPvc = false; - constexpr size_t SIMDSize = 8; -#endif + constexpr uint32_t src1_vec_bit_size = sizeof(T1) * N1 * 8; + constexpr uint32_t src1_num_elem = src1_vec_bit_size / src1_el_bits; + constexpr size_t SIMDSize = src1_num_elem / (systolic_depth * ops_per_chan); + static_assert(SIMDSize == 8 || SIMDSize == 16, + "Execution size must be 8 or 16"); + constexpr bool isPvc = SIMDSize == 16; constexpr bool pvcHfDest = isPvc && std::is_same::value, diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp index 26778a4ef3017..ad2bc78578259 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp @@ -1667,78 +1667,6 @@ __ESIMD_NS::simd dp4(__ESIMD_NS::simd v1, /// @} sycl_esimd_math -/// @cond ESIMD_DETAIL -// dpas helpers -namespace detail { - -enum class dpas_ops_per_channel : unsigned { - OP1 = 1u, - OP2 = 2u, - OP4 = 4u, - OP8 = 8u, - INVALID = 0xffffffffu -}; -constexpr dpas_ops_per_channel -get_ops_per_channel(argument_type src1_precision, - argument_type src2_precision) { - if ((src1_precision == argument_type::U8) || - (src1_precision == argument_type::S8)) { - if ((src2_precision == argument_type::U8) || - (src2_precision == argument_type::S8) || - (src2_precision == argument_type::U4) || - (src2_precision == argument_type::S4) || - (src2_precision == argument_type::U2) || - (src2_precision == argument_type::S2)) { - return dpas_ops_per_channel::OP4; - } - } else if ((src1_precision == argument_type::U4) || - (src1_precision == argument_type::S4) || - (src1_precision == argument_type::U2) || - (src1_precision == argument_type::S2)) { - if ((src2_precision == argument_type::U8) || - (src2_precision == argument_type::S8)) { - return dpas_ops_per_channel::OP4; - } else if ((src2_precision == argument_type::U4) || - (src2_precision == argument_type::S4) || - (src2_precision == argument_type::U2) || - (src2_precision == argument_type::S2)) { - return dpas_ops_per_channel::OP8; - } - } else if ((src1_precision == argument_type::BF16) && - (src2_precision == argument_type::BF16)) { - return dpas_ops_per_channel::OP2; - } else if ((src1_precision == argument_type::FP16) && - (src2_precision == argument_type::FP16)) { - return dpas_ops_per_channel::OP2; - } else if ((src1_precision == argument_type::TF32) && - (src2_precision == argument_type::TF32)) { - return dpas_ops_per_channel::OP1; - } - return dpas_ops_per_channel::INVALID; -} - -constexpr unsigned get_precision_bits(argument_type src_precision) { - if ((src_precision == argument_type::U8) || - (src_precision == argument_type::S8)) { - return 8; - } else if ((src_precision == argument_type::U4) || - (src_precision == argument_type::S4)) { - return 4; - } else if ((src_precision == argument_type::U2) || - (src_precision == argument_type::S2)) { - return 2; - } else if ((src_precision == argument_type::BF16) || - (src_precision == argument_type::FP16)) { - return 16; - } else if (src_precision == argument_type::TF32) { - return 32; - } - return 0; -} - -} // namespace detail -/// @endcond ESIMD_DETAIL - /// @defgroup sycl_esimd_systolic_array_api Systolic Array APIs. /// APIs below are used to implement dot product accumulate systolic functions /// @ingroup sycl_esimd @@ -1759,120 +1687,14 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - // types: dst, src0, src1, src2 - // ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 - constexpr bool check_integer = - detail::is_one_of_v && - detail::is_one_of_v && - detail::is_one_of_enum_v && - detail::is_one_of_enum_v; - - // TODO: Allow the maximum possible combination of types and also allow - // both execution sizes (8 and 16). That will give user all the control - // over the DPAS functionality without the need to define macros like - // ESIMD_XE_HPC. In this case it's user's responsibility to dispatch the code - // and use DPAS with types supported by the target device. - // From ESIMD compiler side the additional compile-time help/convenience may - // be provided via using optional target-specific macros to enforce - // verification of arguments and returns at compilation time. -#if defined(ESIMD_XE_HPC) - // f, bf | f, bf | bf | bf - constexpr bool check_bf16 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::BF16 && - src2_precision == argument_type::BF16; - - // f,hf | f, hf | hf | hf - constexpr bool check_hf = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::FP16 && - src1_precision == argument_type::FP16; - - // f | f | tf32 | tf32 - constexpr bool check_tf32 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::TF32 && - src2_precision == argument_type::TF32; - - constexpr bool check_passed = - (check_integer || check_hf || check_bf16 || check_tf32); - static_assert(check_passed, - "unsupported dpas type! The supported types are:\n" - " dst | src0 | src1 | src2 \n" - " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n" - " f, bf | f, bf | bf | bf \n" - " f, hf | f, hf | hf | hf \n" - " f | f | tf32 | tf32 \n"); - - static_assert((N == 16 * repeat_count), "Execution size on PVC must be 16"); -#else // else defined(ESIMD_XE_HPC) - // f | f | bf | bf - constexpr bool check_bf16 = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::BF16 && - src2_precision == argument_type::BF16; - - // f | f | hf | hf - constexpr bool check_hf = detail::is_one_of_v && - detail::is_one_of_v && - src1_precision == argument_type::FP16 && - src1_precision == argument_type::FP16; - - constexpr bool check_passed = (check_integer || check_hf || check_bf16); - static_assert(check_passed, - "unsupported dpas type! The supported types are:\n" - " dst | src0 | src1 | src2 \n" - " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n" - " f | f | bf | bf \n" - " f | f | hf | hf \n"); - - static_assert((N == 8 * repeat_count), "Execution size must be 8"); -#endif // end else defined(ESIMD_XE_HPC) - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - repeat_count) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - __ESIMD_NS::simd result = - __esimd_dpas2( - src0.data(), src1.data(), src2.data()); - +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::native::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { + auto result = + __ESIMD_NS::xmx::dpas(src0, src1, src2); if constexpr (std::is_same_v) return result; else @@ -1893,10 +1715,11 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { return dpas( src0, src1, src2, sat); } @@ -1913,48 +1736,14 @@ template -__ESIMD_API __ESIMD_NS::simd -dpas(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - - static_assert(__ESIMD_DNS::is_fp_or_dword_type::value, - "Dst must be FP or DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - repeat_count) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpas()") +__ESIMD_API __ESIMD_NS::simd dpas( + __ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { + __ESIMD_NS::simd result = - __esimd_dpas_nosrc0(src1.data(), - src2.data()); + __ESIMD_NS::xmx::dpas(src1, src2); if constexpr (std::is_same_v) return result; @@ -1976,63 +1765,15 @@ template -__ESIMD_API __ESIMD_NS::simd -dpasw(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - constexpr bool is_4xhf = - std::is_same_v> && - (src1_precision == src2_precision) && - (src1_precision == argument_type::FP16); - - constexpr bool is_4xbf = __ESIMD_DNS::is_word_type::value && - src1_precision == src2_precision && - src1_precision == argument_type::BF16; - - constexpr bool is_common_dpas = __ESIMD_DNS::is_fp_or_dword_type::value; +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()") +__ESIMD_API __ESIMD_NS::simd dpasw( + __ESIMD_NS::simd src0, __ESIMD_NS::simd src1, + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - static_assert((is_4xhf || is_4xbf || is_common_dpas), - "unsupported dpas type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - ((repeat_count + 1) / 2)) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; __ESIMD_NS::simd result = - __esimd_dpasw(src0.data(), src1.data(), - src2.data()); - + __ESIMD_NS::xmx::dpasw(src0, src1, src2); if constexpr (std::is_same_v) return result; else @@ -2051,60 +1792,17 @@ template -__ESIMD_API __ESIMD_NS::simd -dpasw2(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, - std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { +__SYCL_DEPRECATED("use sycl::ext::intel::esimd::xmx::dpasw()") +__ESIMD_API __ESIMD_NS::simd dpasw2( + __ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { constexpr bool is_4xhf = std::is_same_v> && src1_precision == src2_precision && src1_precision == argument_type::FP16; - constexpr bool is_4xbf = __ESIMD_DNS::is_word_type::value && - src1_precision == src2_precision && - src1_precision == argument_type::BF16; - - constexpr bool is_common_dpas = __ESIMD_DNS::is_fp_or_dword_type::value; - - static_assert((is_4xhf || is_4xbf || is_common_dpas), - "unsupported dpas type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src1 must be DWORD type"); - - static_assert(__ESIMD_DNS::is_dword_type::value, - "Src2 must be DWORD type"); - - static_assert((N == 8 * repeat_count) || (N == 16 * repeat_count), - "Execution size must be 8 or 16"); - - static_assert(systolic_depth == 8, "systolic_depth must be 8"); - static_assert((repeat_count >= 1) && (repeat_count <= 8), - "repeat_count must be within 1 to 8"); - - constexpr auto en_ops_per_channel = - detail::get_ops_per_channel(src1_precision, src2_precision); - static_assert(en_ops_per_channel != detail::dpas_ops_per_channel::INVALID, - "invalid combination of Src1/Src2 precision"); - constexpr auto ops_per_channel = static_cast(en_ops_per_channel); - - constexpr auto src1_precision_bits = - detail::get_precision_bits(src1_precision); - static_assert( - N1 == ((src1_precision_bits * systolic_depth * ops_per_channel * N) / - (repeat_count * sizeof(T1) * 8)), - "invalid size for Src1"); - - constexpr auto src2_precision_bits = - detail::get_precision_bits(src2_precision); - static_assert(N2 == ((src2_precision_bits * systolic_depth * ops_per_channel * - ((repeat_count + 1) / 2)) / - (sizeof(T2) * 8)), - "invalid size for Src2"); - - constexpr int dpas_info = (repeat_count << 24) + (systolic_depth << 16) + - (((int)src2_precision) << 8) + (int)src1_precision; __ESIMD_NS::simd result = - __esimd_dpasw_nosrc0(src1.data(), - src2.data()); + __ESIMD_NS::xmx::dpasw(src1, src2); if constexpr (std::is_same_v) return result; diff --git a/sycl/test/esimd/dpas.cpp b/sycl/test/esimd/dpas.cpp index 835a94c726f28..a86af83c74cf7 100644 --- a/sycl/test/esimd/dpas.cpp +++ b/sycl/test/esimd/dpas.cpp @@ -1,5 +1,5 @@ -// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -Xclang -emit-llvm %s -o %t -// RUN: %clangxx -DESIMD_XE_HPC -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t +// RUN: %clangxx -O0 -fsycl -c -Xclang -emit-llvm %s -o %t +// RUN: %clangxx -O0 -fsycl -c -fsycl-device-only -Xclang -emit-llvm %s -o %t // RUN: sycl-post-link -split-esimd -lower-esimd -O0 -S %t -o %t.table // RUN: FileCheck %s -input-file=%t_esimd_0.ll @@ -27,13 +27,13 @@ void bar() { } SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() { - simd A_ACC = 7; + simd A_ACC = 7; simd A_ISRC1 = 0; simd A_ISRC2 = 0; simd A_DST = dpas( A_ACC, A_ISRC1, A_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1) + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0) simd B_ACC = 7; simd B_ISRC1 = 0; @@ -49,16 +49,23 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() { C_ISRC1, C_ISRC2); // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) - simd D_ACC = 7; - simd D_ISRC1 = 0; - simd D_ISRC2 = 0; - simd D_DST = dpasw( + simd D_ACC = + 7; // MxN: 1x8 floats (M=RepeatCount=1, N=ExecutionSize=8) + simd D_ISRC1 = + 0; // KxN: 16x8 bf16: (K=SysDepth*OpsPerChan=8*2, N=ExecutionSize=8) + simd D_ISRC2 = + 0; // MxK/2: 1x8 bf16: (M=RepeatCount=1, K=SysDepth*OpsPerChan=8*2) + // Result is MxN: 1x8 floats + simd D_DST = dpasw( D_ACC, D_ISRC1, D_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpasw.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) - simd E_ISRC1 = 0; - simd E_ISRC2 = 0; - simd E_DST = dpasw2(E_ISRC1, E_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpasw.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) + simd E_ISRC1 = + 0; // KxN: 16x8 bf16: K=SysDepth*OPC=8*2, N=ExecutionSize=8 + simd E_ISRC2 = + 0; // MxK/2: 1x16/2 bf16: M=RepeatCount, K=SysDepth*OPC=8*2 + // Result is MxN: 1x8 floats + simd E_DST = dpasw2(E_ISRC1, E_ISRC2); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) }