Skip to content

Commit 1650774

Browse files
Leverage dpctl.tensor.all() and dpctl.tensor.any() implementations (#1512)
* Leverage dpctl.tensor.all implementation * Leverage dpctl.tensor.any implementation * Add new test_truth and test_logic files and expand scope of tests
1 parent 1f15ccc commit 1650774

File tree

7 files changed

+169
-159
lines changed

7 files changed

+169
-159
lines changed

.github/workflows/conda-package.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ env:
1717
test_dparray.py
1818
test_fft.py
1919
test_linalg.py
20+
test_logic.py
2021
test_mathematical.py
2122
test_random_state.py
2223
test_sort.py
2324
test_special.py
2425
test_umath.py
2526
test_usm_type.py
2627
third_party/cupy/linalg_tests/test_product.py
28+
third_party/cupy/logic_tests/test_truth.py
2729
third_party/cupy/manipulation_tests/test_join.py
2830
third_party/cupy/math_tests/test_explog.py
2931
third_party/cupy/math_tests/test_misc.py

dpnp/backend/include/dpnp_iface_fptr.hpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,12 @@ enum class DPNPFuncName : size_t
6565
DPNP_FN_ADD, /**< Used in numpy.add() impl */
6666
DPNP_FN_ADD_EXT, /**< Used in numpy.add() impl, requires extra parameters */
6767
DPNP_FN_ALL, /**< Used in numpy.all() impl */
68-
DPNP_FN_ALL_EXT, /**< Used in numpy.all() impl, requires extra parameters */
69-
DPNP_FN_ALLCLOSE, /**< Used in numpy.allclose() impl */
70-
DPNP_FN_ALLCLOSE_EXT, /**< Used in numpy.allclose() impl, requires extra
71-
parameters */
72-
DPNP_FN_ANY, /**< Used in numpy.any() impl */
73-
DPNP_FN_ANY_EXT, /**< Used in numpy.any() impl, requires extra parameters */
74-
DPNP_FN_ARANGE, /**< Used in numpy.arange() impl */
75-
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() impl */
68+
DPNP_FN_ALLCLOSE, /**< Used in numpy.allclose() impl */
69+
DPNP_FN_ALLCLOSE_EXT, /**< Used in numpy.allclose() impl, requires extra
70+
parameters */
71+
DPNP_FN_ANY, /**< Used in numpy.any() impl */
72+
DPNP_FN_ARANGE, /**< Used in numpy.arange() impl */
73+
DPNP_FN_ARCCOS, /**< Used in numpy.arccos() impl */
7674
DPNP_FN_ARCCOS_EXT, /**< Used in numpy.arccos() impl, requires extra
7775
parameters */
7876
DPNP_FN_ARCCOSH, /**< Used in numpy.arccosh() impl */

dpnp/backend/kernels/dpnp_krnl_logic.cpp

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -640,21 +640,6 @@ void func_map_init_logic(func_map_t &fmap)
640640
fmap[DPNPFuncName::DPNP_FN_ALL][eft_DBL][eft_DBL] = {
641641
eft_DBL, (void *)dpnp_all_default_c<double, bool>};
642642

643-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_BLN][eft_BLN] = {
644-
eft_BLN, (void *)dpnp_all_ext_c<bool, bool>};
645-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_INT][eft_INT] = {
646-
eft_INT, (void *)dpnp_all_ext_c<int32_t, bool>};
647-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_LNG][eft_LNG] = {
648-
eft_LNG, (void *)dpnp_all_ext_c<int64_t, bool>};
649-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_FLT][eft_FLT] = {
650-
eft_FLT, (void *)dpnp_all_ext_c<float, bool>};
651-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_DBL][eft_DBL] = {
652-
eft_DBL, (void *)dpnp_all_ext_c<double, bool>};
653-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C64][eft_C64] = {
654-
eft_C64, (void *)dpnp_all_ext_c<std::complex<float>, bool>};
655-
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C128][eft_C128] = {
656-
eft_C128, (void *)dpnp_all_ext_c<std::complex<double>, bool>};
657-
658643
fmap[DPNPFuncName::DPNP_FN_ALLCLOSE][eft_INT][eft_INT] = {
659644
eft_BLN, (void *)dpnp_allclose_default_c<int32_t, int32_t, bool>};
660645
fmap[DPNPFuncName::DPNP_FN_ALLCLOSE][eft_LNG][eft_INT] = {
@@ -732,21 +717,6 @@ void func_map_init_logic(func_map_t &fmap)
732717
fmap[DPNPFuncName::DPNP_FN_ANY][eft_DBL][eft_DBL] = {
733718
eft_DBL, (void *)dpnp_any_default_c<double, bool>};
734719

735-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_BLN][eft_BLN] = {
736-
eft_BLN, (void *)dpnp_any_ext_c<bool, bool>};
737-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_INT][eft_INT] = {
738-
eft_INT, (void *)dpnp_any_ext_c<int32_t, bool>};
739-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_LNG][eft_LNG] = {
740-
eft_LNG, (void *)dpnp_any_ext_c<int64_t, bool>};
741-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_FLT][eft_FLT] = {
742-
eft_FLT, (void *)dpnp_any_ext_c<float, bool>};
743-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_DBL][eft_DBL] = {
744-
eft_DBL, (void *)dpnp_any_ext_c<double, bool>};
745-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C64][eft_C64] = {
746-
eft_C64, (void *)dpnp_any_ext_c<std::complex<float>, bool>};
747-
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C128][eft_C128] = {
748-
eft_C128, (void *)dpnp_any_ext_c<std::complex<double>, bool>};
749-
750720
func_map_logic_2arg_2type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT,
751721
eft_DBL>(fmap);
752722

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,8 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na
3535
cdef enum DPNPFuncName "DPNPFuncName":
3636
DPNP_FN_ABSOLUTE
3737
DPNP_FN_ABSOLUTE_EXT
38-
DPNP_FN_ALL
39-
DPNP_FN_ALL_EXT
4038
DPNP_FN_ALLCLOSE
4139
DPNP_FN_ALLCLOSE_EXT
42-
DPNP_FN_ANY
43-
DPNP_FN_ANY_EXT
4440
DPNP_FN_ARANGE
4541
DPNP_FN_ARCCOS
4642
DPNP_FN_ARCCOS_EXT

