Skip to content

Commit c6533b5

Browse files
authored
Merge ccf4783 into 76170e0
2 parents 76170e0 + ccf4783 commit c6533b5

File tree

12 files changed

+111
-17
lines changed

12 files changed

+111
-17
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
2020
* Added implementation of `dpnp.linalg.lu_solve` for 2D inputs (SciPy-compatible) [#2575](https://github.com/IntelPython/dpnp/pull/2575)
2121
* Added implementation of `dpnp.special.erfc` [#2588](https://github.com/IntelPython/dpnp/pull/2588)
22+
* Added implementation of `dpnp.special.erfcx` [#2596](https://github.com/IntelPython/dpnp/pull/2596)
2223

2324
### Changed
2425

doc/reference/special.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,5 @@ Error function and Fresnel integrals
1515
erf
1616
erfc
1717
erfcx
18-
erfi
1918
erfinv
2019
erfcinv

dpnp/backend/extensions/ufunc/elementwise_functions/erf_funcs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ static void populate(py::module_ m,
209209

210210
MACRO_DEFINE_IMPL(erf, Erf);
211211
MACRO_DEFINE_IMPL(erfc, Erfc);
212+
MACRO_DEFINE_IMPL(erfcx, Erfcx);
212213
} // namespace impl
213214

214215
void init_erf_funcs(py::module_ m)
@@ -228,5 +229,9 @@ void init_erf_funcs(py::module_ m)
228229
impl::populate<impl::ErfcContigFactory, impl::ErfcStridedFactory>(
229230
m, "_erfc", "", impl::erfc_contig_dispatch_vector,
230231
impl::erfc_strided_dispatch_vector);
232+
233+
impl::populate<impl::ErfcxContigFactory, impl::ErfcxStridedFactory>(
234+
m, "_erfcx", "", impl::erfcx_contig_dispatch_vector,
235+
impl::erfcx_strided_dispatch_vector);
231236
}
232237
} // namespace dpnp::extensions::ufunc

dpnp/backend/extensions/vm/erf_funcs.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ using ew_cmn_ns::unary_contig_impl_fn_ptr_t;
130130

131131
MACRO_DEFINE_IMPL(erf, Erf);
132132
MACRO_DEFINE_IMPL(erfc, Erfc);
133+
MACRO_DEFINE_IMPL(erfcx, Erfcx);
133134

134135
template <template <typename fnT, typename T> typename factoryT>
135136
static void populate(py::module_ m,
@@ -184,5 +185,11 @@ void init_erf_funcs(py::module_ m)
184185
"Call `erfc` function from OneMKL VM library to compute the "
185186
"complementary error function value of vector elements",
186187
impl::erfc_contig_dispatch_vector);
188+
189+
impl::populate<impl::ErfcxContigFactory>(
190+
m, "_erfcx",
191+
"Call `erfcx` function from OneMKL VM library to compute the scaled "
192+
"complementary error function value of vector elements",
193+
impl::erfcx_contig_dispatch_vector);
187194
}
188195
} // namespace dpnp::extensions::vm

dpnp/backend/kernels/elementwise_functions/erf.hpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@
2929

3030
#include <sycl/sycl.hpp>
3131

32+
/**
33+
* Include <sycl/ext/intel/math.hpp> only when targeting to Intel devices.
34+
*/
35+
#if defined(__INTEL_LLVM_COMPILER)
36+
#define __SYCL_EXT_INTEL_MATH_SUPPORT
37+
#endif
38+
39+
#if defined(__SYCL_EXT_INTEL_MATH_SUPPORT)
40+
#include <sycl/ext/intel/math.hpp>
41+
#else
42+
#include <cmath>
43+
#endif
44+
3245
namespace dpnp::kernels::erfs
3346
{
3447
template <typename OpT, typename ArgT, typename ResT>
@@ -62,13 +75,20 @@ struct BaseFunctor
6275
template <typename Tp> \
6376
static Tp apply(const Tp &x) \
6477
{ \
65-
return sycl::__name__(x); \
78+
return __name__(x); \
6679
} \
6780
}; \
6881
\
6982
template <typename ArgT, typename ResT> \
7083
using __f_name__##Functor = BaseFunctor<__f_name__##Op, ArgT, ResT>;
7184

72-
MACRO_DEFINE_FUNCTOR(erf, Erf);
73-
MACRO_DEFINE_FUNCTOR(erfc, Erfc);
85+
MACRO_DEFINE_FUNCTOR(sycl::erf, Erf);
86+
MACRO_DEFINE_FUNCTOR(sycl::erfc, Erfc);
87+
MACRO_DEFINE_FUNCTOR(
88+
#if defined(__SYCL_EXT_INTEL_MATH_SUPPORT)
89+
sycl::ext::intel::math::erfcx,
90+
#else
91+
std::erfc,
92+
#endif
93+
Erfcx);
7494
} // namespace dpnp::kernels::erfs

dpnp/special/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,11 @@
4242
from ._erf import (
4343
erf,
4444
erfc,
45+
erfcx,
4546
)
4647

4748
__all__ = [
4849
"erf",
4950
"erfc",
51+
"erfcx",
5052
]

dpnp/special/_erf.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import dpnp.backend.extensions.ufunc._ufunc_impl as ufi
4343
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPUnaryFunc
4444

45-
__all__ = ["erf", "erfc"]
45+
__all__ = ["erf", "erfc", "erfcx"]
4646

4747

