Skip to content

Commit d8437e6

Browse files
Implement logaddexp in numerically stable way
1 parent 22081fb commit d8437e6

File tree

1 file changed

+12
-9
lines changed
  • dpctl/tensor/libtensor/include/kernels/elementwise_functions

1 file changed

+12
-9
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/logaddexp.hpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,25 @@ template <typename argT1, typename argT2, typename resT> struct LogAddExpFunctor
6262

6363
resT operator()(const argT1 &in1, const argT2 &in2)
6464
{
65-
return std::log(std::exp(in1) + std::exp(in2));
65+
resT max = std::max<resT>(in1, in2);
66+
resT min = std::min<resT>(in1, in2);
67+
return max + std::log(resT(1) + std::exp(min - max));
6668
}
6769

6870
template <int vec_sz>
6971
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
7072
const sycl::vec<argT2, vec_sz> &in2)
7173
{
72-
auto res = sycl::log(sycl::exp(in1) + sycl::exp(in2));
73-
if constexpr (std::is_same_v<resT,
74-
typename decltype(res)::element_type>) {
75-
return res;
76-
}
77-
else {
78-
return vec_cast<resT, typename decltype(res)::element_type, vec_sz>(
79-
res);
74+
sycl::vec<resT, vec_sz> res;
75+
auto diff = in1 - in2;
76+
77+
#pragma unroll
78+
for (int i = 0; i < vec_sz; ++i) {
79+
resT max = std::max<resT>(in1[i], in2[i]);
80+
res[i] = max + std::log(resT(1) + std::exp(std::abs(diff[i])));
8081
}
82+
83+
return res;
8184
}
8285
};
8386

0 commit comments

Comments
 (0)