Skip to content

Commit 8e6b8dc

Browse files
committed
Remove mixed host\dev implementation from dpnp.any()
1 parent 439f2b5 commit 8e6b8dc

File tree

7 files changed

+67
-35
lines changed

7 files changed

+67
-35
lines changed

.gitignore

+4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# CMake build and local install directory
22
build
33
build_cython
4+
dpnp.egg-info
45

56
# Byte-compiled / optimized / DLL files
67
__pycache__/
@@ -14,6 +15,9 @@ coverage.xml
1415
# Backup files kept after git merge/rebase
1516
*.orig
1617

18+
# Build examples
19+
example3
20+
1721
*dpnp_backend*
1822
dpnp/**/*.cpython*.so
1923
dpnp/**/*.pyd

dpnp/backend/kernels/dpnp_krnl_logic.cpp

+30-13
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
218218
const size_t size,
219219
const DPCTLEventVectorRef dep_event_vec_ref)
220220
{
221+
static_assert(std::is_same_v<_ResultType, bool>, "Boolean result type is required");
222+
221223
// avoid warning unused variable
222224
(void)dep_event_vec_ref;
223225

@@ -229,38 +231,50 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
229231
}
230232

231233
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
232-
sycl::event event;
233234

234-
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
235-
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
236-
const _DataType* array_in = input1_ptr.get_ptr();
237-
_ResultType* result = result1_ptr.get_ptr();
235+
const _DataType* array_in = static_cast<const _DataType*>(array1_in);
236+
bool* result = static_cast<bool*>(result1);
238237

239-
result[0] = false;
238+
auto fill_event = q.fill(result, false, 1);
240239

241240
if (!size)
242241
{
243-
return event_ref;
242+
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
243+
return DPCTLEvent_Copy(event_ref);
244244
}
245245

246-
sycl::range<1> gws(size);
247-
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
248-
size_t i = global_id[0];
246+
constexpr size_t lws = 64;
247+
constexpr size_t vec_sz = 8;
248+
249+
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
250+
auto lws_range = sycl::range<1>(lws);
251+
sycl::nd_range<1> gws(gws_range, lws_range);
252+
253+
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
254+
auto sg = nd_it.get_sub_group();
255+
const auto max_sg_size = sg.get_max_local_range()[0];
256+
const size_t start =
257+
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + sg.get_group_id()[0] * max_sg_size);
258+
const size_t end = sycl::min(start + vec_sz * max_sg_size, size);
249259

250-
if (array_in[i])
260+
// each work-item reduces over "vec_sz" elements in the input array
261+
bool local_reduction = sycl::joint_any_of(
262+
sg, &array_in[start], &array_in[end], [&](_DataType elem) { return elem != static_cast<_DataType>(0); });
263+
264+
if (sg.leader() && (local_reduction == true))
251265
{
252266
result[0] = true;
253267
}
254268
};
255269

256270
auto kernel_func = [&](sycl::handler& cgh) {
271+
cgh.depends_on(fill_event);
257272
cgh.parallel_for<class dpnp_any_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
258273
};
259274

260-
event = q.submit(kernel_func);
275+
auto event = q.submit(kernel_func);
261276

262277
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);
263-
264278
return DPCTLEvent_Copy(event_ref);
265279
}
266280

@@ -275,6 +289,7 @@ void dpnp_any_c(const void* array1_in, void* result1, const size_t size)
275289
size,
276290
dep_event_vec_ref);
277291
DPCTLEvent_WaitAndThrow(event_ref);
292+
DPCTLEvent_Delete(event_ref);
278293
}
279294

280295
template <typename _DataType, typename _ResultType>
@@ -829,6 +844,8 @@ void func_map_init_logic(func_map_t& fmap)
829844
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_any_ext_c<int64_t, bool>};
830845
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_any_ext_c<float, bool>};
831846
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_any_ext_c<double, bool>};
847+
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_any_ext_c<std::complex<float>, bool>};
848+
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_any_ext_c<std::complex<double>, bool>};
832849

833850
func_map_logic_1arg_1type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
834851
func_map_logic_2arg_2type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);

dpnp/dpnp_array.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -327,19 +327,24 @@ def all(self, axis=None, out=None, keepdims=False):
327327

328328
return dpnp.all(self, axis, out, keepdims)
329329

330-
def any(self, axis=None, out=None, keepdims=False):
330+
def any(self,
331+
axis=None,
332+
out=None,
333+
keepdims=False,
334+
*,
335+
where=True):
331336
"""
332337
Returns True if any of the elements of `a` evaluate to True.
333338
334-
Refer to `numpy.any` for full documentation.
339+
Refer to :obj:`dpnp.any` for full documentation.
335340
336341
See Also
337342
--------
338-
:obj:`numpy.any` : equivalent function
343+
:obj:`dpnp.any` : equivalent function
339344
340345
"""
341346

