Skip to content

Commit c1933c1

Browse files
authored
Remove mixed host\dev implementation from dpnp.all() (#1301)
* Remove mixed host\dev implementation from dpnp.all() * Reduce over group
1 parent 351f50b commit c1933c1

File tree

6 files changed

+66
-34
lines changed

6 files changed

+66
-34
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CMake build and local install directory
22
build
33
build_cython
4+
dpnp.egg-info
45

56
# Byte-compiled / optimized / DLL files
67
__pycache__/
@@ -14,6 +15,9 @@ coverage.xml
1415
# Backup files kept after git merge/rebase
1516
*.orig
1617

18+
# Build examples
19+
example3
20+
1721
*dpnp_backend*
1822
dpnp/**/*.cpython*.so
1923
dpnp/**/*.pyd

dpnp/backend/kernels/dpnp_krnl_logic.cpp

+30-13
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
4141
const size_t size,
4242
const DPCTLEventVectorRef dep_event_vec_ref)
4343
{
44+
static_assert(std::is_same_v<_ResultType, bool>, "Boolean result type is required");
45+
4446
// avoid warning unused variable
4547
(void)dep_event_vec_ref;
4648

@@ -52,38 +54,50 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
5254
}
5355

5456
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
55-
sycl::event event;
5657

57-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
58-
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
59-
const _DataType* array_in = input1_ptr.get_ptr();
60-
_ResultType* result = result1_ptr.get_ptr();
58+
const _DataType* array_in = static_cast<const _DataType*>(array1_in);
59+
bool* result = static_cast<bool*>(result1);
6160

62-
result[0] = true;
61+
auto fill_event = q.fill(result, true, 1);
6362

6463
if (!size)
6564
{
66-
return event_ref;
65+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
66+
return DPCTLEvent_Copy(event_ref);
6767
}
6868

69-
sycl::range<1> gws(size);
70-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
71-
size_t i = global_id[0];
69+
constexpr size_t lws = 64;
70+
constexpr size_t vec_sz = 8;
71+
72+
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
73+
auto lws_range = sycl::range<1>(lws);
74+
sycl::nd_range<1> gws(gws_range, lws_range);
7275

73-
if (!array_in[i])
76+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
77+
auto gr = nd_it.get_group();
78+
const auto max_gr_size = gr.get_max_local_range()[0];
79+
const size_t start =
80+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + gr.get_group_id()[0] * max_gr_size);
81+
const size_t end = sycl::min(start + vec_sz * max_gr_size, size);
82+
83+
// each work-item reduces over "vec_sz" elements in the input array
84+
bool local_reduction = sycl::joint_none_of(
85+
gr, &array_in[start], &array_in[end], [&](_DataType elem) { return elem == static_cast<_DataType>(0); });
86+
87+
if (gr.leader() && (local_reduction == false))
7488
{
7589
result[0] = false;
7690
}
7791
};
7892

7993
auto kernel_func = [&](sycl::handler& cgh) {
94+
cgh.depends_on(fill_event);
8095
cgh.parallel_for<class dpnp_all_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
8196
};
8297

83-
event = q.submit(kernel_func);
98+
auto event = q.submit(kernel_func);
8499

85100
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
86-
87101
return DPCTLEvent_Copy(event_ref);
88102
}
89103

@@ -98,6 +112,7 @@ void dpnp_all_c(const void* array1_in, void* result1, const size_t size)
98112
size,
99113
dep_event_vec_ref);
100114
DPCTLEvent_WaitAndThrow(event_ref);
115+
DPCTLEvent_Delete(event_ref);
101116
}
102117

103118
template <typename _DataType, typename _ResultType>
@@ -751,6 +766,8 @@ void func_map_init_logic(func_map_t& fmap)
751766
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_all_ext_c<int64_t, bool>};
752767
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_all_ext_c<float, bool>};
753768
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_all_ext_c<double, bool>};
769+
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_all_ext_c<std::complex<float>, bool>};
770+
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_all_ext_c<std::complex<double>, bool>};
754771

755772
fmap[DPNPFuncName::DPNP_FN_ALLCLOSE][eft_INT][eft_INT] = {eft_BLN,
756773
(void*)dpnp_allclose_default_c<int32_t, int32_t, bool>};

dpnp/dpnp_array.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -339,19 +339,24 @@ def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):
339339
res._array_obj = usm_ary
340340
return res
341341

