diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 7b59285884..b996a6d0ec 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -73,6 +74,46 @@ template struct Expm1Functor const realT x = std::real(in); const realT y = std::imag(in); + // special cases + if (std::isinf(x)) { + if (x > realT(0)) { + // positive infinity cases + if (!std::isfinite(y)) { + return resT{x, std::numeric_limits::quiet_NaN()}; + } + else if (y == realT(0)) { + return in; + } + else { + return (resT{std::copysign(x, std::cos(y)), + std::copysign(x, std::sin(y))}); + } + } + else { + // negative infinity cases + if (!std::isfinite(y)) { + // copy sign of y to guarantee + // conj(expm1(x)) == expm1(conj(x)) + return resT{realT(-1), std::copysign(realT(0), y)}; + } + else { + return resT{realT(-1), + std::copysign(realT(0), std::sin(y))}; + } + } + } + + if (std::isnan(x)) { + if (y == realT(0)) { + return in; + } + else { + return resT{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}; + } + } + + // x, y finite numbers realT cosY_val; const realT sinY_val = sycl::sincos(y, &cosY_val); const realT sinhalfY_val = std::sin(y / 2); diff --git a/dpctl/tests/elementwise/test_expm1.py b/dpctl/tests/elementwise/test_expm1.py index 20dc421c77..ba95d2b96d 100644 --- a/dpctl/tests/elementwise/test_expm1.py +++ b/dpctl/tests/elementwise/test_expm1.py @@ -116,29 +116,53 @@ def test_expm1_order(dtype): def test_expm1_special_cases(): - q = get_queue_or_skip() + get_queue_or_skip() - X = dpt.asarray( - [dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q - ) - Xnp = dpt.asnumpy(X) + X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4") + res = np.asarray([np.nan, 0.0, -0.0, np.inf, -1.0], dtype="f4") tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol - ) + assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol) # special cases for complex variant + num_finite = 1.0 vals = [ - complex(*val) - for val in itertools.permutations( - [dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0], 2 - ) + complex(0.0, 0.0), + complex(num_finite, dpt.inf), + complex(num_finite, dpt.nan), + complex(dpt.inf, 0.0), + complex(-dpt.inf, num_finite), + complex(dpt.inf, num_finite), + complex(-dpt.inf, dpt.inf), + complex(dpt.inf, dpt.inf), + complex(-dpt.inf, dpt.nan), + complex(dpt.inf, dpt.nan), + complex(dpt.nan, 0.0), + complex(dpt.nan, num_finite), + complex(dpt.nan, dpt.nan), ] X = dpt.asarray(vals, dtype=dpt.complex64) - Xnp = dpt.asnumpy(X) + cis_1 = complex(np.cos(num_finite), np.sin(num_finite)) + c_nan = complex(np.nan, np.nan) + res = np.asarray( + [ + complex(0.0, 0.0), + c_nan, + c_nan, + complex(np.inf, 0.0), + 0.0 * cis_1 - 1.0, + np.inf * cis_1 - 1.0, + complex(-1.0, 0.0), + complex(np.inf, np.nan), + complex(-1.0, 0.0), + complex(np.inf, np.nan), + complex(np.nan, 0.0), + c_nan, + c_nan, + ], + dtype=np.complex64, + ) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol - ) + with np.errstate(invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol) diff --git a/dpctl/tests/elementwise/test_log1p.py b/dpctl/tests/elementwise/test_log1p.py index 2820b4f8b5..d40574c415 100644 --- a/dpctl/tests/elementwise/test_log1p.py +++ b/dpctl/tests/elementwise/test_log1p.py @@ -119,28 +119,49 @@ def test_log1p_special_cases(): q = get_queue_or_skip() X = dpt.asarray( - [dpt.nan, -1.0, -2.0, 0.0, -0.0, dpt.inf, -dpt.inf], + [dpt.nan, -2.0, -1.0, -0.0, 0.0, dpt.inf], dtype="f4", sycl_queue=q, ) - Xnp = dpt.asnumpy(X) + res = np.asarray([np.nan, np.nan, -np.inf, -0.0, 0.0, np.inf]) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol - ) + with np.errstate(divide="ignore", invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol) # special cases for complex vals = [ - complex(*val) - for val in itertools.permutations( - [dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0, -1.0, -2.0], 2 - ) + complex(-1.0, 0.0), + complex(2.0, dpt.inf), + complex(2.0, dpt.nan), + complex(-dpt.inf, 1.0), + complex(dpt.inf, 1.0), + complex(-dpt.inf, dpt.inf), + complex(dpt.inf, dpt.inf), + complex(dpt.inf, dpt.nan), + complex(dpt.nan, 1.0), + complex(dpt.nan, dpt.inf), + complex(dpt.nan, dpt.nan), ] X = dpt.asarray(vals, dtype=dpt.complex64) - Xnp = dpt.asnumpy(X) + c_nan = complex(np.nan, np.nan) + res = np.asarray( + [ + complex(-np.inf, 0.0), + complex(np.inf, np.pi / 2), + c_nan, + complex(np.inf, np.pi), + complex(np.inf, 0.0), + complex(np.inf, 3 * np.pi / 4), + complex(np.inf, np.pi / 4), + complex(np.inf, np.nan), + c_nan, + complex(np.inf, np.nan), + c_nan, + ], + dtype=np.complex64, + ) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol - ) + with np.errstate(invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol)