Skip to content

Remove mixed host\dev implementation from dpnp.any() #1302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions dpnp/backend/kernels/dpnp_krnl_logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
const size_t size,
const DPCTLEventVectorRef dep_event_vec_ref)
{
static_assert(std::is_same_v<_ResultType, bool>, "Boolean result type is required");

// avoid warning unused variable
(void)dep_event_vec_ref;

Expand All @@ -244,38 +246,50 @@ DPCTLSyclEventRef dpnp_any_c(DPCTLSyclQueueRef q_ref,
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType* array_in = input1_ptr.get_ptr();
_ResultType* result = result1_ptr.get_ptr();
const _DataType* array_in = static_cast<const _DataType*>(array1_in);
bool* result = static_cast<bool*>(result1);

result[0] = false;
auto fill_event = q.fill(result, false, 1);

if (!size)
{
return event_ref;
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
return DPCTLEvent_Copy(event_ref);
}

sycl::range<1> gws(size);
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
size_t i = global_id[0];
constexpr size_t lws = 64;
constexpr size_t vec_sz = 8;

if (array_in[i])
auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
auto lws_range = sycl::range<1>(lws);
sycl::nd_range<1> gws(gws_range, lws_range);

auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto gr = nd_it.get_group();
const auto max_gr_size = gr.get_max_local_range()[0];
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + gr.get_group_id()[0] * max_gr_size);
const size_t end = sycl::min(start + vec_sz * max_gr_size, size);

// each work-item reduces over "vec_sz" elements in the input array
bool local_reduction = sycl::joint_any_of(
gr, &array_in[start], &array_in[end], [&](_DataType elem) { return elem != static_cast<_DataType>(0); });

if (gr.leader() && (local_reduction == true))
{
result[0] = true;
}
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<class dpnp_any_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
};

event = q.submit(kernel_func);
auto event = q.submit(kernel_func);

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -290,6 +304,7 @@ void dpnp_any_c(const void* array1_in, void* result1, const size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _ResultType>
Expand Down Expand Up @@ -846,6 +861,8 @@ void func_map_init_logic(func_map_t& fmap)
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_any_ext_c<int64_t, bool>};
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_any_ext_c<float, bool>};
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_any_ext_c<double, bool>};
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_any_ext_c<std::complex<float>, bool>};
fmap[DPNPFuncName::DPNP_FN_ANY_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_any_ext_c<std::complex<double>, bool>};

func_map_logic_1arg_1type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
func_map_logic_2arg_2type_helper<eft_BLN, eft_INT, eft_LNG, eft_FLT, eft_DBL>(fmap);
Expand Down
13 changes: 9 additions & 4 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,19 +358,24 @@ def all(self,

return dpnp.all(self, axis=axis, out=out, keepdims=keepdims, where=where)

def any(self, axis=None, out=None, keepdims=False):
def any(self,
axis=None,
out=None,
keepdims=False,
*,
where=True):
"""
Returns True if any of the elements of `a` evaluate to True.

Refer to `numpy.any` for full documentation.
Refer to :obj:`dpnp.any` for full documentation.

See Also
--------
:obj:`numpy.any` : equivalent function
:obj:`dpnp.any` : equivalent function

"""

return dpnp.any(self, axis, out, keepdims)
return dpnp.any(self, axis=axis, out=out, keepdims=keepdims, where=where)

def argmax(self, axis=None, out=None):
"""
Expand Down
37 changes: 22 additions & 15 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,13 @@ def allclose(x1, x2, rtol=1.e-5, atol=1.e-8, **kwargs):
return call_origin(numpy.allclose, x1, x2, rtol=rtol, atol=atol, **kwargs)


def any(x1, axis=None, out=None, keepdims=False):
def any(x1,
/,
axis=None,
out=None,
keepdims=False,
*,
where=True):
"""
Test whether any array element along a given axis evaluates to True.

Expand All @@ -181,9 +187,10 @@ def any(x1, axis=None, out=None, keepdims=False):
Input array is supported as :obj:`dpnp.ndarray`.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameter ``axis`` is supported only with default value ``None``.
Parameter ``out`` is supported only with default value ``None``.
Parameter ``keepdims`` is supported only with default value ``False``.
Parameter `axis` is supported only with default value `None`.
Parameter `out` is supported only with default value `None`.
Parameter `keepdims` is supported only with default value `False`.
Parameter `where` is supported only with default value `True`.

See Also
--------
Expand All @@ -196,15 +203,15 @@ def any(x1, axis=None, out=None, keepdims=False):

Examples
--------
>>> import dpnp as np
>>> x = np.array([[True, False], [True, True]])
>>> np.any(x)
>>> import dpnp as dp
>>> x = dp.array([[True, False], [True, True]])
>>> dp.any(x)
True
>>> x2 = np.array([0, 0, 0])
>>> np.any(x2)
>>> x2 = dp.array([0, 0, 0])
>>> dp.any(x2)
False
>>> x3 = np.array([1.0, np.nan])
>>> np.any(x3)
>>> x3 = dp.array([1.0, dp.nan])
>>> dp.any(x3)
True

"""
Expand All @@ -217,13 +224,13 @@ def any(x1, axis=None, out=None, keepdims=False):
pass
elif keepdims is not False:
pass
elif where is not True:
pass
else:
result_obj = dpnp_any(x1_desc).get_pyobj()
result = dpnp.convert_single_elem_array_to_scalar(result_obj)

return result
return dpnp.convert_single_elem_array_to_scalar(result_obj)

return call_origin(numpy.any, x1, axis, out, keepdims)
return call_origin(numpy.any, x1, axis=axis, out=out, keepdims=keepdims, where=where)


def equal(x1,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_allclose(type):
assert_allclose(dpnp_res, np_res)


@pytest.mark.parametrize("type", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("type", get_all_dtypes())
@pytest.mark.parametrize("shape",
[(0,), (4,), (2, 3), (2, 2, 2)],
ids=['(0,)', '(4,)', '(2,3)', '(2,2,2)'])
Expand Down