Skip to content

Commit 1c7b85f

Browse files
authored
Remove mixed host\dev implementation from dpnp.any() (#1302)
* Remove mixed host\dev implementation from dpnp.any() * Reduce over group
1 parent c1933c1 commit 1c7b85f

File tree

4 files changed

+62
-33
lines changed

4 files changed

+62
-33
lines changed

dpnp/backend/kernels/dpnp_krnl_logic.cpp

+30-13
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
233233
const size_t size,
234234
const DPCTLEventVectorRef dep_event_vec_ref)
235235
{
236+
static_assert(std::is_same_v<_ResultType, bool>, "Boolean result type is required");
237+
236238
// avoid warning unused variable
237239
(void)dep_event_vec_ref;
238240

@@ -244,38 +246,50 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
244246
}
245247

246248
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
247-
sycl::event event;
248249

249-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
250-
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
251-
const _DataType* array_in = input1_ptr.get_ptr();
252-
_ResultType* result = result1_ptr.get_ptr();
250+
const _DataType* array_in = static_cast<const _DataType*>(array1_in);
251+
bool* result = static_cast<bool*>(result1);
253252

254-
result[0] = false;
253+
auto fill_event = q.fill(result, false, 1);
255254

256255
if (!size)
257256
{
258-
return event_ref;
257+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
258+
return DPCTLEvent_Copy(event_ref);
259259
}
260260

261-
sycl::range<1> gws(size);
262-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
263-
size_t i = global_id[0];
261+
constexpr size_t lws = 64;
262+
constexpr size_t vec_sz = 8;
264263

265-
if (array_in[i])
264+
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
265+
auto lws_range = sycl::range<1>(lws);
266+
sycl::nd_range<1> gws(gws_range, lws_range);
267+
268+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
269+
auto gr = nd_it.get_group();
270+
const auto max_gr_size = gr.get_max_local_range()[0];
271+
const size_t start =
272+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + gr.get_group_id()[0] * max_gr_size);
273+
const size_t end = sycl::min(start + vec_sz * max_gr_size, size);
274+
275+
// each work-item reduces over "vec_sz" elements in the input array
276+
bool local_reduction = sycl::joint_any_of(
277+
gr, &array_in[start], &array_in[end], [&](_DataType elem) { return elem != static_cast<_DataType>(0); });
278+
279+
if (gr.leader() && (local_reduction == true))
266280
{
267281
result[0] = true;
268282
}
269283
};
270284

271285
auto kernel_func = [&](sycl::handler& cgh) {
286+
cgh.depends_on(fill_event);
272287
cgh.parallel_for<class dpnp_any_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
273288
};
274289

275-
event = q.submit(kernel_func);
290+
auto event = q.submit(kernel_func);
276291

277292
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
278-
279293
return DPCTLEvent_Copy(event_ref);
280294
}
281295

@@ -290,6 +304,7 @@ void dpnp_any_c(const void* array1_in, void* result1, const size_t size)
290304
size,
291305
dep_event_vec_ref);
292306
DPCTLEvent_WaitAndThrow(event_ref);
307+
DPCTLEvent_Delete(event_ref);
293308
}
294309

295310
template <typename _DataType, typename _ResultType>
@@ -846,6 +861,8 @@ void func_map_init_logic(func_map_t& fmap)
846861
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_any_ext_c<int64_t, bool>};
847862
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_any_ext_c<float, bool>};
848863
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_any_ext_c<double, bool>};
864+
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_any_ext_c<std::complex<float>, bool>};
865+
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_any_ext_c<std::complex<double>, bool>};
849866

850867
func_map_logic_1arg_1type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
851868
func_map_logic_2arg_2type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);

