Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sycl/include/sycl/ext/intel/esimd/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ using SurfaceIndex = unsigned int;

namespace detail {

template <typename T>
struct is_saturation_tag {
static constexpr bool value =
std::is_same_v<T, __ESIMD_NS::saturation_on_tag> ||
std::is_same_v<T, __ESIMD_NS::saturation_off_tag>;
};

template <class T>
inline constexpr bool is_saturation_tag_v = is_saturation_tag<T>::value;

/// Check if a given 32 bit positive integer is a power of 2 at compile time.
ESIMD_INLINE constexpr bool isPowerOf2(unsigned int n) {
return (n & (n - 1)) == 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,24 +451,15 @@ __esimd_dpas_inner(const __ESIMD_DNS::vector_type_t<T0, SZ> *src0,
__ESIMD_EMU_DNS::SetSatur<T2,
__ESIMD_EMU_DNS::is_inttype<RT>::value>::set();

constexpr __ESIMD_NS::uint ops_per_chan =
src1_precision == __ESIMD_ENS::argument_type::BF16 ||
src1_precision == __ESIMD_ENS::argument_type::FP16 ||
src2_precision == __ESIMD_ENS::argument_type::BF16 ||
src2_precision == __ESIMD_ENS::argument_type::FP16
? 2
: src1_precision == __ESIMD_ENS::argument_type::S8 ||
src1_precision == __ESIMD_ENS::argument_type::U8 ||
src2_precision == __ESIMD_ENS::argument_type::S8 ||
src2_precision == __ESIMD_ENS::argument_type::U8
? 4
: 8;

__ESIMD_NS::uint V = 0, U = 0, k = 0, temp = 0, src1_ops_per_dword = 0, p = 0;

constexpr auto src1_el_bits = __esimd_dpas_bits_precision(src1_precision);
constexpr auto src2_el_bits = __esimd_dpas_bits_precision(src2_precision);

constexpr auto max_el_bits = std::max(src1_el_bits, src2_el_bits);
constexpr __ESIMD_NS::uint ops_per_chan =
std::min(32 / max_el_bits, static_cast<__ESIMD_NS::uint>(8));

uint32_t src1_signed =
src1_precision == __ESIMD_ENS::argument_type::S2 ||
src1_precision == __ESIMD_ENS::argument_type::S4 ||
Expand Down
21 changes: 12 additions & 9 deletions sycl/include/sycl/ext/intel/experimental/esimd/math.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1761,7 +1761,8 @@ template <argument_type src1_precision, argument_type src2_precision,
typename Sat = __ESIMD_NS::saturation_off_tag>
__ESIMD_API __ESIMD_NS::simd<T, N>
dpas(__ESIMD_NS::simd<T0, N> src0, __ESIMD_NS::simd<T1, N1> src1,
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
__ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
// types: dst, src0, src1, src2
// ud, d | ud, d | ub,b,u4,s4,u2,s2 | ub,b,u4,s4,u2,s2
constexpr bool check_integer =
Expand Down Expand Up @@ -1894,7 +1895,8 @@ template <argument_type src1_precision, argument_type src2_precision,
typename Sat = __ESIMD_NS::saturation_off_tag>
__ESIMD_API __ESIMD_NS::simd<T, N>
dpas(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T1, N1> src1,
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
__ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
return dpas<src1_precision, src2_precision, T, systolic_depth, repeat_count>(
src0, src1, src2, sat);
}
Expand All @@ -1911,9 +1913,9 @@ template <argument_type src1_precision, argument_type src2_precision,
int systolic_depth, int repeat_count, typename T, typename T1,
typename T2, int N, int N1, int N2,
typename Sat = __ESIMD_NS::saturation_off_tag>
__ESIMD_API __ESIMD_NS::simd<T, N> dpas(__ESIMD_NS::simd<T1, N1> src1,
__ESIMD_NS::simd<T2, N2> src2,
Sat sat = {}) {
__ESIMD_API __ESIMD_NS::simd<T, N>
dpas(__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {

static_assert(__ESIMD_DNS::is_fp_or_dword_type<T>::value,
"Dst must be FP or DWORD type");
Expand Down Expand Up @@ -1976,7 +1978,8 @@ template <argument_type src1_precision, argument_type src2_precision,
typename Sat = __ESIMD_NS::saturation_off_tag>
__ESIMD_API __ESIMD_NS::simd<T, N>
dpasw(__ESIMD_NS::simd<T, N> src0, __ESIMD_NS::simd<T1, N1> src1,
__ESIMD_NS::simd<T2, N2> src2, Sat sat = {}) {
__ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
constexpr bool is_4xhf =
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
(src1_precision == src2_precision) &&
Expand Down Expand Up @@ -2048,9 +2051,9 @@ template <argument_type src1_precision, argument_type src2_precision,
int systolic_depth, int repeat_count, typename T, typename T1,
typename T2, int N, int N1, int N2,
typename Sat = __ESIMD_NS::saturation_off_tag>
__ESIMD_API __ESIMD_NS::simd<T, N> dpasw2(__ESIMD_NS::simd<T1, N1> src1,
__ESIMD_NS::simd<T2, N2> src2,
Sat sat = {}) {
__ESIMD_API __ESIMD_NS::simd<T, N>
dpasw2(__ESIMD_NS::simd<T1, N1> src1, __ESIMD_NS::simd<T2, N2> src2,
std::enable_if_t<__ESIMD_DNS::is_saturation_tag_v<Sat>, Sat> sat = {}) {
constexpr bool is_4xhf =
std::is_same_v<T, __ESIMD_DNS::__raw_t<sycl::half>> &&
src1_precision == src2_precision && src1_precision == argument_type::FP16;
Expand Down