dpnp/dpnp_algo/dpnp_algo_logic.pxi

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,11 @@ and the rest of the library
3636
# NO IMPORTs here. All imports must be placed into main "dpnp_algo.pyx" file
3737

3838
__all__ += [
39-
"dpnp_all",
4039
"dpnp_allclose",
41-
"dpnp_any",
4240
"dpnp_isclose",
4341
]
4442

4543

46-
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_logic_1in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
47-
void *, void * , const size_t,
48-
const c_dpctl.DPCTLEventVectorRef)
4944
ctypedef c_dpctl.DPCTLSyclEventRef(*custom_allclose_1in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef,
5045
void * ,
5146
void * ,
@@ -56,35 +51,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*custom_allclose_1in_1out_func_ptr_t)(c_dpctl
5651
const c_dpctl.DPCTLEventVectorRef)
5752

5853

59-
cpdef utils.dpnp_descriptor dpnp_all(utils.dpnp_descriptor array1):
60-
array1_obj = array1.get_array()
61-
62-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py((1,),
63-
dpnp.bool,
64-
None,
65-
device=array1_obj.sycl_device,
66-
usm_type=array1_obj.usm_type,
67-
sycl_queue=array1_obj.sycl_queue)
68-
69-
result_sycl_queue = result.get_array().sycl_queue
70-
71-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
72-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
73-
74-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
75-
76-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ALL_EXT, param1_type, param1_type)
77-
78-
cdef custom_logic_1in_1out_func_ptr_t func = <custom_logic_1in_1out_func_ptr_t > kernel_data.ptr
79-
80-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, array1.get_data(), result.get_data(), array1.size, NULL)
81-
82-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
83-
c_dpctl.DPCTLEvent_Delete(event_ref)
84-
85-
return result
86-
87-
8854
cpdef utils.dpnp_descriptor dpnp_allclose(utils.dpnp_descriptor array1,
8955
utils.dpnp_descriptor array2,
9056
double rtol_val,
@@ -125,35 +91,6 @@ cpdef utils.dpnp_descriptor dpnp_allclose(utils.dpnp_descriptor array1,
12591
return result
12692

12793

128-
cpdef utils.dpnp_descriptor dpnp_any(utils.dpnp_descriptor array1):
129-
array1_obj = array1.get_array()
130-
131-
cdef utils.dpnp_descriptor result = utils_py.create_output_descriptor_py((1,),
132-
dpnp.bool,
133-
None,
134-
device=array1_obj.sycl_device,
135-
usm_type=array1_obj.usm_type,
136-
sycl_queue=array1_obj.sycl_queue)
137-
138-
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
139-
140-
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ANY_EXT, param1_type, param1_type)
141-
142-
result_sycl_queue = result.get_array().sycl_queue
143-
144-
cdef c_dpctl.SyclQueue q = <c_dpctl.SyclQueue> result_sycl_queue
145-
cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref()
146-
147-
cdef custom_logic_1in_1out_func_ptr_t func = <custom_logic_1in_1out_func_ptr_t > kernel_data.ptr
148-
149-
cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, array1.get_data(), result.get_data(), array1.size, NULL)
150-
151-
with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref)
152-
c_dpctl.DPCTLEvent_Delete(event_ref)
153-
154-
return result
155-
156-
15794
cpdef utils.dpnp_descriptor dpnp_isclose(utils.dpnp_descriptor input1,
15895
utils.dpnp_descriptor input2,
15996
double rtol=1e-05,

dpnp/dpnp_iface_logic.py

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@
4040
"""
4141

4242

43+
import dpctl.tensor as dpt
4344
import numpy
4445

4546
import dpnp
4647
from dpnp.dpnp_algo import *
48+
from dpnp.dpnp_array import dpnp_array
4749
from dpnp.dpnp_utils import *
4850

4951
from .dpnp_algo.dpnp_elementwise_common import (
@@ -84,24 +86,29 @@
8486
]
8587

8688

87-
def all(x1, /, axis=None, out=None, keepdims=False, *, where=True):
89+
def all(x, /, axis=None, out=None, keepdims=False, *, where=True):
8890
"""
8991
Test whether all array elements along a given axis evaluate to True.
9092
9193
For full documentation refer to :obj:`numpy.all`.
9294
95+
Returns
96+
-------
97+
dpnp.ndarray
98+
An array with a data type of `bool`
99+
containing the results of the logical AND reduction.
100+
93101
Limitations
94102
-----------
95-
Input array is supported as :obj:`dpnp.ndarray`.
96-
Otherwise the function will be executed sequentially on CPU.
103+
Parameters `x` is supported either as :class:`dpnp.ndarray`
104+
or :class:`dpctl.tensor.usm_ndarray`.
105+
Parameters `out` and `where` are supported with default value.
97106
Input array data types are limited by supported DPNP :ref:`Data types`.
98-
Parameter `axis` is supported only with default value `None`.
99-
Parameter `out` is supported only with default value `None`.
100-
Parameter `keepdims` is supported only with default value `False`.
101-
Parameter `where` is supported only with default value `True`.
107+
Otherwise the function will be executed sequentially on CPU.
102108
103109
See Also
104110
--------
111+
:obj:`dpnp.ndarray.all` : equivalent method
105112
:obj:`dpnp.any` : Test whether any element along a given axis evaluates to True.
106113
107114
Notes
@@ -111,35 +118,37 @@ def all(x1, /, axis=None, out=None, keepdims=False, *, where=True):
111118
112119
Examples
113120
--------
114-
>>> import dpnp as dp
115-
>>> x = dp.array([[True, False], [True, True]])
116-
>>> dp.all(x)
117-
False
118-
>>> x2 = dp.array([-1, 4, 5])
119-
>>> dp.all(x2)
120-
True
121-
>>> x3 = dp.array([1.0, dp.nan])
122-
>>> dp.all(x3)
123-
True
121+
>>> import dpnp as np
122+
>>> x = np.array([[True, False], [True, True]])
123+
>>> np.all(x)
124+
array(False)
125+
126+
>>> np.all(x, axis=0)
127+
array([ True, False])
128+
129+
>>> x2 = np.array([-1, 4, 5])
130+
>>> np.all(x2)
131+
array(True)
132+
133+
>>> x3 = np.array([1.0, np.nan])
134+
>>> np.all(x3)
135+
array(True)
124136
125137
"""
126138

127-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
128-
if x1_desc:
129-
if axis is not None:
130-
pass
131-
elif out is not None:
132-
pass
133-
elif keepdims is not False:
139+
if dpnp.is_supported_array_type(x):
140+
if out is not None:
134141
pass
135142
elif where is not True:
136143
pass
137144
else:
138-
result_obj = dpnp_all(x1_desc).get_pyobj()
139-
return dpnp.convert_single_elem_array_to_scalar(result_obj)
145+
dpt_array = dpnp.get_usm_ndarray(x)
146+
return dpnp_array._create_from_usm_ndarray(
147+
dpt.all(dpt_array, axis=axis, keepdims=keepdims)
148+
)
140149

141150
return call_origin(
142-
numpy.all, x1, axis=axis, out=out, keepdims=keepdims, where=where
151+
numpy.all, x, axis=axis, out=out, keepdims=keepdims, where=where
143152
)
144153

145154

@@ -181,24 +190,29 @@ def allclose(x1, x2, rtol=1.0e-5, atol=1.0e-8, **kwargs):
181190
return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)
182191

183192

184-
def any(x1, /, axis=None, out=None, keepdims=False, *, where=True):
193+
def any(x, /, axis=None, out=None, keepdims=False, *, where=True):
185194
"""
186195
Test whether any array element along a given axis evaluates to True.
187196
188197
For full documentation refer to :obj:`numpy.any`.
189198
199+
Returns
200+
-------
201+
dpnp.ndarray
202+
An array with a data type of `bool`
203+
containing the results of the logical OR reduction.
204+
190205
Limitations
191206
-----------
192-
Input array is supported as :obj:`dpnp.ndarray`.
193-
Otherwise the function will be executed sequentially on CPU.
207+
Parameters `x` is supported either as :class:`dpnp.ndarray`
208+
or :class:`dpctl.tensor.usm_ndarray`.
209+
Parameters `out` and `where` are supported with default value.
194210
Input array data types are limited by supported DPNP :ref:`Data types`.
195-
Parameter `axis` is supported only with default value `None`.
196-
Parameter `out` is supported only with default value `None`.
197-
Parameter `keepdims` is supported only with default value `False`.
198-
Parameter `where` is supported only with default value `True`.
211+
Otherwise the function will be executed sequentially on CPU.
199212
200213
See Also
201214
--------
215+
:obj:`dpnp.ndarray.any` : equivalent method
202216
:obj:`dpnp.all` : Test whether all elements along a given axis evaluate to True.
203217
204218
Notes
@@ -208,35 +222,37 @@ def any(x1, /, axis=None, out=None, keepdims=False, *, where=True):
208222
209223
Examples
210224
--------
211-
>>> import dpnp as dp
212-
>>> x = dp.array([[True, False], [True, True]])
213-
>>> dp.any(x)
214-
True
215-
>>> x2 = dp.array([0, 0, 0])
216-
>>> dp.any(x2)
217-
False
218-
>>> x3 = dp.array([1.0, dp.nan])
219-
>>> dp.any(x3)
220-
True
225+
>>> import dpnp as np
226+
>>> x = np.array([[True, False], [True, True]])
227+
>>> np.any(x)
228+
array(True)
229+
230+
>>> np.any(x, axis=0)
231+
array([ True, True])
232+
233+
>>> x2 = np.array([0, 0, 0])
234+
>>> np.any(x2)
235+
array(False)
236+
237+
>>> x3 = np.array([1.0, np.nan])
238+
>>> np.any(x3)
239+
array(True)
221240
222241
"""
223242

224-
x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False)
225-
if x1_desc:
226-
if axis is not None:
227-
pass
228-
elif out is not None:
229-
pass
230-
elif keepdims is not False:
243+
if dpnp.is_supported_array_type(x):
244+
if out is not None:
231245
pass
232246
elif where is not True:
233247
pass
234248
else:
235-
result_obj = dpnp_any(x1_desc).get_pyobj()
236-
return dpnp.convert_single_elem_array_to_scalar(result_obj)
249+
dpt_array = dpnp.get_usm_ndarray(x)
250+
return dpnp_array._create_from_usm_ndarray(
251+
dpt.any(dpt_array, axis=axis, keepdims=keepdims)
252+
)
237253

238254
return call_origin(
239-
numpy.any, x1, axis=axis, out=out, keepdims=keepdims, where=where
255+
numpy.any, x, axis=axis, out=out, keepdims=keepdims, where=where
240256
)
241257

242258

0 commit comments

Comments
 (0)