Skip to content

Commit 12a8dcb

Browse files
authored
Merge pull request #1515 from IntelPython/use_dpctl_remainder_func
use_dpctl_remainder_func
2 parents 8ace62f + 988ec20 commit 12a8dcb

12 files changed

+268
-270
lines changed

dpnp/backend/include/dpnp_iface_fptr.hpp

+17-19
Original file line numberDiff line numberDiff line change
@@ -313,25 +313,23 @@ enum class DPNPFuncName : size_t
313313
DPNP_FN_PUT_ALONG_AXIS_EXT, /**< Used in numpy.put_along_axis() impl,
314314
requires extra parameters */
315315
DPNP_FN_QR, /**< Used in numpy.linalg.qr() impl */
316-
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
317-
parameters */
318-
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
319-
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
320-
parameters */
321-
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
322-
DPNP_FN_REMAINDER_EXT, /**< Used in numpy.remainder() impl, requires extra
323-
parameters */
324-
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
325-
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
326-
parameters */
327-
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
328-
DPNP_FN_REPEAT_EXT, /**< Used in numpy.repeat() impl, requires extra
329-
parameters */
330-
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
331-
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
332-
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
333-
parameters */
334-
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
316+
DPNP_FN_QR_EXT, /**< Used in numpy.linalg.qr() impl, requires extra
317+
parameters */
318+
DPNP_FN_RADIANS, /**< Used in numpy.radians() impl */
319+
DPNP_FN_RADIANS_EXT, /**< Used in numpy.radians() impl, requires extra
320+
parameters */
321+
DPNP_FN_REMAINDER, /**< Used in numpy.remainder() impl */
322+
DPNP_FN_RECIP, /**< Used in numpy.recip() impl */
323+
DPNP_FN_RECIP_EXT, /**< Used in numpy.recip() impl, requires extra
324+
parameters */
325+
DPNP_FN_REPEAT, /**< Used in numpy.repeat() impl */
326+
DPNP_FN_REPEAT_EXT, /**< Used in numpy.repeat() impl, requires extra
327+
parameters */
328+
DPNP_FN_RIGHT_SHIFT, /**< Used in numpy.right_shift() impl */
329+
DPNP_FN_RNG_BETA, /**< Used in numpy.random.beta() impl */
330+
DPNP_FN_RNG_BETA_EXT, /**< Used in numpy.random.beta() impl, requires extra
331+
parameters */
332+
DPNP_FN_RNG_BINOMIAL, /**< Used in numpy.random.binomial() impl */
335333
DPNP_FN_RNG_BINOMIAL_EXT, /**< Used in numpy.random.binomial() impl,
336334
requires extra parameters */
337335
DPNP_FN_RNG_CHISQUARE, /**< Used in numpy.random.chisquare() impl */

dpnp/backend/kernels/dpnp_krnl_mathematical.cpp

-50
Original file line numberDiff line numberDiff line change
@@ -988,23 +988,6 @@ void (*dpnp_remainder_default_c)(void *,
988988
const size_t *) =
989989
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
990990

991-
template <typename _DataType_output,
992-
typename _DataType_input1,
993-
typename _DataType_input2>
994-
DPCTLSyclEventRef (*dpnp_remainder_ext_c)(DPCTLSyclQueueRef,
995-
void *,
996-
const void *,
997-
const size_t,
998-
const shape_elem_type *,
999-
const size_t,
1000-
const void *,
1001-
const size_t,
1002-
const shape_elem_type *,
1003-
const size_t,
1004-
const size_t *,
1005-
const DPCTLEventVectorRef) =
1006-
dpnp_remainder_c<_DataType_output, _DataType_input1, _DataType_input2>;
1007-
1008991
template <typename _KernelNameSpecialization1,
1009992
typename _KernelNameSpecialization2,
1010993
typename _KernelNameSpecialization3>
@@ -1385,39 +1368,6 @@ void func_map_init_mathematical(func_map_t &fmap)
13851368
fmap[DPNPFuncName::DPNP_FN_REMAINDER][eft_DBL][eft_DBL] = {
13861369
eft_DBL, (void *)dpnp_remainder_default_c<double, double, double>};
13871370

