Skip to content

Commit 1595ab7

Browse files
authored
Merge pull request #1520 from IntelPython/use_dpctl_round_func
use_dpctl_round_func_in_dpnp
2 parents 3a8ba50 + a4cab86 commit 1595ab7

16 files changed

+267
-213
lines changed

dpnp/backend/extensions/vm/round.hpp

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event round_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
char *out_y,
46+
const std::vector<sycl::event> &depends)
47+
{
48+
type_utils::validate_type_for_device<T>(exec_q);
49+
50+
const T *a = reinterpret_cast<const T *>(in_a);
51+
T *y = reinterpret_cast<T *>(out_y);
52+
53+
return mkl_vm::rint(exec_q,
54+
n, // number of elements to be calculated
55+
a, // pointer `a` containing input vector of size n
56+
y, // pointer `y` to the output vector of size n
57+
depends);
58+
}
59+
60+
template <typename fnT, typename T>
61+
struct RoundContigFactory
62+
{
63+
fnT get()
64+
{
65+
if constexpr (std::is_same_v<
66+
typename types::RoundOutputType<T>::value_type, void>)
67+
{
68+
return nullptr;
69+
}
70+
else {
71+
return round_contig_impl<T>;
72+
}
73+
}
74+
};
75+
} // namespace vm
76+
} // namespace ext
77+
} // namespace backend
78+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,21 @@ struct MulOutputType
203203
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
204204
};
205205

206+
/**
207+
* @brief A factory to define pairs of supported types for which
208+
* MKL VM library provides support in oneapi::mkl::vm::rint<T> function.
209+
*
210+
* @tparam T Type of input vector `a` and of result vector `y`.
211+
*/
212+
template <typename T>
213+
struct RoundOutputType
214+
{
215+
using value_type = typename std::disjunction<
216+
dpctl_td_ns::TypeMapResultEntry<T, double, double>,
217+
dpctl_td_ns::TypeMapResultEntry<T, float, float>,
218+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
219+
};
220+
206221
/**
207222
* @brief A factory to define pairs of supported types for which
208223
* MKL VM library provides support in oneapi::mkl::vm::sin<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "floor.hpp"
4040
#include "ln.hpp"
4141
#include "mul.hpp"
42+
#include "round.hpp"
4243
#include "sin.hpp"
4344
#include "sqr.hpp"
4445
#include "sqrt.hpp"
@@ -60,6 +61,7 @@ static unary_impl_fn_ptr_t floor_dispatch_vector[dpctl_td_ns::num_types];
6061
static unary_impl_fn_ptr_t conj_dispatch_vector[dpctl_td_ns::num_types];
6162
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
6263
static binary_impl_fn_ptr_t mul_dispatch_vector[dpctl_td_ns::num_types];
64+
static unary_impl_fn_ptr_t round_dispatch_vector[dpctl_td_ns::num_types];
6365
static unary_impl_fn_ptr_t sin_dispatch_vector[dpctl_td_ns::num_types];
6466
static unary_impl_fn_ptr_t sqr_dispatch_vector[dpctl_td_ns::num_types];
6567
static unary_impl_fn_ptr_t sqrt_dispatch_vector[dpctl_td_ns::num_types];
@@ -301,6 +303,34 @@ PYBIND11_MODULE(_vm_impl, m)
301303
py::arg("dst"));
302304
}
303305

306+
// UnaryUfunc: ==== Round(x) ====
307+
{
308+
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,
309+
vm_ext::RoundContigFactory>(
310+
round_dispatch_vector);
311+
312+
auto round_pyapi = [&](sycl::queue exec_q, arrayT src, arrayT dst,
313+
const event_vecT &depends = {}) {
314+
return vm_ext::unary_ufunc(exec_q, src, dst, depends,
315+
round_dispatch_vector);
316+
};
317+
m.def("_round", round_pyapi,
318+
"Call `rint` function from OneMKL VM library to compute "
319+
"the rounded value of vector elements",
320+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"),
321+
py::arg("depends") = py::list());
322+
323+
auto round_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src,
324+
arrayT dst) {
325+
return vm_ext::need_to_call_unary_ufunc(exec_q, src, dst,
326+
round_dispatch_vector);
327+
};
328+
m.def("_mkl_round_to_call", round_need_to_call_pyapi,
329+
"Check input arguments to answer if `rint` function from "
330+
"OneMKL VM library can be used",
331+
py::arg("sycl_queue"), py::arg("src"), py::arg("dst"));
332+
}
333+
304334
// UnaryUfunc: ==== Sin(x) ====
305335
{
306336
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ enum class DPNPFuncName : size_t
101101
DPNP_FN_ARGSORT_EXT, /**< Used in numpy.argsort() impl, requires extra
102102
parameters */
103103
DPNP_FN_AROUND, /**< Used in numpy.around() impl */
104-
DPNP_FN_AROUND_EXT, /**< Used in numpy.around() impl, requires extra
105-
parameters */
106104
DPNP_FN_ASTYPE, /**< Used in numpy.astype() impl */
107105
DPNP_FN_ASTYPE_EXT, /**< Used in numpy.astype() impl, requires extra
108106
parameters */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