dpnp/dpnp_array.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -358,19 +358,24 @@ def all(self,
358358

359359
return dpnp.all(self, axis=axis, out=out, keepdims=keepdims, where=where)
360360

361-
def any(self, axis=None, out=None, keepdims=False):
361+
def any(self,
362+
axis=None,
363+
out=None,
364+
keepdims=False,
365+
*,
366+
where=True):
362367
"""
363368
Returns True if any of the elements of `a` evaluate to True.
364369
365-
Refer to `numpy.any` for full documentation.
370+
Refer to :obj:`dpnp.any` for full documentation.
366371
367372
See Also
368373
--------
369-
:obj:`numpy.any` : equivalent function
374+
:obj:`dpnp.any` : equivalent function
370375
371376
"""
372377

373-
return dpnp.any(self, axis, out, keepdims)
378+
return dpnp.any(self, axis=axis, out=out, keepdims=keepdims, where=where)
374379

375380
def argmax(self, axis=None, out=None):
376381
"""

dpnp/dpnp_iface_logic.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,13 @@ def allclose(x1, x2, rtol=1.e-5, atol=1.e-8, **kwargs):
170170
return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)
171171

172172

173-
def any(x1, axis=None, out=None, keepdims=False):
173+
def any(x1,
174+
/,
175+
axis=None,
176+
out=None,
177+
keepdims=False,
178+
*,
179+
where=True):
174180
"""
175181
Test whether any array element along a given axis evaluates to True.
176182
@@ -181,9 +187,10 @@ def any(x1, axis=None, out=None, keepdims=False):
181187
Input array is supported as :obj:`dpnp.ndarray`.
182188
Otherwise the function will be executed sequentially on CPU.
183189
Input array data types are limited by supported DPNP :ref:`Data types`.
184-
Parameter ``axis`` is supported only with default value ``None``.
185-
Parameter ``out`` is supported only with default value ``None``.
186-
Parameter ``keepdims`` is supported only with default value ``False``.
190+
Parameter `axis` is supported only with default value `None`.
191+
Parameter `out` is supported only with default value `None`.
192+
Parameter `keepdims` is supported only with default value `False`.
193+
Parameter `where` is supported only with default value `True`.
187194
188195
See Also
189196
--------
@@ -196,15 +203,15 @@ def any(x1, axis=None, out=None, keepdims=False):
196203
197204
Examples
198205
--------
199-
>>> import dpnp as np
200-
>>> x = np.array([[True, False], [True, True]])
201-
>>> np.any(x)
206+
>>> import dpnp as dp
207+
>>> x = dp.array([[True, False], [True, True]])
208+
>>> dp.any(x)
202209
True
203-
>>> x2 = np.array([0, 0, 0])
204-
>>> np.any(x2)
210+
>>> x2 = dp.array([0, 0, 0])
211+
>>> dp.any(x2)
205212
False
206-
>>> x3 = np.array([1.0, np.nan])
207-
>>> np.any(x3)
213+
>>> x3 = dp.array([1.0, dp.nan])
214+
>>> dp.any(x3)
208215
True
209216
210217
"""
@@ -217,13 +224,13 @@ def any(x1, axis=None, out=None, keepdims=False):
217224
pass
218225
elif keepdims is not False:
219226
pass
227+
elif where is not True:
228+
pass
220229
else:
221230
result_obj = dpnp_any(x1_desc).get_pyobj()
222-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
223-
224-
return result
231+
return dpnp.convert_single_elem_array_to_scalar(result_obj)
225232

226-
return call_origin(numpy.any, x1, axis, out, keepdims)
233+
return call_origin(numpy.any, x1, axis=axis, out=out, keepdims=keepdims, where=where)
227234

228235

229236
def equal(x1,

tests/test_logic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_allclose(type):
6363
assert_allclose(dpnp_res, np_res)
6464

6565

66-
@pytest.mark.parametrize("type", get_all_dtypes(no_complex=True))
66+
@pytest.mark.parametrize("type", get_all_dtypes())
6767
@pytest.mark.parametrize("shape",
6868
[(0,), (4,), (2, 3), (2, 2, 2)],
6969
ids=['(0,)', '(4,)', '(2,3)', '(2,2,2)'])

0 commit comments

Comments
 (0)