Skip to content

Commit ebf94b0

Browse files
Merge pull request #1477 from IntelPython/resolve-gh-1439
Resolves gh-1439
2 parents a4369ac + 1c4b0ab commit ebf94b0

File tree

2 files changed

+18
-11
lines changed

2 files changed

+18
-11
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,21 +114,22 @@ template <typename argT, typename resT> struct Expm1Functor
114114
}
115115

116116
// x, y finite numbers
117-
realT cosY_val;
118-
auto cosY_val_multi_ptr = sycl::address_space_cast<
119-
sycl::access::address_space::private_space,
120-
sycl::access::decorated::yes>(&cosY_val);
121-
const realT sinY_val = sycl::sincos(y, cosY_val_multi_ptr);
122-
const realT sinhalfY_val = std::sin(y / 2);
117+
const realT cosY_val = std::cos(y);
118+
const realT sinY_val = (y == 0) ? y : std::sin(y);
119+
const realT sinhalfY_val = (y == 0) ? y : std::sin(y / 2);
123120

124121
const realT res_re =
125122
std::expm1(x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
126-
const realT res_im = std::exp(x) * sinY_val;
123+
realT res_im = std::exp(x) * sinY_val;
127124
return resT{res_re, res_im};
128125
}
129126
else {
130127
static_assert(std::is_floating_point_v<argT> ||
131128
std::is_same_v<argT, sycl::half>);
129+
static_assert(std::is_same_v<argT, resT>);
130+
if (in == 0) {
131+
return in;
132+
}
132133
return std::expm1(in);
133134
}
134135
}

dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,15 @@ template <typename argT, typename resT> struct SinFunctor
8181
*/
8282
if (in_re_finite && in_im_finite) {
8383
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
84-
return exprm_ns::sin(
84+
resT res = exprm_ns::sin(
8585
exprm_ns::complex<realT>(in)); // std::sin(in);
8686
#else
87-
return std::sin(in);
87+
resT res = std::sin(in);
8888
#endif
89+
if (in_re == realT(0)) {
90+
res.real(std::copysign(realT(0), in_re));
91+
}
92+
return res;
8993
}
9094

9195
/*
@@ -176,8 +180,10 @@ template <typename argT, typename resT> struct SinFunctor
176180
return resT{sinh_im, -sinh_re};
177181
}
178182
else {
179-
static_assert(std::is_floating_point_v<argT> ||
180-
std::is_same_v<argT, sycl::half>);
183+
static_assert(std::is_same_v<argT, resT>);
184+
if (in == 0) {
185+
return in;
186+
}
181187
return std::sin(in);
182188
}
183189
}

0 commit comments

Comments
 (0)