-18
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,6 @@ template <typename _DataType>
109109
void (*dpnp_around_default_c)(const void *, void *, const size_t, const int) =
110110
dpnp_around_c<_DataType>;
111111

112-
template <typename _DataType>
113-
DPCTLSyclEventRef (*dpnp_around_ext_c)(DPCTLSyclQueueRef,
114-
const void *,
115-
void *,
116-
const size_t,
117-
const int,
118-
const DPCTLEventVectorRef) =
119-
dpnp_around_c<_DataType>;
120-
121112
template <typename _KernelNameSpecialization1,
122113
typename _KernelNameSpecialization2>
123114
class dpnp_elemwise_absolute_c_kernel;
@@ -1184,15 +1175,6 @@ void func_map_init_mathematical(func_map_t &fmap)
11841175
fmap[DPNPFuncName::DPNP_FN_AROUND][eft_DBL][eft_DBL] = {
11851176
eft_DBL, (void *)dpnp_around_default_c<double>};
11861177

1187-
fmap[DPNPFuncName::DPNP_FN_AROUND_EXT][eft_INT][eft_INT] = {
1188-
eft_INT, (void *)dpnp_around_ext_c<int32_t>};
1189-
fmap[DPNPFuncName::DPNP_FN_AROUND_EXT][eft_LNG][eft_LNG] = {
1190-
eft_LNG, (void *)dpnp_around_ext_c<int64_t>};
1191-
fmap[DPNPFuncName::DPNP_FN_AROUND_EXT][eft_FLT][eft_FLT] = {
1192-
eft_FLT, (void *)dpnp_around_ext_c<float>};
1193-
fmap[DPNPFuncName::DPNP_FN_AROUND_EXT][eft_DBL][eft_DBL] = {
1194-
eft_DBL, (void *)dpnp_around_ext_c<double>};
1195-
11961178
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_INT] = {
11971179
eft_INT, (void *)dpnp_cross_default_c<int32_t, int32_t, int32_t>};
11981180
fmap[DPNPFuncName::DPNP_FN_CROSS][eft_INT][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

-2
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
5858
DPNP_FN_ARGMIN_EXT
5959
DPNP_FN_ARGSORT
6060
DPNP_FN_ARGSORT_EXT
61-
DPNP_FN_AROUND
62-
DPNP_FN_AROUND_EXT
6361
DPNP_FN_ASTYPE
6462
DPNP_FN_ASTYPE_EXT
6563
DPNP_FN_CBRT

dpnp/dpnp_algo/dpnp_algo.pyx

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ cpdef utils.dpnp_descriptor dpnp_flatten(utils.dpnp_descriptor x1):
170170

171171
cpdef utils.dpnp_descriptor dpnp_init_val(shape, dtype, value):
172172
"""
173-
same as dpnp_full(). TODO remove code dumplication
173+
same as dpnp_full(). TODO remove code duplication
174174
"""
175175
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(dtype)
176176

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

-36
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ and the rest of the library
3838
__all__ += [
3939
"dpnp_absolute",
4040
"dpnp_arctan2",
41-
"dpnp_around",
4241
"dpnp_copysign",
4342
"dpnp_cross",
4443
"dpnp_cumprod",
@@ -72,9 +71,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_1in_2out_t)(c_dpctl.DPCTLSyclQueueRef,
7271
ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_trapz_2in_1out_with_2size_t)(c_dpctl.DPCTLSyclQueueRef,
7372
void *, void * , void * , double, size_t, size_t,
7473
const c_dpctl.DPCTLEventVectorRef)
75-
ctypedef c_dpctl.DPCTLSyclEventRef(*ftpr_custom_around_1in_1out_t)(c_dpctl.DPCTLSyclQueueRef,
76-
const void * , void * , const size_t, const int,
77-
const c_dpctl.DPCTLEventVectorRef)
7874

7975

8076
cpdef utils.dpnp_descriptor dpnp_absolute(utils.dpnp_descriptor x1):
@@ -120,38 +116,6 @@ cpdef utils.dpnp_descriptor dpnp_arctan2(utils.dpnp_descriptor x1_obj,
120116
return call_fptr_2in_1out_strides(DPNP_FN_ARCTAN2_EXT, x1_obj, x2_obj, dtype, out, where, func_name="arctan2")
121117

122118

123-
cpdef utils.dpnp_descriptor dpnp_around(utils.dpnp_descriptor x1, int decimals):
124-
125-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
126-
127-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_AROUND_EXT, param1_type, param1_type)
128-
129-
x1_obj = x1.get_array()
130-
131-
# ceate result array with type given by FPTR data
132-
cdef shape_type_c result_shape = x1.shape
133-
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape,
134-
kernel_data.return_type,
135-
None,
136-
device=x1_obj.sycl_device,
137-
usm_type=x1_obj.usm_type,
138-
sycl_queue=x1_obj.sycl_queue)
139-
140-
result_sycl_queue = result.get_array().sycl_queue
141-
142-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
143-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
144-
145-
cdef ftpr_custom_around_1in_1out_t func = <ftpr_custom_around_1in_1out_t > kernel_data.ptr
146-
147-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, x1.get_data(), result.get_data(), x1.size, decimals, NULL)
148-
149-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
150-
c_dpctl.DPCTLEvent_Delete(event_ref)
151-
152-
return result
153-
154-
155119
cpdef utils.dpnp_descriptor dpnp_copysign(utils.dpnp_descriptor x1_obj,
156120
utils.dpnp_descriptor x2_obj,
157121
object dtype=None,

dpnp/dpnp_algo/dpnp_elementwise_common.py

+52
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
"dpnp_not_equal",
7070
"dpnp_remainder",
7171
"dpnp_right_shift",
72+
"dpnp_round",
7273
"dpnp_sin",
7374
"dpnp_sqrt",
7475
"dpnp_square",
@@ -1545,6 +1546,57 @@ def dpnp_right_shift(x1, x2, out=None, order="K"):
15451546
return dpnp_array._create_from_usm_ndarray(res_usm)
15461547

15471548

1549+
_round_docstring = """
1550+
round(x, out=None, order='K')
1551+
Rounds each element `x_i` of the input array `x` to
1552+
the nearest integer-valued number.
1553+
Args:
1554+
x (dpnp.ndarray):
1555+
Input array, expected to have numeric data type.
1556+
out ({None, dpnp.ndarray}, optional):
1557+
Output array to populate. Array must have the correct
1558+
shape and the expected data type.
1559+
order ("C","F","A","K", optional): memory layout of the new
1560+
output array, if parameter `out` is `None`.
1561+
Default: "K".
1562+
Return:
1563+
dpnp.ndarray:
1564+
An array containing the element-wise rounded value. The data type
1565+
of the returned array is determined by the Type Promotion Rules.
1566+
"""
1567+
1568+
1569+
def _call_round(src, dst, sycl_queue, depends=None):
1570+
"""A callback to register in UnaryElementwiseFunc class of dpctl.tensor"""
1571+
1572+
if depends is None:
1573+
depends = []
1574+
1575+
if vmi._mkl_round_to_call(sycl_queue, src, dst):
1576+
# call pybind11 extension for round() function from OneMKL VM
1577+
return vmi._round(sycl_queue, src, dst, depends)
1578+
return ti._round(src, dst, sycl_queue, depends)
1579+
1580+
1581+
round_func = UnaryElementwiseFunc(
1582+
"round", ti._round_result_type, _call_round, _round_docstring
1583+
)
1584+
1585+
1586+
def dpnp_round(x, out=None, order="K"):
1587+
"""
1588+
Invokes round() function from pybind11 extension of OneMKL VM if possible.
1589+
1590+
Otherwise fully relies on dpctl.tensor implementation for round() function.
1591+
"""
1592+
# dpctl.tensor only works with usm_ndarray
1593+
x1_usm = dpnp.get_usm_ndarray(x)
1594+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1595+
1596+
res_usm = round_func(x1_usm, out=out_usm, order=order)
1597+
return dpnp_array._create_from_usm_ndarray(res_usm)
1598+
1599+
15481600
_sign_docstring = """
15491601
sign(x, out=None, order="K")
15501602

0 commit comments

Comments
 (0)