1388-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_INT] = {
1389-
eft_INT, (void *)dpnp_remainder_ext_c<int32_t, int32_t, int32_t>};
1390-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_LNG] = {
1391-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int32_t, int64_t>};
1392-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_FLT] = {
1393-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, float>};
1394-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_INT][eft_DBL] = {
1395-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int32_t, double>};
1396-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_INT] = {
1397-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int32_t>};
1398-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_LNG] = {
1399-
eft_LNG, (void *)dpnp_remainder_ext_c<int64_t, int64_t, int64_t>};
1400-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_FLT] = {
1401-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, float>};
1402-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_LNG][eft_DBL] = {
1403-
eft_DBL, (void *)dpnp_remainder_ext_c<double, int64_t, double>};
1404-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_INT] = {
1405-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int32_t>};
1406-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_LNG] = {
1407-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, int64_t>};
1408-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_FLT] = {
1409-
eft_FLT, (void *)dpnp_remainder_ext_c<float, float, float>};
1410-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_FLT][eft_DBL] = {
1411-
eft_DBL, (void *)dpnp_remainder_ext_c<double, float, double>};
1412-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_INT] = {
1413-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int32_t>};
1414-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_LNG] = {
1415-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, int64_t>};
1416-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_FLT] = {
1417-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, float>};
1418-
fmap[DPNPFuncName::DPNP_FN_REMAINDER_EXT][eft_DBL][eft_DBL] = {
1419-
eft_DBL, (void *)dpnp_remainder_ext_c<double, double, double>};
1420-
14211371
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_INT] = {
14221372
eft_DBL, (void *)dpnp_trapz_default_c<int32_t, int32_t, double>};
14231373
fmap[DPNPFuncName::DPNP_FN_TRAPZ][eft_INT][eft_LNG] = {

dpnp/dpnp_algo/dpnp_algo.pxd

-5
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
188188
DPNP_FN_QR_EXT
189189
DPNP_FN_RADIANS
190190
DPNP_FN_RADIANS_EXT
191-
DPNP_FN_REMAINDER
192-
DPNP_FN_REMAINDER_EXT
193191
DPNP_FN_RECIP
194192
DPNP_FN_RECIP_EXT
195193
DPNP_FN_REPEAT
@@ -442,9 +440,6 @@ cpdef dpnp_descriptor dpnp_minimum(dpnp_descriptor x1_obj, dpnp_descriptor x2_ob
442440
cpdef dpnp_descriptor dpnp_negative(dpnp_descriptor array1)
443441
cpdef dpnp_descriptor dpnp_power(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
444442
dpnp_descriptor out=*, object where=*)
445-
cpdef dpnp_descriptor dpnp_remainder(dpnp_descriptor x1_obj, dpnp_descriptor x2_obj, object dtype=*,
446-
dpnp_descriptor out=*, object where=*)
447-
448443

449444
"""
450445
Array manipulation routines

dpnp/dpnp_algo/dpnp_algo_mathematical.pxi

-9
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ __all__ += [
6262
"dpnp_negative",
6363
"dpnp_power",
6464
"dpnp_prod",
65-
"dpnp_remainder",
6665
"dpnp_sign",
6766
"dpnp_sum",
6867
"dpnp_trapz",
@@ -546,14 +545,6 @@ cpdef utils.dpnp_descriptor dpnp_prod(utils.dpnp_descriptor x1,
546545
return result
547546

548547

549-
cpdef utils.dpnp_descriptor dpnp_remainder(utils.dpnp_descriptor x1_obj,
550-
utils.dpnp_descriptor x2_obj,
551-
object dtype=None,
552-
utils.dpnp_descriptor out=None,
553-
object where=True):
554-
return call_fptr_2in_1out(DPNP_FN_REMAINDER_EXT, x1_obj, x2_obj, dtype, out, where)
555-
556-
557548
cpdef utils.dpnp_descriptor dpnp_sign(utils.dpnp_descriptor x1):
558549
return call_fptr_1in_1out_strides(DPNP_FN_SIGN_EXT, x1)
559550

dpnp/dpnp_algo/dpnp_elementwise_common.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"dpnp_logical_xor",
6767
"dpnp_multiply",
6868
"dpnp_not_equal",
69+
"dpnp_remainder",
6970
"dpnp_right_shift",
7071
"dpnp_sin",
7172
"dpnp_sqrt",
@@ -86,7 +87,7 @@ def check_nd_call_func(
8687
**kwargs,
8788
):
8889
"""
89-
Checks arguments and calls function with a single input array.
90+
Checks arguments and calls a function.
9091
9192
Chooses a common internal elementwise function to call in DPNP based on input arguments
9293
or to fallback on NumPy call if any passed argument is not currently supported.
@@ -127,7 +128,6 @@ def check_nd_call_func(
127128
order
128129
)
129130
)
130-
131131
return dpnp_func(*x_args, out=out, order=order)
132132
return call_origin(
133133
origin_func,
@@ -1174,6 +1174,49 @@ def dpnp_not_equal(x1, x2, out=None, order="K"):
11741174
return dpnp_array._create_from_usm_ndarray(res_usm)
11751175

11761176

1177+
_remainder_docstring_ = """
1178+
remainder(x1, x2, out=None, order='K')
1179+
Calculates the remainder of division for each element `x1_i` of the input array
1180+
`x1` with the respective element `x2_i` of the input array `x2`.
1181+
This function is equivalent to the Python modulus operator.
1182+
Args:
1183+
x1 (dpnp.ndarray):
1184+
First input array, expected to have a real-valued data type.
1185+
x2 (dpnp.ndarray):
1186+
Second input array, also expected to have a real-valued data type.
1187+
out ({None, usm_ndarray}, optional):
1188+
Output array to populate.
1189+
Array have the correct shape and the expected data type.
1190+
order ("C","F","A","K", optional):
1191+
Memory layout of the newly output array, if parameter `out` is `None`.
1192+
Default: "K".
1193+
Returns:
1194+
dpnp.ndarray:
1195+
an array containing the element-wise remainders. The data type of
1196+
the returned array is determined by the Type Promotion Rules.
1197+
"""
1198+
1199+
1200+
remainder_func = BinaryElementwiseFunc(
1201+
"remainder",
1202+
ti._remainder_result_type,
1203+
ti._remainder,
1204+
_remainder_docstring_,
1205+
)
1206+
1207+
1208+
def dpnp_remainder(x1, x2, out=None, order="K"):
1209+
# dpctl.tensor only works with usm_ndarray or scalar
1210+
x1_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x1)
1211+
x2_usm_or_scalar = dpnp.get_usm_ndarray_or_scalar(x2)
1212+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
1213+
1214+
res_usm = remainder_func(
1215+
x1_usm_or_scalar, x2_usm_or_scalar, out=out_usm, order=order
1216+
)
1217+
return dpnp_array._create_from_usm_ndarray(res_usm)
1218+
1219+
11771220
_right_shift_docstring_ = """
11781221
right_shift(x1, x2, out=None, order='K')
11791222

0 commit comments

Comments
 (0)