4848
# pylint: disable=too-few-public-methods
@@ -96,7 +96,6 @@ def __call__(self, x, out=None): # pylint: disable=signature-differs
9696
:obj:`dpnp.special.erfinv` : Inverse of the error function.
9797
:obj:`dpnp.special.erfcinv` : Inverse of the complementary error function.
9898
:obj:`dpnp.special.erfcx` : Scaled complementary error function.
99-
:obj:`dpnp.special.erfi` : Imaginary error function.
10099
101100
Notes
102101
-----
@@ -152,7 +151,6 @@ def __call__(self, x, out=None): # pylint: disable=signature-differs
152151
:obj:`dpnp.special.erfinv` : Inverse of the error function.
153152
:obj:`dpnp.special.erfcinv` : Inverse of the complementary error function.
154153
:obj:`dpnp.special.erfcx` : Scaled complementary error function.
155-
:obj:`dpnp.special.erfi` : Imaginary error function.
156154
157155
Examples
158156
--------
@@ -171,3 +169,48 @@ def __call__(self, x, out=None): # pylint: disable=signature-differs
171169
mkl_fn_to_call="_mkl_erf_to_call",
172170
mkl_impl_fn="_erfc",
173171
)
172+
173+
_ERFCX_DOCSTRING = r"""
174+
Calculates the scaled complementary error function of a given input array.
175+
176+
It is defined as :math:`\exp(x^2) * \operatorname{erfc}(x)`.
177+
178+
For full documentation refer to :obj:`scipy.special.erfcx`.
179+
180+
Parameters
181+
----------
182+
x : {dpnp.ndarray, usm_ndarray}
183+
Input array, expected to have a real-valued floating-point data type.
184+
out : {dpnp.ndarray, usm_ndarray}, optional
185+
Optional output array for the function values.
186+
187+
Returns
188+
-------
189+
out : dpnp.ndarray
190+
The values of the scaled complementary error function at the given points
191+
`x`.
192+
193+
See Also
194+
--------
195+
:obj:`dpnp.special.erf` : Gauss error function.
196+
:obj:`dpnp.special.erfc` : Complementary error function.
197+
:obj:`dpnp.special.erfinv` : Inverse of the error function.
198+
:obj:`dpnp.special.erfcinv` : Inverse of the complementary error function.
199+
200+
Examples
201+
--------
202+
>>> import dpnp as np
203+
>>> x = np.linspace(-3, 3, num=4)
204+
>>> np.special.erfcx(x)
205+
array([1.62059889e+04, 5.00898008e+00, 4.27583576e-01, 1.79001151e-01])
206+
207+
"""
208+
209+
erfcx = DPNPErf(
210+
"erfcx",
211+
ufi._erf_result_type,
212+
ufi._erfcx,
213+
_ERFCX_DOCSTRING,
214+
mkl_fn_to_call="_mkl_erf_to_call",
215+
mkl_impl_fn="_erfcx",
216+
)

dpnp/tests/test_special.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
@with_requires("scipy")
17-
@pytest.mark.parametrize("func", ["erf", "erfc"])
17+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
1818
class TestCommon:
1919
@pytest.mark.parametrize(
2020
"dt", get_all_dtypes(no_none=True, no_float16=False, no_complex=True)
@@ -65,18 +65,36 @@ def test_complex(self, func, dt):
6565

6666
class TestConsistency:
6767

68-
def test_erfc(self):
68+
tol = 8 * dpnp.finfo(dpnp.default_float_type()).resolution
69+
70+
def _check_variant_func(self, func, other_func, rtol, atol=0):
6971
# TODO: replace with dpnp.random.RandomState, once pareto is added
7072
rng = numpy.random.RandomState(1234)
7173
n = 10000
7274
a = rng.pareto(0.02, n) * (2 * rng.randint(0, 2, n) - 1)
7375
a = dpnp.array(a)
76+
a = a[::-1]
7477

75-
res = 1 - dpnp.special.erf(a)
78+
res = other_func(a)
7679
mask = dpnp.isfinite(res)
7780
a = a[mask]
7881

79-
tol = 8 * dpnp.finfo(a).resolution
80-
assert dpnp.allclose(
81-
dpnp.special.erfc(a), res[mask], rtol=tol, atol=tol
82+
x, y = func(a), res[mask]
83+
if not dpnp.allclose(x, y, rtol=rtol, atol=atol):
84+
# calling numpy testing func, because it's more verbose
85+
assert_allclose(x.asnumpy(), y.asnumpy(), rtol=rtol, atol=atol)
86+
87+
def test_erfc(self):
88+
self._check_variant_func(
89+
dpnp.special.erfc,
90+
lambda z: 1 - dpnp.special.erf(z),
91+
rtol=self.tol,
92+
atol=self.tol,
93+
)
94+
95+
def test_erfcx(self):
96+
self._check_variant_func(
97+
dpnp.special.erfcx,
98+
lambda z: dpnp.exp(z * z) * dpnp.special.erfc(z),
99+
rtol=10 * self.tol,
82100
)

dpnp/tests/test_strides.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def test_reduce_hypot(dtype, stride):
167167

168168

169169
@with_requires("scipy")
170-
@pytest.mark.parametrize("func", ["erf", "erfc"])
170+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
171171
@pytest.mark.parametrize("stride", [2, -1, -3])
172172
def test_erf_funcs(func, stride):
173173
import scipy.special

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1489,7 +1489,7 @@ def test_interp(device, left, right, period):
14891489
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14901490

14911491

1492-
@pytest.mark.parametrize("func", ["erf", "erfc"])
1492+
@pytest.mark.parametrize("func", ["erf", "erfc", "erfcx"])
14931493
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
14941494
def test_erf_funcs(func, device):
14951495
x = dpnp.linspace(-3, 3, num=5, device=device)

0 commit comments

Comments
 (0)