diff --git a/sycl/include/sycl/ext/intel/esimd/detail/defines_elementary.hpp b/sycl/include/sycl/ext/intel/esimd/detail/defines_elementary.hpp index cf8d0bfd69d3a..7c53989c5e73d 100644 --- a/sycl/include/sycl/ext/intel/esimd/detail/defines_elementary.hpp +++ b/sycl/include/sycl/ext/intel/esimd/detail/defines_elementary.hpp @@ -53,6 +53,8 @@ #define __ESIMD_EMU_DNS sycl::ext::intel::esimd::emu::detail #define __ESIMD_ENS sycl::ext::intel::experimental::esimd #define __ESIMD_EDNS sycl::ext::intel::experimental::esimd::detail +#define __ESIMD_XMX_NS sycl::ext::intel::esimd::xmx +#define __ESIMD_XMX_DNS sycl::ext::intel::esimd::xmx::detail #define __ESIMD_QUOTE1(m) #m #define __ESIMD_QUOTE(m) __ESIMD_QUOTE1(m) diff --git a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp index 258a7393e2d34..8b16e6b6b0119 100644 --- a/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp +++ b/sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp @@ -26,29 +26,29 @@ namespace detail { template constexpr dpas_argument_type dpas_precision_from_type() { // TODO: add support for tfloat32 here. if constexpr (std::is_same_v) - return dpas_argument_type::FP16; + return dpas_argument_type::fp16; else if constexpr (std::is_same_v) - return dpas_argument_type::BF16; + return dpas_argument_type::bf16; else if constexpr (std::is_same_v) - return dpas_argument_type::U8; + return dpas_argument_type::u8; else if constexpr (__ESIMD_DNS::is_type()) - return dpas_argument_type::S8; + return dpas_argument_type::s8; else return dpas_argument_type::Invalid; } template constexpr int dpas_bitsize_from_precision() { - if constexpr (T == dpas_argument_type::U2 || T == dpas_argument_type::S2) + if constexpr (T == dpas_argument_type::u2 || T == dpas_argument_type::s2) return 2; - else if constexpr (T == dpas_argument_type::U4 || T == dpas_argument_type::S4) + else if constexpr (T == dpas_argument_type::u4 || T == dpas_argument_type::s4) return 4; - else if constexpr (T == dpas_argument_type::U8 || T == dpas_argument_type::S8) + else if constexpr (T == dpas_argument_type::u8 || T == dpas_argument_type::s8) return 8; - else if constexpr (T == dpas_argument_type::BF16 || - T == dpas_argument_type::FP16) + else if constexpr (T == dpas_argument_type::bf16 || + T == dpas_argument_type::fp16) return 16; - else if constexpr (T == dpas_argument_type::TF32) + else if constexpr (T == dpas_argument_type::tf32) return 32; else return -1; @@ -124,8 +124,8 @@ constexpr int verify_parameters_and_deduce_exec_size() { static_assert(ExecutionSize == 8 || (!IsDPASW && ExecutionSize == 16), "Execution size must be 8 or 16 for DPAS and 8 for DPASW."); - if constexpr (APrecision == dpas_argument_type::FP16 || - BPrecision == dpas_argument_type::FP16) { + if constexpr (APrecision == dpas_argument_type::fp16 || + BPrecision == dpas_argument_type::fp16) { if constexpr (ExecutionSize == 8) { static_assert(APrecision == BPrecision && __ESIMD_DNS::is_type() && @@ -141,8 +141,8 @@ constexpr int verify_parameters_and_deduce_exec_size() { " Result | C | B | A \n" " f, hf | f, hf | hf | hf \n"); } - } else if constexpr (APrecision == dpas_argument_type::BF16 || - BPrecision == dpas_argument_type::BF16) { + } else if constexpr (APrecision == dpas_argument_type::bf16 || + BPrecision == dpas_argument_type::bf16) { using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; if constexpr (ExecutionSize == 8) { static_assert(APrecision == BPrecision && @@ -159,8 +159,8 @@ constexpr int verify_parameters_and_deduce_exec_size() { " Result | C | B | A \n" " f, bf | f, bf | bf | bf \n"); } - } else if constexpr (APrecision == dpas_argument_type::TF32 || - BPrecision == dpas_argument_type::TF32) { + } else if constexpr (APrecision == dpas_argument_type::tf32 || + BPrecision == dpas_argument_type::tf32) { static_assert(ExecutionSize == 16, "tf32 type can be used only with ExecutionSize=16"); static_assert(APrecision == BPrecision && std::is_same_v && @@ -169,18 +169,18 @@ constexpr int verify_parameters_and_deduce_exec_size() { " Result | C | B | A \n" " f | f | tf32 | tf32 \n"); } else { - static_assert((APrecision == dpas_argument_type::U2 || - APrecision == dpas_argument_type::S2 || - APrecision == dpas_argument_type::U4 || - APrecision == dpas_argument_type::S4 || - APrecision == dpas_argument_type::U8 || - APrecision == dpas_argument_type::S8) && - (BPrecision == dpas_argument_type::U2 || - BPrecision == dpas_argument_type::S2 || - BPrecision == dpas_argument_type::U4 || - BPrecision == dpas_argument_type::S4 || - BPrecision == dpas_argument_type::U8 || - BPrecision == dpas_argument_type::S8), + static_assert((APrecision == dpas_argument_type::u2 || + APrecision == dpas_argument_type::s2 || + APrecision == dpas_argument_type::u4 || + APrecision == dpas_argument_type::s4 || + APrecision == dpas_argument_type::u8 || + APrecision == dpas_argument_type::s8) && + (BPrecision == dpas_argument_type::u2 || + BPrecision == dpas_argument_type::s2 || + BPrecision == dpas_argument_type::u4 || + BPrecision == dpas_argument_type::s4 || + BPrecision == dpas_argument_type::u8 || + BPrecision == dpas_argument_type::s8), "Unsupported DPAS types! The supported types are:\n" " Result | C | B | A \n" " ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 \n"); @@ -221,7 +221,8 @@ __ESIMD_NS::simd dpas(__ESIMD_NS::simd C, __ESIMD_NS::simd ACasted = A.template bit_cast_view(); __ESIMD_NS::simd BCasted = B.template bit_cast_view(); using CRawT = typename __ESIMD_NS::simd::raw_element_type; - return __esimd_dpas2::raw_element_type; + return __esimd_dpas2( C.data(), BCasted.data(), ACasted.data()); } @@ -257,8 +258,9 @@ auto dpas(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + ((int)APrecision << 8) + (int)BPrecision; + using RawT = typename __ESIMD_NS::simd::raw_element_type; __ESIMD_NS::simd Result = - __esimd_dpas_nosrc0( + __esimd_dpas_nosrc0( BCasted.data(), ACasted.data()); return Result; } @@ -289,9 +291,10 @@ __ESIMD_NS::simd dpasw(__ESIMD_NS::simd C, __ESIMD_NS::simd ACasted = A.template bit_cast_view(); __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + using RawT = typename __ESIMD_NS::simd::raw_element_type; constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + ((int)APrecision << 8) + (int)BPrecision; - return __esimd_dpasw( + return __esimd_dpasw( C.data(), BCasted.data(), ACasted.data()); } @@ -325,10 +328,11 @@ auto dpasw(__ESIMD_NS::simd B, __ESIMD_NS::simd A) { __ESIMD_NS::simd ACasted = A.template bit_cast_view(); __ESIMD_NS::simd BCasted = B.template bit_cast_view(); + using RawT = typename __ESIMD_NS::simd::raw_element_type; constexpr int Info = (RepeatCount << 24) + (SystolicDepth << 16) + ((int)APrecision << 8) + (int)BPrecision; __ESIMD_NS::simd Result = - __esimd_dpasw_nosrc0( + __esimd_dpasw_nosrc0( BCasted.data(), ACasted.data()); return Result; } diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp index 3498ac7c70d7b..25555964ca598 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp @@ -420,25 +420,25 @@ __ESIMD_INTRIN __ESIMD_raw_vec_t(T, SZ) } inline constexpr __ESIMD_NS::uint -__esimd_dpas_bits_precision(__ESIMD_ENS::argument_type precisionType) { - return precisionType == __ESIMD_ENS::argument_type::TF32 ? 32 - : precisionType == __ESIMD_ENS::argument_type::BF16 || - precisionType == __ESIMD_ENS::argument_type::FP16 +__esimd_dpas_bits_precision(__ESIMD_XMX_NS::dpas_argument_type precisionType) { + return precisionType == __ESIMD_XMX_NS::dpas_argument_type::tf32 ? 32 + : precisionType == __ESIMD_XMX_NS::dpas_argument_type::bf16 || + precisionType == __ESIMD_XMX_NS::dpas_argument_type::fp16 ? 16 - : precisionType == __ESIMD_ENS::argument_type::S8 || - precisionType == __ESIMD_ENS::argument_type::U8 + : precisionType == __ESIMD_XMX_NS::dpas_argument_type::s8 || + precisionType == __ESIMD_XMX_NS::dpas_argument_type::u8 ? 8 - : precisionType == __ESIMD_ENS::argument_type::S4 || - precisionType == __ESIMD_ENS::argument_type::U4 + : precisionType == __ESIMD_XMX_NS::dpas_argument_type::s4 || + precisionType == __ESIMD_XMX_NS::dpas_argument_type::u4 ? 4 - : precisionType == __ESIMD_ENS::argument_type::S2 || - precisionType == __ESIMD_ENS::argument_type::U2 + : precisionType == __ESIMD_XMX_NS::dpas_argument_type::s2 || + precisionType == __ESIMD_XMX_NS::dpas_argument_type::u2 ? 2 : 1; } -template <__ESIMD_ENS::argument_type src1_precision, - __ESIMD_ENS::argument_type src2_precision, int systolic_depth, +template <__ESIMD_XMX_NS::dpas_argument_type src1_precision, + __ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth, int repeat_count, typename RT, typename T0, typename T1, typename T2, __ESIMD_NS::uint SZ, __ESIMD_NS::uint N1, __ESIMD_NS::uint N2> inline __ESIMD_DNS::vector_type_t @@ -463,16 +463,16 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8)); uint32_t src1_signed = - src1_precision == __ESIMD_ENS::argument_type::S2 || - src1_precision == __ESIMD_ENS::argument_type::S4 || - src1_precision == __ESIMD_ENS::argument_type::S8 + src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 || + src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 || + src1_precision == __ESIMD_XMX_NS::dpas_argument_type::s8 ? 1 : 0; uint32_t src2_signed = - src2_precision == __ESIMD_ENS::argument_type::S2 || - src2_precision == __ESIMD_ENS::argument_type::S4 || - src2_precision == __ESIMD_ENS::argument_type::S8 + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s2 || + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s4 || + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::s8 ? 1 : 0; @@ -484,19 +484,31 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, constexpr bool isPvc = SIMDSize == 16; constexpr bool - pvcHfDest = isPvc && std::is_same::value, - pvcBfDest = isPvc && std::is_same::value, + pvcHfDest = isPvc && std::is_same_v && + src1_precision == __ESIMD_ENS::argument_type::FP16 && + src2_precision == __ESIMD_ENS::argument_type::FP16, + pvcHfSrc0 = isPvc && std::is_same_v && + src1_precision == __ESIMD_ENS::argument_type::FP16 && + src2_precision == __ESIMD_ENS::argument_type::FP16, + pvcBfDest = isPvc && std::is_same_v && + src1_precision == __ESIMD_ENS::argument_type::BF16 && + src2_precision == __ESIMD_ENS::argument_type::BF16, + pvcBfSrc0 = isPvc && std::is_same_v && + src1_precision == __ESIMD_ENS::argument_type::BF16 && + src2_precision == __ESIMD_ENS::argument_type::BF16, pvcBfOrHfDest = pvcBfDest || pvcHfDest, - pvcBfDestChecks = pvcBfDest && - src1_precision == __ESIMD_ENS::argument_type::BF16 && - src2_precision == __ESIMD_ENS::argument_type::BF16, + pvcBfDestChecks = + pvcBfDest && + src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 && + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16, pvcHfDestChecks = - pvcHfDest && ((src1_precision == __ESIMD_ENS::argument_type::FP16 && - src2_precision == __ESIMD_ENS::argument_type::FP16) || - (src1_precision == __ESIMD_ENS::argument_type::BF16 && - src2_precision == __ESIMD_ENS::argument_type::BF16)), + pvcHfDest && + ((src1_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16 && + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::fp16) || + (src1_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16 && + src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16)), destTypeChk = (!pvcBfOrHfDest && __ESIMD_EMU_DNS::is_fp_or_dword_type::value) || @@ -547,9 +559,11 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, if (src0 != nullptr) { auto src0El = src0[0][r * SIMDSize + n]; - if (pvcBfDest) { + if (pvcBfSrc0) { const auto tmp = (uint32_t)(src0El) << 16; simdAcc[n] = reinterpret_cast(tmp); + } else if (pvcHfSrc0) { + simdAcc[n] = reinterpret_cast(src0El); } else simdAcc[n] = src0El; } else @@ -566,7 +580,7 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, p = d + (s % src1_ops_per_dword) * ops_per_chan; uint32_t extension_temp = false; - if (src2_precision == __ESIMD_ENS::argument_type::BF16) { + if (src2_precision == __ESIMD_XMX_NS::dpas_argument_type::bf16) { const auto s1 = extract(src1_el_bits, p * src1_el_bits, src1[U * SIMDSize + n], extension_temp) @@ -577,7 +591,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, << 16; simdAcc[n] += reinterpret_cast(s2) * reinterpret_cast(s1); - } else if (src2_precision == __ESIMD_ENS::argument_type::FP16) { + } else if (src2_precision == + __ESIMD_XMX_NS::dpas_argument_type::fp16) { const auto s1 = extract(src1_el_bits, p * src1_el_bits, src1[U * SIMDSize + n], extension_temp); @@ -615,6 +630,10 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, } retv[r * SIMDSize + n] = static_cast(reinterpret_cast(tmpUint) >> 16); + } else if constexpr (pvcHfDest) { + retv[r * SIMDSize + n] = + __ESIMD_EMU_DNS::satur::saturate(simdAcc[n], + sat1); } else retv[r * SIMDSize + n] = __ESIMD_EMU_DNS::satur::template saturate(simdAcc[n], @@ -627,8 +646,8 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, } #endif // #ifndef __SYCL_DEVICE_ONLY__ -template <__ESIMD_ENS::argument_type src1_precision, - __ESIMD_ENS::argument_type src2_precision, int systolic_depth, +template <__ESIMD_XMX_NS::dpas_argument_type src1_precision, + __ESIMD_XMX_NS::dpas_argument_type src2_precision, int systolic_depth, int repeat_count, typename T, typename T0, typename T1, typename T2, int N, int N1, int N2, int res_sign = std::is_signed_v, int acc_sign = std::is_signed_v> @@ -654,10 +673,10 @@ __esimd_dpas_nosrc0(__ESIMD_DNS::vector_type_t src1, ; #else // !__SYCL_DEVICE_ONLY__ { - constexpr __ESIMD_ENS::argument_type src1_precision = - static_cast<__ESIMD_ENS::argument_type>(Info & 0xff); - constexpr __ESIMD_ENS::argument_type src2_precision = - static_cast<__ESIMD_ENS::argument_type>((Info >> 8) & 0xff); + constexpr __ESIMD_XMX_NS::dpas_argument_type src1_precision = + static_cast<__ESIMD_XMX_NS::dpas_argument_type>(Info & 0xff); + constexpr __ESIMD_XMX_NS::dpas_argument_type src2_precision = + static_cast<__ESIMD_XMX_NS::dpas_argument_type>((Info >> 8) & 0xff); constexpr int systolic_depth = (Info >> 16) & 0xff; constexpr int repeat_count = (Info >> 24) & 0xff; return __esimd_dpas_inner dpasw2( __ESIMD_NS::simd src1, __ESIMD_NS::simd src2, std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { - constexpr bool is_4xhf = - std::is_same_v> && - src1_precision == src2_precision && src1_precision == argument_type::FP16; __ESIMD_NS::simd result = __ESIMD_NS::xmx::dpasw #include using namespace sycl::ext::intel::esimd; -using namespace sycl::ext::intel::experimental::esimd; -SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo(); +namespace old = sycl::ext::intel::experimental::esimd; +namespace xmx = sycl::ext::intel::esimd::xmx; + +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; +using half = sycl::half; + +constexpr auto bf16 = xmx::dpas_argument_type::bf16; +constexpr auto fp16 = xmx::dpas_argument_type::fp16; +constexpr auto s2 = xmx::dpas_argument_type::s2; +constexpr auto s8 = xmx::dpas_argument_type::s8; + +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func(); +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void old_func(); + +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void old_func_end(); +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func_end(); class EsimdFunctor { public: - void operator()() __attribute__((sycl_explicit_simd)) { foo(); } + void operator()() __attribute__((sycl_explicit_simd)) { + old_func(); + xmx_func(); + } }; template @@ -26,46 +43,301 @@ void bar() { kernel(esimdf); } -SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void foo() { - simd A_ACC = 7; - simd A_ISRC1 = 0; - simd A_ISRC2 = 0; - simd A_DST = - dpas( - A_ACC, A_ISRC1, A_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0) - - simd B_ACC = 7; - simd B_ISRC1 = 0; - simd B_ISRC2 = 0; - simd B_DST = dpas( - B_ACC, B_ISRC1, B_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1) - - simd C_ISRC1 = 0; - simd C_ISRC2 = 0; - simd C_DST = - dpas( - C_ISRC1, C_ISRC2); - // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 {{[^,]+}}) - - simd D_ACC = - 7; // MxN: 1x8 floats (M=RepeatCount=1, N=ExecutionSize=8) - simd D_ISRC1 = - 0; // KxN: 16x8 bf16: (K=SysDepth*OpsPerChan=8*2, N=ExecutionSize=8) - simd D_ISRC2 = - 0; // MxK/2: 1x8 bf16: (M=RepeatCount=1, K=SysDepth*OpsPerChan=8*2) - // Result is MxN: 1x8 floats - simd D_DST = dpasw( - D_ACC, D_ISRC1, D_ISRC2); - // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) - - simd E_ISRC1 = - 0; // KxN: 16x8 bf16: K=SysDepth*OPC=8*2, N=ExecutionSize=8 - simd E_ISRC2 = - 0; // MxK/2: 1x16/2 bf16: M=RepeatCount, K=SysDepth*OPC=8*2 - // Result is MxN: 1x8 floats - simd E_DST = dpasw2(E_ISRC1, E_ISRC2); - // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 {{[^,]+}}) +template SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void zoo(T... A); + +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void old_func() { + // DPAS: Result(M x N) = A(M x K) * B(K x N) + // where: + // M = RepeatCount; + // K = SystolicDepth * OpsPerChannel; + // N = ExecutionSize, must be 16 on PVC and 8 on DG2. + constexpr int M_one = 1; + constexpr int K_half = 8 * 2; + constexpr int K_bf16 = 8 * 2; + constexpr int K_int8x2 = 8 * 4; + constexpr int N_pvc = 16; + constexpr int N_dg2 = 8; + + // CHECK: define dso_local spir_func void @_Z8old_funcv() + + { // ======= DPAS BF16 ======================================================= + simd R_bf = 0; + simd R_f = 0; + + simd C_bf = 0; + simd C_f = 0; + + simd B_int = 0; // 2 bf16 per 1 int32 + simd A_int = 0; // 2 bf16 per 1 int32 + + // ------------ DPAS BF16: WITH THE ACCUMULATOR OPERAND -------------------- + R_f = old::dpas(C_f, B_int, A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1) + + R_f = old::dpas(C_bf, B_int, A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0) + + R_bf = old::dpas(C_f, B_int, A_int); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas2.v16i16.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 0, i32 1) + + R_bf = old::dpas(C_bf, B_int, A_int); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas2.v16i16.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 0, i32 0) + + // ------------ DPAS BF16: WITHOUT THE ACCUMULATOR OPERAND ----------------- + R_f = old::dpas(B_int, + A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303817) + + R_bf = old::dpas( + B_int, A_int); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas.nosrc0.v16i16.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303817) + } + + { // ======= DPAS FP16 ======================================================= + simd R_hf = 0; + simd R_f = 0; + + simd C_hf = 0; + simd C_f = 0; + + simd B_int = 0; // 2 fp16 per 1 int32 + simd A_int = 0; // 2 fp16 per 1 int32 + + // ------------ DPAS FP16: WITH THE ACCUMULATOR OPERAND -------------------- + R_f = old::dpas(C_f, B_int, A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 1, i32 1) + + R_f = old::dpas(C_hf, B_int, A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f16.v128i32.v8i32(<16 x half> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 1, i32 0) + + R_hf = old::dpas(C_f, B_int, A_int); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas2.v16f16.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 0, i32 1) + + R_hf = old::dpas(C_hf, B_int, A_int); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas2.v16f16.v16f16.v128i32.v8i32(<16 x half> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 0, i32 0) + + // ------------ DPAS FP16: WITHOUT THE ACCUMULATOR OPERAND ----------------- + R_f = old::dpas(B_int, + A_int); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304074) + + R_hf = old::dpas(B_int, + A_int); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas.nosrc0.v16f16.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304074) + } + + { // ======= DPAS 8-BIT x 2-BIT INT ========================================== + simd R_d = 0; + simd C_d = 0; + simd B_int2 = 0; // 16 2-bit integers per int32 + simd A_int8 = 0; // 4 8-bit integers per int32 + + // ------------ DPAS s8 x s2: WITH THE ACCUMULATOR OPERAND ----------------- + R_d = old::dpas(C_d, B_int2, A_int8); + zoo(R_d); + // CHECK: call <16 x i32> @llvm.genx.dpas2.v16i32.v16i32.v32i32.v8i32(<16 x i32> {{[^,]+}}, <32 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 4, i32 8, i32 8, i32 1, i32 1, i32 1) + + // ------------ DPAS s8 x s2: WITHOUT THE ACCUMULATOR OPERAND -------------- + R_d = old::dpas(B_int2, A_int8); + zoo(R_d); + // CHECK: call <16 x i32> @llvm.genx.dpas.nosrc0.v16i32.v32i32.v8i32(<32 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303556) + } + + { // ======= DPASW BF16 ====================================================== + simd R_f = 0; + simd C_f = 0; + + simd B_int = 0; // 2 bf16 per 1 int32 + simd A_int = 0; // 2 bf16 per 1 int32 + + // ------------ DPASW BF16: WITH THE ACCUMULATOR OPERAND ------------------- + R_f = old::dpasw(C_f, B_int, A_int); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17303817) + + // ------------ DPASW BF16: WITHOUT ACC OPERAND ---------------------------- + R_f = old::dpasw2(B_int, + A_int); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17303817) + } + + { // ======= DPASW FP16 ====================================================== + simd R_f = 0; + simd C_f = 0; + + simd B_int = 0; // 2 fp16 per 1 int32 + simd A_int = 0; // 2 fp16 per 1 int32 + + // ------------ DPASW FP16: WITH THE ACCUMULATOR OPERAND ------------------- + R_f = old::dpasw(C_f, B_int, A_int); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074) + + // ------------ DPASW FP16: WITHOUT ACC OPERAND ---------------------------- + R_f = old::dpasw2(B_int, + A_int); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074) + } + + old_func_end(); + // CHECK: call spir_func void @_Z12old_func_endv() +} + +SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() { + // DPAS: Result(M x N) = A(M x K) * B(K x N) + // where: + // M = RepeatCount; + // K = SystolicDepth * OpsPerChannel; + // N = ExecutionSize, must be 16 on PVC and 8 on DG2. + constexpr int M_one = 1; + constexpr int K_half = 8 * 2; + constexpr int K_bf16 = 8 * 2; + constexpr int K_int8x2 = 8 * 4; + constexpr int N_pvc = 16; + constexpr int N_dg2 = 8; + + // CHECK: define dso_local spir_func void @_Z8xmx_funcv() + + { // ======= DPAS BF16 ======================================================= + simd R_bf = 0; + simd R_f = 0; + + simd C_bf = 0; + simd C_f = 0; + + simd B_bf = 0; + simd A_bf = 0; + + R_f = xmx::dpas<8, 1, float>(C_f, B_bf, A_bf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 1) + + R_f = xmx::dpas<8, 1, float>(C_bf, B_bf, A_bf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 1, i32 0) + + R_bf = xmx::dpas<8, 1, bfloat16>(C_f, B_bf, A_bf); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas2.v16i16.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 0, i32 1) + + R_bf = xmx::dpas<8, 1, bfloat16>(C_bf, B_bf, A_bf); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas2.v16i16.v16i16.v128i32.v8i32(<16 x i16> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 9, i32 9, i32 8, i32 1, i32 0, i32 0) + + R_f = xmx::dpas<8, 1, float>(B_bf, A_bf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303817) + + R_bf = xmx::dpas<8, 1, bfloat16>(B_bf, A_bf); + zoo(R_bf); + // CHECK: call <16 x i16> @llvm.genx.dpas.nosrc0.v16i16.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303817) + } + + { // ======= DPAS FP16 ======================================================= + simd R_hf = 0; + simd R_f = 0; + + simd C_hf = 0; + simd C_f = 0; + + simd B_hf = 0; + simd A_hf = 0; + + // ------------------- FP16: WITH ACC OPERAND ----------------------- + R_f = xmx::dpas<8, 1, float>(C_f, B_hf, A_hf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 1, i32 1) + + R_f = xmx::dpas<8, 1, float>(C_hf, B_hf, A_hf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f16.v128i32.v8i32(<16 x half> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 1, i32 0) + + R_hf = xmx::dpas<8, 1, half>(C_f, B_hf, A_hf); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas2.v16f16.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 0, i32 1) + + R_hf = xmx::dpas<8, 1, half>(C_hf, B_hf, A_hf); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas2.v16f16.v16f16.v128i32.v8i32(<16 x half> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 10, i32 10, i32 8, i32 1, i32 0, i32 0) + + // ------------------- FP16: NO ACC OPERAND ----------------------- + R_f = xmx::dpas<8, 1, float>(B_hf, A_hf); + zoo(R_f); + // CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304074) + + R_hf = xmx::dpas<8, 1, half>(B_hf, A_hf); + zoo(R_hf); + // CHECK: call <16 x half> @llvm.genx.dpas.nosrc0.v16f16.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304074) + } + + { // ======= DPAS 8-BIT x 2-BIT INT ========================================== + simd R_d = 0; + simd C_d = 0; + simd B_int2 = 0; // 16 2-bit integers per int32 + simd A_int8 = 0; + + // ------------ DPAS s8 x s2: WITH THE ACCUMULATOR OPERAND ----------------- + R_d = xmx::dpas<8, 1, int, int, int, signed char, s2, s8>(C_d, B_int2, + A_int8); + zoo(R_d); + // CHECK: call <16 x i32> @llvm.genx.dpas2.v16i32.v16i32.v32i32.v8i32(<16 x i32> {{[^,]+}}, <32 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 4, i32 8, i32 8, i32 1, i32 1, i32 1) + + // ------------ DPAS s8 x s2: WITHOUT THE ACCUMULATOR OPERAND -------------- + R_d = xmx::dpas<8, 1, int, int, signed char, s2, s8>(B_int2, A_int8); + zoo(R_d); + // CHECK: call <16 x i32> @llvm.genx.dpas.nosrc0.v16i32.v32i32.v8i32(<32 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17303556) + } + + { // ======= DPASW BF16 ====================================================== + simd R_f = 0; + simd C_f = 0; + + simd B_bf = 0; + simd A_bf = 0; + + // ------------ DPASW BF16: WITH THE ACCUMULATOR OPERAND ------------------- + R_f = xmx::dpasw<8, 1, float>(C_f, B_bf, A_bf); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17303817) + + // ------------ DPASW BF16: WITHOUT ACC OPERAND ---------------------------- + R_f = xmx::dpasw<8, 1, float>(B_bf, A_bf); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17303817) + } + + { // ======= DPASW FP16 ====================================================== + simd R_f = 0; + simd C_f = 0; + + simd B_hf = 0; + simd A_hf = 0; + + // ------------ DPASW FP16: WITH THE ACCUMULATOR OPERAND ------------------- + R_f = xmx::dpasw<8, 1, float>(C_f, B_hf, A_hf); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.v8f32.v64i32.v4i32(<8 x float> {{[^,]+}}, <64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074) + + // ------------ DPASW FP16: WITHOUT ACC OPERAND ---------------------------- + R_f = xmx::dpasw<8, 1, float>(B_hf, A_hf); + zoo(R_f); + // CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074) + } + + xmx_func_end(); + // CHECK: call spir_func void @_Z12xmx_func_endv() }