diff --git a/dpnp/backend/extensions/vm/sqrt.hpp b/dpnp/backend/extensions/vm/sqrt.hpp new file mode 100644 index 000000000000..df20ee4b0509 --- /dev/null +++ b/dpnp/backend/extensions/vm/sqrt.hpp @@ -0,0 +1,78 @@ +//***************************************************************************** +// Copyright (c) 2023, Intel Corporation +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// - Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// - Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF +// THE POSSIBILITY OF SUCH DAMAGE. +//***************************************************************************** + +#pragma once + +#include + +#include "common.hpp" +#include "types_matrix.hpp" + +namespace dpnp +{ +namespace backend +{ +namespace ext +{ +namespace vm +{ +template +sycl::event sqrt_contig_impl(sycl::queue exec_q, + const std::int64_t n, + const char *in_a, + char *out_y, + const std::vector &depends) +{ + type_utils::validate_type_for_device(exec_q); + + const T *a = reinterpret_cast(in_a); + T *y = reinterpret_cast(out_y); + + return mkl_vm::sqrt(exec_q, + n, // number of elements to be calculated + a, // pointer `a` containing input vector of size n + y, // pointer `y` to the output vector of size n + depends); +} + +template +struct SqrtContigFactory +{ + fnT get() + { + if constexpr (std::is_same_v< + typename types::SqrtOutputType::value_type, void>) + { + return nullptr; + } + else { + return sqrt_contig_impl; + } + } +}; +} // namespace vm +} // namespace ext +} // namespace backend +} // namespace dpnp diff --git a/dpnp/backend/extensions/vm/types_matrix.hpp b/dpnp/backend/extensions/vm/types_matrix.hpp index d7c9f9eecdf1..019f646febdf 100644 --- a/dpnp/backend/extensions/vm/types_matrix.hpp +++ b/dpnp/backend/extensions/vm/types_matrix.hpp @@ -124,6 +124,25 @@ struct SinOutputType dpctl_td_ns::TypeMapResultEntry, dpctl_td_ns::DefaultResultEntry>::result_type; }; + +/** + * @brief A factory to define pairs of supported types for which + * MKL VM library provides support in oneapi::mkl::vm::sqrt function. + * + * @tparam T Type of input vector `a` and of result vector `y`. + */ +template +struct SqrtOutputType +{ + using value_type = typename std::disjunction< + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, + dpctl_td_ns:: + TypeMapResultEntry, std::complex>, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::TypeMapResultEntry, + dpctl_td_ns::DefaultResultEntry>::result_type; +}; } // namespace types } // namespace vm } // namespace ext diff --git a/dpnp/backend/extensions/vm/vm_py.cpp b/dpnp/backend/extensions/vm/vm_py.cpp index 0e26d310a2d4..aeb12ac51094 100644 --- a/dpnp/backend/extensions/vm/vm_py.cpp +++ b/dpnp/backend/extensions/vm/vm_py.cpp @@ -35,6 +35,7 @@ #include "div.hpp" #include "ln.hpp" #include "sin.hpp" +#include "sqrt.hpp" #include "types_matrix.hpp" namespace py = pybind11; @@ -48,6 +49,7 @@ static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types]; static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types]; +static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types]; PYBIND11_MODULE(_vm_impl, m) { @@ -167,4 +169,34 @@ PYBIND11_MODULE(_vm_impl, m) "OneMKL VM library can be used", py::arg("sycl_queue"), py::arg("src"), py::arg("dst")); } + + // UnaryUfunc: ==== Sqrt(x) ==== + { + vm_ext::init_ufunc_dispatch_vector( + sqrt_dispatch_vector); + + auto sqrt_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst, + const event_vecT &depends = {}) { + return vm_ext::unary_ufunc(exec_q, src, dst, depends, + sqrt_dispatch_vector); + }; + m.def( + "_sqrt", sqrt_pyapi, + "Call `sqrt` from OneMKL VM library to performs element by element " + "operation of extracting the square root " + "of vector `src` to resulting vector `dst`", + py::arg("sycl_queue"), py::arg("src"), py::arg("dst"), + py::arg("depends") = py::list()); + + auto sqrt_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src, + arrayT dst) { + return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst, + sqrt_dispatch_vector); + }; + m.def("_mkl_sqrt_to_call", sqrt_need_to_call_pyapi, + "Check input arguments to answer if `sqrt` function from " + "OneMKL VM library can be used", + py::arg("sycl_queue"), py::arg("src"), py::arg("dst")); + } } diff --git a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp index 7f6b66afc58a..bebbcc1f4252 100644 --- a/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_elemwise.cpp @@ -729,10 +729,7 @@ static void func_map_init_elemwise_1arg_2type(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_SQRT][eft_DBL][eft_DBL] = { eft_DBL, (void *)dpnp_sqrt_c_default}; - fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_INT][eft_INT] = { - eft_DBL, (void *)dpnp_sqrt_c_ext}; - fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_LNG][eft_LNG] = { - eft_DBL, (void *)dpnp_sqrt_c_ext}; + // Used in dpnp_std_c fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_FLT][eft_FLT] = { eft_FLT, (void *)dpnp_sqrt_c_ext}; fmap[DPNPFuncName::DPNP_FN_SQRT_EXT][eft_DBL][eft_DBL] = { diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index 433034f04c6f..8a93e5927d20 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -295,8 +295,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_SINH_EXT DPNP_FN_SORT DPNP_FN_SORT_EXT - DPNP_FN_SQRT - DPNP_FN_SQRT_EXT DPNP_FN_SQUARE DPNP_FN_SQUARE_EXT DPNP_FN_STD @@ -553,7 +551,6 @@ cpdef dpnp_descriptor dpnp_log2(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_radians(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_recip(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_sinh(dpnp_descriptor array1) -cpdef dpnp_descriptor dpnp_sqrt(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_square(dpnp_descriptor array1) cpdef dpnp_descriptor dpnp_tan(dpnp_descriptor array1, dpnp_descriptor out) cpdef dpnp_descriptor dpnp_tanh(dpnp_descriptor array1) diff --git a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi index 8f4ea3bc80cb..c9c59fb58837 100644 --- a/dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_trigonometric.pxi @@ -54,7 +54,6 @@ __all__ += [ 'dpnp_radians', 'dpnp_recip', 'dpnp_sinh', - 'dpnp_sqrt', 'dpnp_square', 'dpnp_tan', 'dpnp_tanh', @@ -134,10 +133,6 @@ cpdef utils.dpnp_descriptor dpnp_sinh(utils.dpnp_descriptor x1): return call_fptr_1in_1out_strides(DPNP_FN_SINH_EXT, x1) -cpdef utils.dpnp_descriptor dpnp_sqrt(utils.dpnp_descriptor x1, utils.dpnp_descriptor out): - return call_fptr_1in_1out_strides(DPNP_FN_SQRT_EXT, x1, dtype=None, out=out, where=True, func_name='sqrt') - - cpdef utils.dpnp_descriptor dpnp_square(utils.dpnp_descriptor x1): return call_fptr_1in_1out_strides(DPNP_FN_SQUARE_EXT, x1) diff --git a/dpnp/dpnp_algo/dpnp_elementwise_common.py b/dpnp/dpnp_algo/dpnp_elementwise_common.py index 5a2f8905d777..94f22201fc28 100644 --- a/dpnp/dpnp_algo/dpnp_elementwise_common.py +++ b/dpnp/dpnp_algo/dpnp_elementwise_common.py @@ -54,6 +54,7 @@ "dpnp_multiply", "dpnp_not_equal", "dpnp_sin", + "dpnp_sqrt", "dpnp_subtract", ] @@ -685,6 +686,57 @@ def _call_sin(src, dst, sycl_queue, depends=None): return dpnp_array._create_from_usm_ndarray(res_usm) +_sqrt_docstring_ = """ +sqrt(x, out=None, order='K') +Computes the non-negative square-root for each element `x_i` for input array `x`. +Args: + x (dpnp.ndarray): + Input array. + out ({None, dpnp.ndarray}, optional): + Output array to populate. Array must have the correct + shape and the expected data type. + order ("C","F","A","K", optional): memory layout of the new + output array, if parameter `out` is `None`. + Default: "K". +Return: + dpnp.ndarray: + An array containing the element-wise square-root results. +""" + + +def dpnp_sqrt(x, out=None, order="K"): + """ + Invokes sqrt() function from pybind11 extension of OneMKL VM if possible. + + Otherwise fully relies on dpctl.tensor implementation for sqrt() function. + + """ + + def _call_sqrt(src, dst, sycl_queue, depends=None): + """A callback to register in UnaryElementwiseFunc class of dpctl.tensor""" + + if depends is None: + depends = [] + + if vmi._mkl_sqrt_to_call(sycl_queue, src, dst): + # call pybind11 extension for sqrt() function from OneMKL VM + return vmi._sqrt(sycl_queue, src, dst, depends) + return ti._sqrt(src, dst, sycl_queue, depends) + + # dpctl.tensor only works with usm_ndarray or scalar + x_usm = dpnp.get_usm_ndarray(x) + out_usm = None if out is None else dpnp.get_usm_ndarray(out) + + func = UnaryElementwiseFunc( + "sqrt", + ti._sqrt_result_type, + _call_sqrt, + _sqrt_docstring_, + ) + res_usm = func(x_usm, out=out_usm, order=order) + return dpnp_array._create_from_usm_ndarray(res_usm) + + _subtract_docstring_ = """ subtract(x1, x2, out=None, order="K") diff --git a/dpnp/dpnp_iface_bitwise.py b/dpnp/dpnp_iface_bitwise.py index ecaa68ec2f75..4bffb71417fb 100644 --- a/dpnp/dpnp_iface_bitwise.py +++ b/dpnp/dpnp_iface_bitwise.py @@ -133,7 +133,7 @@ def bitwise_and(x1, x2, dtype=None, out=None, where=True, **kwargs): Returns ------- y : dpnp.ndarray - An array containing the element-wise results. + An array containing the element-wise results of positive square root. Limitations ----------- diff --git a/dpnp/dpnp_iface_trigonometric.py b/dpnp/dpnp_iface_trigonometric.py index 89f6d7c05e0f..4e020b651412 100644 --- a/dpnp/dpnp_iface_trigonometric.py +++ b/dpnp/dpnp_iface_trigonometric.py @@ -40,7 +40,6 @@ """ -import dpctl.tensor as dpt import numpy import dpnp @@ -52,6 +51,7 @@ dpnp_cos, dpnp_log, dpnp_sin, + dpnp_sqrt, ) __all__ = [ @@ -1048,51 +1048,64 @@ def sinh(x1): return call_origin(numpy.sinh, x1, **kwargs) -def sqrt(x1, /, out=None, **kwargs): +def sqrt( + x, + /, + out=None, + *, + order="K", + where=True, + dtype=None, + subok=True, + **kwargs, +): """ - Return the positive square-root of an array, element-wise. + Return the non-negative square-root of an array, element-wise. For full documentation refer to :obj:`numpy.sqrt`. + Returns + ------- + y : dpnp.ndarray + An array of the same shape as `x`, containing the positive + square-root of each element in `x`. If any element in `x` is + complex, a complex array is returned (and the square-roots of + negative reals are calculated). If all of the elements in `x` + are real, so is `y`, with negative elements returning ``nan``. + Limitations ----------- Input array is supported as either :class:`dpnp.ndarray` or :class:`dpctl.tensor.usm_ndarray`. Parameter `out` is supported as class:`dpnp.ndarray`, class:`dpctl.tensor.usm_ndarray` or with default value ``None``. + Parameters `where`, `dtype` and `subok` are supported with their default values. Otherwise the function will be executed sequentially on CPU. - Keyword arguments ``kwargs`` are currently unsupported. Input array data types are limited by supported DPNP :ref:`Data types`. Examples -------- >>> import dpnp as np >>> x = np.array([1, 4, 9]) - >>> out = np.sqrt(x) - >>> [i for i in out] - [1.0, 2.0, 3.0] + >>> np.sqrt(x) + array([1., 2., 3.]) + + >>> x2 = np.array([4, -1, np.inf]) + >>> np.sqrt(x2) + array([ 2., nan, inf]) """ - x1_desc = ( - dpnp.get_dpnp_descriptor( - x1, copy_when_strides=False, copy_when_nondefault_queue=False - ) - if not kwargs - else None + return check_nd_call_func( + numpy.sqrt, + dpnp_sqrt, + x, + out=out, + where=where, + order=order, + dtype=dtype, + subok=subok, + **kwargs, ) - if x1_desc: - if out is not None: - if not isinstance(out, (dpnp.ndarray, dpt.usm_ndarray)): - raise TypeError("return array must be of supported array type") - out_desc = ( - dpnp.get_dpnp_descriptor(out, copy_when_nondefault_queue=False) - or None - ) - else: - out_desc = None - return dpnp_sqrt(x1_desc, out=out_desc).get_pyobj() - - return call_origin(numpy.sqrt, x1, out=out, **kwargs) def square(x1): diff --git a/dpnp/linalg/dpnp_algo_linalg.pyx b/dpnp/linalg/dpnp_algo_linalg.pyx index 56b18935ce42..69a3efa223f7 100644 --- a/dpnp/linalg/dpnp_algo_linalg.pyx +++ b/dpnp/linalg/dpnp_algo_linalg.pyx @@ -366,7 +366,7 @@ cpdef object dpnp_norm(object input, ord=None, axis=None): input = dpnp.ravel(input, order='K') sqnorm = dpnp.dot(input, input) - ret = dpnp.sqrt([sqnorm]) + ret = dpnp.sqrt(sqnorm) return dpnp.array(ret.reshape(1, *ret.shape), dtype=res_type) len_axis = 1 if axis is None else len(axis_) diff --git a/tests/test_umath.py b/tests/test_umath.py index 8dd98244ccb2..071f27376f83 100644 --- a/tests/test_umath.py +++ b/tests/test_umath.py @@ -11,7 +11,6 @@ from .helper import ( get_all_dtypes, get_complex_dtypes, - get_float_dtypes, has_support_aspect16, has_support_aspect64, ) @@ -550,22 +549,60 @@ def test_invalid_shape(self, shape): class TestSqrt: - @pytest.mark.parametrize("dtype", get_float_dtypes()) - def test_sqrt_ordinary(self, dtype): - array_data = numpy.arange(10) - out = numpy.empty(10, dtype=dtype) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_complex=True) + ) + def test_sqrt_int_float(self, dtype): + np_array = numpy.arange(10, dtype=dtype) + np_out = numpy.empty(10, dtype=numpy.float64) # DPNP - dp_array = dpnp.array(array_data, dtype=dtype) - dp_out = dpnp.array(out, dtype=dtype) + dp_out_dtype = dpnp.float32 + if has_support_aspect64() and dtype != dpnp.float32: + dp_out_dtype = dpnp.float64 + + dp_out = dpnp.array(np_out, dtype=dp_out_dtype) + dp_array = dpnp.array(np_array, dtype=dtype) result = dpnp.sqrt(dp_array, out=dp_out) # original - np_array = numpy.array(array_data, dtype=dtype) - expected = numpy.sqrt(np_array, out=out) + expected = numpy.sqrt(np_array, out=np_out) + assert_allclose(expected, result) + + @pytest.mark.parametrize("dtype", get_complex_dtypes()) + def test_sqrt_complex(self, dtype): + np_array = numpy.arange(10, 20, dtype=dtype) + np_out = numpy.empty(10, dtype=numpy.complex128) - numpy.testing.assert_allclose(expected, result) - numpy.testing.assert_allclose(out, dp_out) + # DPNP + dp_out_dtype = dpnp.complex64 + if has_support_aspect64() and dtype != dpnp.complex64: + dp_out_dtype = dpnp.complex128 + + dp_out = dpnp.array(np_out, dtype=dp_out_dtype) + dp_array = dpnp.array(np_array, dtype=dtype) + result = dpnp.sqrt(dp_array, out=dp_out) + + # original + expected = numpy.sqrt(np_array, out=np_out) + assert_allclose(expected, result) + + @pytest.mark.usefixtures("suppress_divide_numpy_warnings") + @pytest.mark.skipif( + not has_support_aspect16(), reason="No fp16 support by device" + ) + def test_sqrt_bool(self): + np_array = numpy.arange(2, dtype=numpy.bool_) + np_out = numpy.empty(2, dtype=numpy.float16) + + # DPNP + dp_array = dpnp.array(np_array, dtype=np_array.dtype) + dp_out = dpnp.array(np_out, dtype=np_out.dtype) + result = dpnp.sqrt(dp_array, out=dp_out) + + # original + expected = numpy.sqrt(np_array, out=np_out) + assert_allclose(expected, result) @pytest.mark.parametrize( "dtype", [numpy.int64, numpy.int32], ids=["numpy.int64", "numpy.int32"] @@ -574,7 +611,7 @@ def test_invalid_dtype(self, dtype): dp_array = dpnp.arange(10, dtype=dpnp.float32) dp_out = dpnp.empty(10, dtype=dtype) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dpnp.sqrt(dp_array, out=dp_out) @pytest.mark.parametrize( @@ -584,7 +621,7 @@ def test_invalid_shape(self, shape): dp_array = dpnp.arange(10, dtype=dpnp.float32) dp_out = dpnp.empty(shape, dtype=dpnp.float32) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dpnp.sqrt(dp_array, out=dp_out) @pytest.mark.parametrize(