File tree 2 files changed +18
-11
lines changed
dpctl/tensor/libtensor/include/kernels/elementwise_functions 2 files changed +18
-11
lines changed Original file line number Diff line number Diff line change @@ -114,21 +114,22 @@ template <typename argT, typename resT> struct Expm1Functor
114
114
}
115
115
116
116
// x, y finite numbers
117
- realT cosY_val;
118
- auto cosY_val_multi_ptr = sycl::address_space_cast<
119
- sycl::access ::address_space::private_space,
120
- sycl::access ::decorated::yes>(&cosY_val);
121
- const realT sinY_val = sycl::sincos (y, cosY_val_multi_ptr);
122
- const realT sinhalfY_val = std::sin (y / 2 );
117
+ const realT cosY_val = std::cos (y);
118
+ const realT sinY_val = (y == 0 ) ? y : std::sin (y);
119
+ const realT sinhalfY_val = (y == 0 ) ? y : std::sin (y / 2 );
123
120
124
121
const realT res_re =
125
122
std::expm1 (x) * cosY_val - 2 * sinhalfY_val * sinhalfY_val;
126
- const realT res_im = std::exp (x) * sinY_val;
123
+ realT res_im = std::exp (x) * sinY_val;
127
124
return resT{res_re, res_im};
128
125
}
129
126
else {
130
127
static_assert (std::is_floating_point_v<argT> ||
131
128
std::is_same_v<argT, sycl::half>);
129
+ static_assert (std::is_same_v<argT, resT>);
130
+ if (in == 0 ) {
131
+ return in;
132
+ }
132
133
return std::expm1 (in);
133
134
}
134
135
}
Original file line number Diff line number Diff line change @@ -81,11 +81,15 @@ template <typename argT, typename resT> struct SinFunctor
81
81
*/
82
82
if (in_re_finite && in_im_finite) {
83
83
#ifdef USE_SYCL_FOR_COMPLEX_TYPES
84
- return exprm_ns::sin (
84
+ resT res = exprm_ns::sin (
85
85
exprm_ns::complex<realT>(in)); // std::sin(in);
86
86
#else
87
- return std::sin (in);
87
+ resT res = std::sin (in);
88
88
#endif
89
+ if (in_re == realT (0 )) {
90
+ res.real (std::copysign (realT (0 ), in_re));
91
+ }
92
+ return res;
89
93
}
90
94
91
95
/*
@@ -176,8 +180,10 @@ template <typename argT, typename resT> struct SinFunctor
176
180
return resT{sinh_im, -sinh_re};
177
181
}
178
182
else {
179
- static_assert (std::is_floating_point_v<argT> ||
180
- std::is_same_v<argT, sycl::half>);
183
+ static_assert (std::is_same_v<argT, resT>);
184
+ if (in == 0 ) {
185
+ return in;
186
+ }
181
187
return std::sin (in);
182
188
}
183
189
}
You can’t perform that action at this time.
0 commit comments