Skip to content

Commit 3dedb2f

Browse files
committed
use_dpctl_remainder_func
1 parent 0a45a54 commit 3dedb2f

9 files changed

+309
-196
lines changed
+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2023, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
//
13+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
14+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
15+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
16+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
17+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23+
// THE POSSIBILITY OF SUCH DAMAGE.
24+
//*****************************************************************************
25+
26+
#pragma once
27+
28+
#include <CL/sycl.hpp>
29+
30+
#include "common.hpp"
31+
#include "types_matrix.hpp"
32+
33+
namespace dpnp
34+
{
35+
namespace backend
36+
{
37+
namespace ext
38+
{
39+
namespace vm
40+
{
41+
template <typename T>
42+
sycl::event remainder_contig_impl(sycl::queue exec_q,
43+
const std::int64_t n,
44+
const char *in_a,
45+
const char *in_b,
46+
char *out_y,
47+
const std::vector<sycl::event> &depends)
48+
{
49+
type_utils::validate_type_for_device<T>(exec_q);
50+
51+
const T *a = reinterpret_cast<const T *>(in_a);
52+
const T *b = reinterpret_cast<const T *>(in_b);
53+
T *y = reinterpret_cast<T *>(out_y);
54+
55+
return mkl_vm::remainder(
56+
exec_q,
57+
n, // number of elements to be calculated
58+
a, // pointer `a` containing 1st input vector of size n
59+
b, // pointer `b` containing 2nd input vector of size n
60+
y, // pointer `y` to the output vector of size n
61+
depends);
62+
}
63+
64+
template <typename fnT, typename T>
65+
struct RemainderContigFactory
66+
{
67+
fnT get()
68+
{
69+
if constexpr (std::is_same_v<
70+
typename types::RemainderOutputType<T>::value_type,
71+
void>)
72+
{
73+
return nullptr;
74+
}
75+
else {
76+
return remainder_contig_impl<T>;
77+
}
78+
}
79+
};
80+
} // namespace vm
81+
} // namespace ext
82+
} // namespace backend
83+
} // namespace dpnp

dpnp/backend/extensions/vm/types_matrix.hpp

+15
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,21 @@ struct DivOutputType
6868
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
6969
};
7070

71+
/**
72+
* @brief A factory to define pairs of supported types for which
73+
* MKL VM library provides support in oneapi::mkl::vm::remainder<T> function.
74+
*
75+
* @tparam T Type of input vectors `a` and `b` and of result vector `y`.
76+
*/
77+
template <typename T>
78+
struct RemainderOutputType
79+
{
80+
using value_type = typename std::disjunction<
81+
dpctl_td_ns::BinaryTypeMapResultEntry<T, double, T, double, double>,
82+
dpctl_td_ns::BinaryTypeMapResultEntry<T, float, T, float, float>,
83+
dpctl_td_ns::DefaultResultEntry<void>>::result_type;
84+
};
85+
7186
/**
7287
* @brief A factory to define pairs of supported types for which
7388
* MKL VM library provides support in oneapi::mkl::vm::cos<T> function.

dpnp/backend/extensions/vm/vm_py.cpp

+33
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "cos.hpp"
3535
#include "div.hpp"
3636
#include "ln.hpp"
37+
#include "remainder.hpp"
3738
#include "sin.hpp"
3839
#include "sqr.hpp"
3940
#include "sqrt.hpp"
@@ -46,6 +47,7 @@ using vm_ext::binary_impl_fn_ptr_t;
4647
using vm_ext::unary_impl_fn_ptr_t;
4748

4849
static binary_impl_fn_ptr_t div_dispatch_vector[dpctl_td_ns::num_types];
50+
static binary_impl_fn_ptr_t remainder_dispatch_vector[dpctl_td_ns::num_types];
4951

5052
static unary_impl_fn_ptr_t cos_dispatch_vector[dpctl_td_ns::num_types];
5153
static unary_impl_fn_ptr_t ln_dispatch_vector[dpctl_td_ns::num_types];
@@ -88,6 +90,37 @@ PYBIND11_MODULE(_vm_impl, m)
8890
py::arg("dst"));
8991
}
9092

93+
// BinaryUfunc: ==== REMAINDER(x1, x2) ====
94+
{
95+
vm_ext::init_ufunc_dispatch_vector<binary_impl_fn_ptr_t,
96+
vm_ext::RemainderContigFactory>(
97+
remainder_dispatch_vector);
98+
99+
auto remainder_pyapi = [&](sycl::queue exec_q, arrayT src1, arrayT src2,
100+
arrayT dst, const event_vecT &depends = {}) {
101+
return vm_ext::binary_ufunc(exec_q, src1, src2, dst, depends,
102+
remainder_dispatch_vector);
103+
};
104+
m.def("_remainder", remainder_pyapi,
105+
"Call `remainder` function from OneMKL VM library to performs "
106+
"element "
107+
"by element remainder of vector `src1` by vector `src2` "
108+
"to resulting vector `dst`",
109+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
110+
py::arg("dst"), py::arg("depends") = py::list());
111+
112+
auto remainder_need_to_call_pyapi = [&](sycl::queue exec_q, arrayT src1,
113+
arrayT src2, arrayT dst) {
114+
return vm_ext::need_to_call_binary_ufunc(exec_q, src1, src2, dst,
115+
remainder_dispatch_vector);
116+
};
117+
m.def("_mkl_remainder_to_call", remainder_need_to_call_pyapi,
118+
"Check input arguments to answer if `remainder` function from "
119+
"OneMKL VM library can be used",
120+
py::arg("sycl_queue"), py::arg("src1"), py::arg("src2"),
121+
py::arg("dst"));
122+
}
123+
91124
// UnaryUfunc: ==== Cos(x) ====
92125
{
93126
vm_ext::init_ufunc_dispatch_vector<unary_impl_fn_ptr_t,

dpnp/backend/include/dpnp_iface_fptr.hpp

-2
Original file line numberDiff line numberDiff line change
@@ -331,8 +331,6 @@ enum class DPNPFuncName : size_t
331331
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
332332
parameters */
333333
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
334-
DPNP_FN_REMAINDER_EXT, /**< Used in numpy.remainder() impl, requires extra
335-
parameters */
336334
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
337335
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
338336
parameters */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

