From 19179e6555bba96ba99121019465f90bfa21567d Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 9 Aug 2023 12:28:22 -0700 Subject: [PATCH 1/4] Fixed some complex special cases for expm1 --- .../kernels/elementwise_functions/expm1.hpp | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 7b59285884..9a31b90ea1 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include "kernels/elementwise_functions/common.hpp" @@ -73,6 +74,45 @@ template struct Expm1Functor const realT x = std::real(in); const realT y = std::imag(in); + // special cases + if (std::isinf(x)) { + if (x > realT(0)) { + // positive infinity cases + if (!std::isfinite(y)) { + return resT{x, std::numeric_limits::quiet_NaN()}; + } + else if (y == realT(0)) { + return in; + } + else { + return (std::numeric_limits::infinity() * + resT{std::cos(y), std::sin(y)} - + realT(1)); + } + } + else { + // negative infinity cases + if (!std::isfinite(y)) { + return resT{-1, 0}; + } + else { + return (realT(0) * resT{std::cos(y), std::sin(y)} - + realT(1)); + } + } + } + + if (std::isnan(x)) { + if (y == realT(0)) { + return in; + } + else { + return resT{std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}; + } + } + + // x, y finite numbers realT cosY_val; const realT sinY_val = sycl::sincos(y, &cosY_val); const realT sinhalfY_val = std::sin(y / 2); From ccc405097ae90a5c96c7017effbc37769d23dacf Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 9 Aug 2023 13:27:50 -0700 Subject: [PATCH 2/4] Silence remaining warnings in elementwise tests --- dpctl/tests/elementwise/test_expm1.py | 56 +++++++++++++++++++-------- dpctl/tests/elementwise/test_log1p.py | 47 +++++++++++++++------- 2 files changed, 74 insertions(+), 29 deletions(-) diff --git a/dpctl/tests/elementwise/test_expm1.py b/dpctl/tests/elementwise/test_expm1.py index 20dc421c77..ba95d2b96d 100644 --- a/dpctl/tests/elementwise/test_expm1.py +++ b/dpctl/tests/elementwise/test_expm1.py @@ -116,29 +116,53 @@ def test_expm1_order(dtype): def test_expm1_special_cases(): - q = get_queue_or_skip() + get_queue_or_skip() - X = dpt.asarray( - [dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4", sycl_queue=q - ) - Xnp = dpt.asnumpy(X) + X = dpt.asarray([dpt.nan, 0.0, -0.0, dpt.inf, -dpt.inf], dtype="f4") + res = np.asarray([np.nan, 0.0, -0.0, np.inf, -1.0], dtype="f4") tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol - ) + assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol) # special cases for complex variant + num_finite = 1.0 vals = [ - complex(*val) - for val in itertools.permutations( - [dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0], 2 - ) + complex(0.0, 0.0), + complex(num_finite, dpt.inf), + complex(num_finite, dpt.nan), + complex(dpt.inf, 0.0), + complex(-dpt.inf, num_finite), + complex(dpt.inf, num_finite), + complex(-dpt.inf, dpt.inf), + complex(dpt.inf, dpt.inf), + complex(-dpt.inf, dpt.nan), + complex(dpt.inf, dpt.nan), + complex(dpt.nan, 0.0), + complex(dpt.nan, num_finite), + complex(dpt.nan, dpt.nan), ] X = dpt.asarray(vals, dtype=dpt.complex64) - Xnp = dpt.asnumpy(X) + cis_1 = complex(np.cos(num_finite), np.sin(num_finite)) + c_nan = complex(np.nan, np.nan) + res = np.asarray( + [ + complex(0.0, 0.0), + c_nan, + c_nan, + complex(np.inf, 0.0), + 0.0 * cis_1 - 1.0, + np.inf * cis_1 - 1.0, + complex(-1.0, 0.0), + complex(np.inf, np.nan), + complex(-1.0, 0.0), + complex(np.inf, np.nan), + complex(np.nan, 0.0), + c_nan, + c_nan, + ], + dtype=np.complex64, + ) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.expm1(X)), np.expm1(Xnp), atol=tol, rtol=tol - ) + with np.errstate(invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.expm1(X)), res, atol=tol, rtol=tol) diff --git a/dpctl/tests/elementwise/test_log1p.py b/dpctl/tests/elementwise/test_log1p.py index 2820b4f8b5..d40574c415 100644 --- a/dpctl/tests/elementwise/test_log1p.py +++ b/dpctl/tests/elementwise/test_log1p.py @@ -119,28 +119,49 @@ def test_log1p_special_cases(): q = get_queue_or_skip() X = dpt.asarray( - [dpt.nan, -1.0, -2.0, 0.0, -0.0, dpt.inf, -dpt.inf], + [dpt.nan, -2.0, -1.0, -0.0, 0.0, dpt.inf], dtype="f4", sycl_queue=q, ) - Xnp = dpt.asnumpy(X) + res = np.asarray([np.nan, np.nan, -np.inf, -0.0, 0.0, np.inf]) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol - ) + with np.errstate(divide="ignore", invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol) # special cases for complex vals = [ - complex(*val) - for val in itertools.permutations( - [dpt.nan, dpt.inf, -dpt.inf, 0.0, -0.0, 1.0, -1.0, -2.0], 2 - ) + complex(-1.0, 0.0), + complex(2.0, dpt.inf), + complex(2.0, dpt.nan), + complex(-dpt.inf, 1.0), + complex(dpt.inf, 1.0), + complex(-dpt.inf, dpt.inf), + complex(dpt.inf, dpt.inf), + complex(dpt.inf, dpt.nan), + complex(dpt.nan, 1.0), + complex(dpt.nan, dpt.inf), + complex(dpt.nan, dpt.nan), ] X = dpt.asarray(vals, dtype=dpt.complex64) - Xnp = dpt.asnumpy(X) + c_nan = complex(np.nan, np.nan) + res = np.asarray( + [ + complex(-np.inf, 0.0), + complex(np.inf, np.pi / 2), + c_nan, + complex(np.inf, np.pi), + complex(np.inf, 0.0), + complex(np.inf, 3 * np.pi / 4), + complex(np.inf, np.pi / 4), + complex(np.inf, np.nan), + c_nan, + complex(np.inf, np.nan), + c_nan, + ], + dtype=np.complex64, + ) tol = dpt.finfo(X.dtype).resolution - assert_allclose( - dpt.asnumpy(dpt.log1p(X)), np.log1p(Xnp), atol=tol, rtol=tol - ) + with np.errstate(invalid="ignore"): + assert_allclose(dpt.asnumpy(dpt.log1p(X)), res, atol=tol, rtol=tol) From 64282d6ecd1de25e3504178b3b971af417589d80 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 9 Aug 2023 14:39:51 -0700 Subject: [PATCH 3/4] Tweaks to expm1 complex special case logic --- .../include/kernels/elementwise_functions/expm1.hpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index 9a31b90ea1..a3cb05f664 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -85,19 +85,19 @@ template struct Expm1Functor return in; } else { - return (std::numeric_limits::infinity() * - resT{std::cos(y), std::sin(y)} - - realT(1)); + return (x * resT{std::cos(y), std::sin(y)}); } } else { // negative infinity cases if (!std::isfinite(y)) { - return resT{-1, 0}; + // copy sign of y to guarantee + // conj(expm1(x)) == expm1(conj(x)) + return resT{realT(-1), std::copysign(realT(0), y)}; } else { - return (realT(0) * resT{std::cos(y), std::sin(y)} - - realT(1)); + return resT{realT(-1), + std::copysign(realT(0), std::sin(y))}; } } } From 27ea9c26513cc2aa7bc4ace6392f207df507036e Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Thu, 10 Aug 2023 14:12:18 -0700 Subject: [PATCH 4/4] expm1 special case change For `inf` real part and finite, nonzero imaginary part, now guaranteed to be (+/-`inf`, +/-`inf`), with cosine and sine of the imaginary part determining the sign, respectively --- .../libtensor/include/kernels/elementwise_functions/expm1.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp index a3cb05f664..b996a6d0ec 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/expm1.hpp @@ -85,7 +85,8 @@ template struct Expm1Functor return in; } else { - return (x * resT{std::cos(y), std::sin(y)}); + return (resT{std::copysign(x, std::cos(y)), + std::copysign(x, std::sin(y))}); } } else {