Skip to content

Commit e0349c5

Browse files
committed
Resolves gh-1439
Adjusts logic in expm1 and sin for negative 0s inputs in real and complex cases
1 parent 5ec9fd5 commit e0349c5

File tree

2 files changed

+22
-11
lines changed

2 files changed

+22
-11
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,22 @@ template <typename argT, typename resT> struct Expm1Functor
114114
}
115115

116116
// x, y finite numbers
117-
realT cosY_val;
118-
auto cosY_val_multi_ptr = sycl::address_space_cast<
119-
sycl::access::address_space::private_space,
120-
sycl::access::decorated::yes>(&cosY_val);
121-
const realT sinY_val = sycl::sincos(y, cosY_val_multi_ptr);
122-
const realT sinhalfY_val = std::sin(y / 2);
117+
const realT cosY_val = std::cos(y);
118+
const realT sinY_val = (y == 0) ? y : std::sin(y);
119+
const realT sinhalfY_val = (y == 0) ? y : std::sin(y / 2);
123120

124121
const realT res_re =
125122
std::expm1(x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
126-
const realT res_im = std::exp(x) * sinY_val;
123+
realT res_im = std::exp(x) * sinY_val;
127124
return resT{res_re, res_im};
128125
}
129126
else {
130127
static_assert(std::is_floating_point_v<argT> ||
131128
std::is_same_v<argT, sycl::half>);
129+
static_assert(std::is_same_v<argT, resT>);
130+
if (in == 0) {
131+
return in;
132+
}
132133
return std::expm1(in);
133134
}
134135
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,18 @@ template <typename argT, typename resT> struct SinFunctor
8181
*/
8282
if (in_re_finite && in_im_finite) {
8383
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
84-
return exprm_ns::sin(
84+
resT res = exprm_ns::sin(
8585
exprm_ns::complex<realT>(in)); // std::sin(in);
86+
if (in_re == realT(0)) {
87+
res.real(std::copysign(realT(0), in_re));
88+
}
89+
return res;
8690
#else
87-
return std::sin(in);
91+
resT res = std::sin(in);
92+
if (in_re == realT(0)) {
93+
res.real(std::copysign(realT(0), in_re));
94+
}
95+
return res;
8896
#endif
8997
}
9098

@@ -176,8 +184,10 @@ template <typename argT, typename resT> struct SinFunctor
176184
return resT{sinh_im, -sinh_re};
177185
}
178186
else {
179-
static_assert(std::is_floating_point_v<argT> ||
180-
std::is_same_v<argT, sycl::half>);
187+
static_assert(std::is_same_v<argT, resT>);
188+
if (in == 0) {
189+
return in;
190+
}
181191
return std::sin(in);
182192
}
183193
}

0 commit comments

Comments
 (0)