-50
Original file line numberDiff line numberDiff line change
@@ -988,23 +988,6 @@ void (*dpnp_remainder_default_c)(void *,
988988
const size_t *) =
989989
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
990990

991-
template <typename _DataType_output,
992-
typename _DataType_input1,
993-
typename _DataType_input2>
994-
DPCTLSyclEventRef (*dpnp_remainder_ext_c)(DPCTLSyclQueueRef,
995-
void *,
996-
const void *,
997-
const size_t,
998-
const shape_elem_type *,
999-
const size_t,
1000-
const void *,
1001-
const size_t,
1002-
const shape_elem_type *,
1003-
const size_t,
1004-
const size_t *,
1005-
const DPCTLEventVectorRef) =
1006-
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
1007-
1008991
template <typename _KernelNameSpecialization1,
1009992
typename _KernelNameSpecialization2,
1010993
typename _KernelNameSpecialization3>
@@ -1385,39 +1368,6 @@ void func_map_init_mathematical(func_map_t &fmap)
13851368
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_DBL][eft_DBL] = {
13861369
eft_DBL, (void *)dpnp_remainder_default_c<double, double, double>};
13871370

1388-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_INT] = {
1389-
eft_INT, (void *)dpnp_remainder_ext_c<int32_t, int32_t, int32_t>};
1390-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_LNG] = {
1391-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int32_t, int64_t>};
1392-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_FLT] = {
1393-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, float>};
1394-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_DBL] = {
1395-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, double>};
1396-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_INT] = {
1397-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int32_t>};
1398-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_LNG] = {
1399-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int64_t>};
1400-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_FLT] = {
1401-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, float>};
1402-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_DBL] = {
1403-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, double>};
1404-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_INT] = {
1405-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int32_t>};
1406-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_LNG] = {
1407-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int64_t>};
1408-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_FLT] = {
1409-
eft_FLT, (void *)dpnp_remainder_ext_c<float, float, float>};
1410-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_DBL] = {
1411-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, double>};
1412-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_INT] = {
1413-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int32_t>};
1414-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_LNG] = {
1415-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int64_t>};
1416-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_FLT] = {
1417-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, float>};
1418-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_DBL] = {
1419-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, double>};
1420-
14211371
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_INT] = {
14221372
eft_DBL, (void *)dpnp_trapz_default_c<int32_t, int32_t, double>};
14231373
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

-5
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
202202
DPNP_FN_QR_EXT
203203
DPNP_FN_RADIANS
204204
DPNP_FN_RADIANS_EXT
205-
DPNP_FN_REMAINDER
206-
DPNP_FN_REMAINDER_EXT
207205
DPNP_FN_RECIP
208206
DPNP_FN_RECIP_EXT
209207
DPNP_FN_REPEAT
@@ -490,9 +488,6 @@ cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
490488
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
491489
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
492490
dpnp_descriptor out=*, object where=*)
493-
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
494-
dpnp_descriptor out=*, object where=*)
495-
496491

