diff --git a/sycl/include/sycl/ext/intel/esimd/common.hpp b/sycl/include/sycl/ext/intel/esimd/common.hpp index 2f144bced5a03..ce48d9c1bc86c 100644 --- a/sycl/include/sycl/ext/intel/esimd/common.hpp +++ b/sycl/include/sycl/ext/intel/esimd/common.hpp @@ -65,6 +65,16 @@ using SurfaceIndex = unsigned int; namespace detail { +template +struct is_saturation_tag { + static constexpr bool value = + std::is_same_v || + std::is_same_v; +}; + +template +inline constexpr bool is_saturation_tag_v = is_saturation_tag::value; + /// Check if a given 32 bit positive integer is a power of 2 at compile time. ESIMD_INLINE constexpr bool isPowerOf2(unsigned int n) { return (n & (n - 1)) == 0; diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp index 2d6649c9b9c40..aab438ba698fc 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/detail/math_intrin.hpp @@ -451,24 +451,15 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t *src0, __ESIMD_EMU_DNS::SetSatur::value>::set(); - constexpr __ESIMD_NS::uint ops_per_chan = - src1_precision == __ESIMD_ENS::argument_type::BF16 || - src1_precision == __ESIMD_ENS::argument_type::FP16 || - src2_precision == __ESIMD_ENS::argument_type::BF16 || - src2_precision == __ESIMD_ENS::argument_type::FP16 - ? 2 - : src1_precision == __ESIMD_ENS::argument_type::S8 || - src1_precision == __ESIMD_ENS::argument_type::U8 || - src2_precision == __ESIMD_ENS::argument_type::S8 || - src2_precision == __ESIMD_ENS::argument_type::U8 - ? 4 - : 8; - __ESIMD_NS::uint V = 0, U = 0, k = 0, temp = 0, src1_ops_per_dword = 0, p = 0; constexpr auto src1_el_bits = __esimd_dpas_bits_precision(src1_precision); constexpr auto src2_el_bits = __esimd_dpas_bits_precision(src2_precision); + constexpr auto max_el_bits = std::max(src1_el_bits, src2_el_bits); + constexpr __ESIMD_NS::uint ops_per_chan = + std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8)); + uint32_t src1_signed = src1_precision == __ESIMD_ENS::argument_type::S2 || src1_precision == __ESIMD_ENS::argument_type::S4 || diff --git a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp index 23231ce062587..26778a4ef3017 100644 --- a/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp +++ b/sycl/include/sycl/ext/intel/experimental/esimd/math.hpp @@ -1761,7 +1761,8 @@ template __ESIMD_API __ESIMD_NS::simd dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, Sat sat = {}) { + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { // types: dst, src0, src1, src2 // ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2 constexpr bool check_integer = @@ -1894,7 +1895,8 @@ template __ESIMD_API __ESIMD_NS::simd dpas(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, Sat sat = {}) { + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { return dpas( src0, src1, src2, sat); } @@ -1911,9 +1913,9 @@ template -__ESIMD_API __ESIMD_NS::simd dpas(__ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - Sat sat = {}) { +__ESIMD_API __ESIMD_NS::simd +dpas(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { static_assert(__ESIMD_DNS::is_fp_or_dword_type::value, "Dst must be FP or DWORD type"); @@ -1976,7 +1978,8 @@ template __ESIMD_API __ESIMD_NS::simd dpasw(__ESIMD_NS::simd src0, __ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, Sat sat = {}) { + __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { constexpr bool is_4xhf = std::is_same_v> && (src1_precision == src2_precision) && @@ -2048,9 +2051,9 @@ template -__ESIMD_API __ESIMD_NS::simd dpasw2(__ESIMD_NS::simd src1, - __ESIMD_NS::simd src2, - Sat sat = {}) { +__ESIMD_API __ESIMD_NS::simd +dpasw2(__ESIMD_NS::simd src1, __ESIMD_NS::simd src2, + std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v, Sat> sat = {}) { constexpr bool is_4xhf = std::is_same_v> && src1_precision == src2_precision && src1_precision == argument_type::FP16;