342-
def all(self, axis=None, out=None, keepdims=False):
342+
def all(self,
343+
axis=None,
344+
out=None,
345+
keepdims=False,
346+
*,
347+
where=True):
343348
"""
344349
Returns True if all elements evaluate to True.
345350
346-
Refer to `numpy.all` for full documentation.
351+
Refer to :obj:`dpnp.all` for full documentation.
347352
348353
See Also
349354
--------
350-
:obj:`numpy.all` : equivalent function
355+
:obj:`dpnp.all` : equivalent function
351356
352357
"""
353358

354-
return dpnp.all(self, axis, out, keepdims)
359+
return dpnp.all(self, axis=axis, out=out, keepdims=keepdims, where=where)
355360

356361
def any(self, axis=None, out=None, keepdims=False):
357362
"""

dpnp/dpnp_iface_logic.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,13 @@
6969
]
7070

7171

72-
def all(x1, axis=None, out=None, keepdims=False):
72+
def all(x1,
73+
/,
74+
axis=None,
75+
out=None,
76+
keepdims=False,
77+
*,
78+
where=True):
7379
"""
7480
Test whether all array elements along a given axis evaluate to True.
7581
@@ -80,9 +86,10 @@ def all(x1, axis=None, out=None, keepdims=False):
8086
Input array is supported as :obj:`dpnp.ndarray`.
8187
Otherwise the function will be executed sequentially on CPU.
8288
Input array data types are limited by supported DPNP :ref:`Data types`.
83-
Parameter ``axis`` is supported only with default value ``None``.
84-
Parameter ``out`` is supported only with default value ``None``.
85-
Parameter ``keepdims`` is supported only with default value ``False``.
89+
Parameter `axis` is supported only with default value `None`.
90+
Parameter `out` is supported only with default value `None`.
91+
Parameter `keepdims` is supported only with default value `False`.
92+
Parameter `where` is supported only with default value `True`.
8693
8794
See Also
8895
--------
@@ -95,15 +102,15 @@ def all(x1, axis=None, out=None, keepdims=False):
95102
96103
Examples
97104
--------
98-
>>> import dpnp as np
99-
>>> x = np.array([[True, False], [True, True]])
100-
>>> np.all(x)
105+
>>> import dpnp as dp
106+
>>> x = dp.array([[True, False], [True, True]])
107+
>>> dp.all(x)
101108
False
102-
>>> x2 = np.array([-1, 4, 5])
103-
>>> np.all(x2)
109+
>>> x2 = dp.array([-1, 4, 5])
110+
>>> dp.all(x2)
104111
True
105-
>>> x3 = np.array([1.0, np.nan])
106-
>>> np.all(x3)
112+
>>> x3 = dp.array([1.0, dp.nan])
113+
>>> dp.all(x3)
107114
True
108115
109116
"""
@@ -116,13 +123,13 @@ def all(x1, axis=None, out=None, keepdims=False):
116123
pass
117124
elif keepdims is not False:
118125
pass
126+
elif where is not True:
127+
pass
119128
else:
120129
result_obj = dpnp_all(x1_desc).get_pyobj()
121-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
122-
123-
return result
130+
return dpnp.convert_single_elem_array_to_scalar(result_obj)
124131

125-
return call_origin(numpy.all, x1, axis, out, keepdims)
132+
return call_origin(numpy.all, x1, axis=axis, out=out, keepdims=keepdims, where=where)
126133

127134

128135
def allclose(x1, x2, rtol=1.e-5, atol=1.e-8, **kwargs):

tests/test_logic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
)
1111

1212

13-
@pytest.mark.parametrize("type", get_all_dtypes(no_complex=True))
13+
@pytest.mark.parametrize("type", get_all_dtypes())
1414
@pytest.mark.parametrize("shape",
1515
[(0,), (4,), (2, 3), (2, 2, 2)],
1616
ids=['(0,)', '(4,)', '(2,3)', '(2,2,2)'])

tests/test_usm_type.py

-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def test_linspace_arrays(usm_type_start, usm_type_stop):
103103
assert res.usm_type == du.get_coerced_usm_type([usm_type_start, usm_type_stop])
104104

105105

106-
@pytest.mark.skip()
107106
@pytest.mark.parametrize("func", ["tril", "triu"], ids=["tril", "triu"])
108107
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
109108
def test_tril_triu(func, usm_type):

0 commit comments

Comments
 (0)