497492
"""
498493
Array manipulation routines

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

-9
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ __all__ += [
6262
"dpnp_negative",
6363
"dpnp_power",
6464
"dpnp_prod",
65-
"dpnp_remainder",
6665
"dpnp_sign",
6766
"dpnp_sum",
6867
"dpnp_trapz",
@@ -546,14 +545,6 @@ cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
546545
return result
547546

548547

549-
cpdef utils.dpnp_descriptor dpnp_remainder(utils.dpnp_descriptor x1_obj,
550-
utils.dpnp_descriptor x2_obj,
551-
object dtype=None,
552-
utils.dpnp_descriptor out=None,
553-
object where=True):
554-
return call_fptr_2in_1out(DPNP_FN_REMAINDER_EXT, x1_obj, x2_obj, dtype, out, where)
555-
556-
557548
cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
558549
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
559550

dpnp/dpnp_algo/dpnp_elementwise_common.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
"dpnp_logical_xor",
6262
"dpnp_multiply",
6363
"dpnp_not_equal",
64+
"dpnp_remainder",
6465
"dpnp_sin",
6566
"dpnp_sqrt",
6667
"dpnp_square",
@@ -80,7 +81,7 @@ def check_nd_call_func(
8081
**kwargs,
8182
):
8283
"""
83-
Checks arguments and calls function with a single input array.
84+
Checks arguments and calls a function.
8485
8586
Chooses a common internal elementwise function to call in DPNP based on input arguments
8687
or to fallback on NumPy call if any passed argument is not currently supported.
@@ -121,7 +122,6 @@ def check_nd_call_func(
121122
order
122123
)
123124
)
124-
125125
return dpnp_func(*x_args, out=out, order=order)
126126
return call_origin(
127127
origin_func,
@@ -953,6 +953,66 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
953953
return dpnp_array._create_from_usm_ndarray(res_usm)
954954

955955

956+
_remainder_docstring_ = """
957+
remainder(x1, x2, out=None, order='K')
958+
959+
Calculates the remainder of division for each element `x1_i` of the input array
960+
`x1` with the respective element `x2_i` of the input array `x2`.
961+
962+
This function is equivalent to the Python modulus operator.
963+
964+
Args:
965+
x1 (dpnp.ndarray):
966+
First input array, expected to have a real-valued data type.
967+
x2 (dpnp.ndarray):
968+
Second input array, also expected to have a real-valued data type.
969+
out ({None, usm_ndarray}, optional):
970+
Output array to populate.
971+
Array have the correct shape and the expected data type.
972+
order ("C","F","A","K", optional):
973+
Memory layout of the newly output array, if parameter `out` is `None`.
974+
Default: "K".
975+
Returns:
976+
dpnp.ndarray:
977+
an array containing the element-wise remainders. The data type of
978+
the returned array is determined by the Type Promotion Rules.
979+
"""
980+
981+
982+
def dpnp_remainder(x1, x2, out=None, order="K"):
983+
"""
984+
Invokes remainder() function from pybind11 extension of OneMKL VM if possible.
985+
986+
Otherwise fully relies on dpctl.tensor implementation for remainder() function.
987+
988+
"""
989+
990+
def _call_remainder(src1, src2, dst, sycl_queue, depends=None):
991+
"""A callback to register in BinaryElementwiseFunc class of dpctl.tensor"""
992+
993+
if depends is None:
994+
depends = []
995+
996+
if vmi._mkl_remainder_to_call(sycl_queue, src1, src2, dst):
997+
# call pybind11 extension for remainder() function from OneMKL VM
998+
return vmi._remainder(sycl_queue, src1, src2, dst, depends)
999+
return ti._remainder(src1, src2, dst, sycl_queue, depends)
1000+
1001+
# dpctl.tensor only works with usm_ndarray or scalar
1002+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
1003+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
1004+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1005+
1006+
func = BinaryElementwiseFunc(
1007+
"remainder",
1008+
ti._remainder_result_type,
1009+
_call_remainder,
1010+
_remainder_docstring_,
1011+
)
1012+
res_usm = func(x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order)
1013+
return dpnp_array._create_from_usm_ndarray(res_usm)
1014+
1015+
9561016
_sin_docstring = """
9571017
sin(x, out=None, order='K')
9581018
Computes sine for each element `x_i` of input array `x`.

0 commit comments

Comments
 (0)