342-
return dpnp.any(self, axis, out, keepdims)
347+
return dpnp.any(self, axis=axis, out=out, keepdims=keepdims, where=where)
343348

344349
def argmax(self, axis=None, out=None):
345350
"""

dpnp/dpnp_iface_logic.py

+22-15
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,13 @@ def allclose(x1, x2, rtol=1.e-5, atol=1.e-8, **kwargs):
163163
return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)
164164

165165

166-
def any(x1, axis=None, out=None, keepdims=False):
166+
def any(x1,
167+
/,
168+
axis=None,
169+
out=None,
170+
keepdims=False,
171+
*,
172+
where=True):
167173
"""
168174
Test whether any array element along a given axis evaluates to True.
169175
@@ -174,9 +180,10 @@ def any(x1, axis=None, out=None, keepdims=False):
174180
Input array is supported as :obj:`dpnp.ndarray`.
175181
Otherwise the function will be executed sequentially on CPU.
176182
Input array data types are limited by supported DPNP :ref:`Data types`.
177-
Parameter ``axis`` is supported only with default value ``None``.
178-
Parameter ``out`` is supported only with default value ``None``.
179-
Parameter ``keepdims`` is supported only with default value ``False``.
183+
Parameter `axis` is supported only with default value `None`.
184+
Parameter `out` is supported only with default value `None`.
185+
Parameter `keepdims` is supported only with default value `False`.
186+
Parameter `where` is supported only with default value `True`.
180187
181188
See Also
182189
--------
@@ -189,15 +196,15 @@ def any(x1, axis=None, out=None, keepdims=False):
189196
190197
Examples
191198
--------
192-
>>> import dpnp as np
193-
>>> x = np.array([[True, False], [True, True]])
194-
>>> np.any(x)
199+
>>> import dpnp as dp
200+
>>> x = dp.array([[True, False], [True, True]])
201+
>>> dp.any(x)
195202
True
196-
>>> x2 = np.array([0, 0, 0])
197-
>>> np.any(x2)
203+
>>> x2 = dp.array([0, 0, 0])
204+
>>> dp.any(x2)
198205
False
199-
>>> x3 = np.array([1.0, np.nan])
200-
>>> np.any(x3)
206+
>>> x3 = dp.array([1.0, dp.nan])
207+
>>> dp.any(x3)
201208
True
202209
203210
"""
@@ -210,13 +217,13 @@ def any(x1, axis=None, out=None, keepdims=False):
210217
pass
211218
elif keepdims is not False:
212219
pass
220+
elif where is not True:
221+
pass
213222
else:
214223
result_obj = dpnp_any(x1_desc).get_pyobj()
215-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
216-
217-
return result
224+
return dpnp.convert_single_elem_array_to_scalar(result_obj)
218225

219-
return call_origin(numpy.any, x1, axis, out, keepdims)
226+
return call_origin(numpy.any, x1, axis=axis, out=out, keepdims=keepdims, where=where)
220227

221228

222229
def equal(x1,

tests/test_arraycreation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def test_dpctl_tensor_input(func, args):
505505
new_args = [eval(val, {'x0' : x0}) for val in args]
506506
X = getattr(dpt, func)(*new_args)
507507
Y = getattr(dpnp, func)(*new_args)
508-
if func is 'empty_like':
508+
if func == 'empty_like':
509509
assert X.shape == Y.shape
510510
else:
511511
assert_array_equal(X, Y)

tests/test_logic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_allclose(type):
6363
assert_allclose(dpnp_res, np_res)
6464

6565

66-
@pytest.mark.parametrize("type", get_all_dtypes(no_complex=True))
66+
@pytest.mark.parametrize("type", get_all_dtypes())
6767
@pytest.mark.parametrize("shape",
6868
[(0,), (4,), (2, 3), (2, 2, 2)],
6969
ids=['(0,)', '(4,)', '(2,3)', '(2,2,2)'])

tests/test_usm_type.py

-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def test_array_creation(func, args, usm_type_x, usm_type_y):
6464
assert y.usm_type == usm_type_y
6565

6666

67-
@pytest.mark.skip()
6867
@pytest.mark.parametrize("func", ["tril", "triu"], ids=["tril", "triu"])
6968
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
7069
def test_tril_triu(func, usm_type):

0 commit comments

Comments
 (0)