Skip to content

Remove mixed host\dev implementation from dpnp.all() #1301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CMake build and local install directory
build
build_cython
dpnp.egg-info

# Byte-compiled / optimized / DLL files
__pycache__/
Expand All @@ -14,6 +15,9 @@ coverage.xml
# Backup files kept after git merge/rebase
*.orig

# Build examples
example3

*dpnp_backend*
dpnp/**/*.cpython*.so
dpnp/**/*.pyd
Expand Down
43 changes: 30 additions & 13 deletions dpnp/backend/kernels/dpnp_krnl_logic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
const size_t size,
const DPCTLEventVectorRef dep_event_vec_ref)
{
static_assert(std::is_same_v<_ResultType, bool>, "Boolean result type is required");

// avoid warning unused variable
(void)dep_event_vec_ref;

Expand All @@ -52,38 +54,50 @@ DPCTLSyclEventRef dpnp_all_c(DPCTLSyclQueueRef q_ref,
}

sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));
sycl::event event;

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, size);
DPNPC_ptr_adapter<_ResultType> result1_ptr(q_ref, result1, 1, true, true);
const _DataType* array_in = input1_ptr.get_ptr();
_ResultType* result = result1_ptr.get_ptr();
const _DataType* array_in = static_cast<const _DataType*>(array1_in);
bool* result = static_cast<bool*>(result1);

result[0] = true;
auto fill_event = q.fill(result, true, 1);

if (!size)
{
return event_ref;
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&fill_event);
return DPCTLEvent_Copy(event_ref);
}

sycl::range<1> gws(size);
auto kernel_parallel_for_func = [=](sycl::id<1> global_id) {
size_t i = global_id[0];
constexpr size_t lws = 64;
constexpr size_t vec_sz = 8;

auto gws_range = sycl::range<1>(((size + lws * vec_sz - 1) / (lws * vec_sz)) * lws);
auto lws_range = sycl::range<1>(lws);
sycl::nd_range<1> gws(gws_range, lws_range);

if (!array_in[i])
auto kernel_parallel_for_func = [=](sycl::nd_item<1> nd_it) {
auto gr = nd_it.get_group();
const auto max_gr_size = gr.get_max_local_range()[0];
const size_t start =
vec_sz * (nd_it.get_group(0) * nd_it.get_local_range(0) + gr.get_group_id()[0] * max_gr_size);
const size_t end = sycl::min(start + vec_sz * max_gr_size, size);

// each work-item reduces over "vec_sz" elements in the input array
bool local_reduction = sycl::joint_none_of(
gr, &array_in[start], &array_in[end], [&](_DataType elem) { return elem == static_cast<_DataType>(0); });

if (gr.leader() && (local_reduction == false))
{
result[0] = false;
}
};

auto kernel_func = [&](sycl::handler& cgh) {
cgh.depends_on(fill_event);
cgh.parallel_for<class dpnp_all_c_kernel<_DataType, _ResultType>>(gws, kernel_parallel_for_func);
};

event = q.submit(kernel_func);
auto event = q.submit(kernel_func);

event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
}

Expand All @@ -98,6 +112,7 @@ void dpnp_all_c(const void* array1_in, void* result1, const size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _ResultType>
Expand Down Expand Up @@ -751,6 +766,8 @@ void func_map_init_logic(func_map_t& fmap)
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_LNG][eft_LNG] = {eft_LNG, (void*)dpnp_all_ext_c<int64_t, bool>};
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_FLT][eft_FLT] = {eft_FLT, (void*)dpnp_all_ext_c<float, bool>};
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_DBL][eft_DBL] = {eft_DBL, (void*)dpnp_all_ext_c<double, bool>};
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C64][eft_C64] = {eft_C64, (void*)dpnp_all_ext_c<std::complex<float>, bool>};
fmap[DPNPFuncName::DPNP_FN_ALL_EXT][eft_C128][eft_C128] = {eft_C128, (void*)dpnp_all_ext_c<std::complex<double>, bool>};

fmap[DPNPFuncName::DPNP_FN_ALLCLOSE][eft_INT][eft_INT] = {eft_BLN,
(void*)dpnp_allclose_default_c<int32_t, int32_t, bool>};
Expand Down
13 changes: 9 additions & 4 deletions dpnp/dpnp_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,19 +339,24 @@ def _create_from_usm_ndarray(usm_ary : dpt.usm_ndarray):
res._array_obj = usm_ary
return res

def all(self, axis=None, out=None, keepdims=False):
def all(self,
axis=None,
out=None,
keepdims=False,
*,
where=True):
"""
Returns True if all elements evaluate to True.

Refer to `numpy.all` for full documentation.
Refer to :obj:`dpnp.all` for full documentation.

See Also
--------
:obj:`numpy.all` : equivalent function
:obj:`dpnp.all` : equivalent function

"""

return dpnp.all(self, axis, out, keepdims)
return dpnp.all(self, axis=axis, out=out, keepdims=keepdims, where=where)

def any(self, axis=None, out=None, keepdims=False):
"""
Expand Down
37 changes: 22 additions & 15 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,13 @@
]


def all(x1, axis=None, out=None, keepdims=False):
def all(x1,
/,
axis=None,
out=None,
keepdims=False,
*,
where=True):
"""
Test whether all array elements along a given axis evaluate to True.

Expand All @@ -80,9 +86,10 @@ def all(x1, axis=None, out=None, keepdims=False):
Input array is supported as :obj:`dpnp.ndarray`.
Otherwise the function will be executed sequentially on CPU.
Input array data types are limited by supported DPNP :ref:`Data types`.
Parameter ``axis`` is supported only with default value ``None``.
Parameter ``out`` is supported only with default value ``None``.
Parameter ``keepdims`` is supported only with default value ``False``.
Parameter `axis` is supported only with default value `None`.
Parameter `out` is supported only with default value `None`.
Parameter `keepdims` is supported only with default value `False`.
Parameter `where` is supported only with default value `True`.

See Also
--------
Expand All @@ -95,15 +102,15 @@ def all(x1, axis=None, out=None, keepdims=False):

Examples
--------
>>> import dpnp as np
>>> x = np.array([[True, False], [True, True]])
>>> np.all(x)
>>> import dpnp as dp
>>> x = dp.array([[True, False], [True, True]])
>>> dp.all(x)
False
>>> x2 = np.array([-1, 4, 5])
>>> np.all(x2)
>>> x2 = dp.array([-1, 4, 5])
>>> dp.all(x2)
True
>>> x3 = np.array([1.0, np.nan])
>>> np.all(x3)
>>> x3 = dp.array([1.0, dp.nan])
>>> dp.all(x3)
True

"""
Expand All @@ -116,13 +123,13 @@ def all(x1, axis=None, out=None, keepdims=False):
pass
elif keepdims is not False:
pass
elif where is not True:
pass
else:
result_obj = dpnp_all(x1_desc).get_pyobj()
result = dpnp.convert_single_elem_array_to_scalar(result_obj)

return result
return dpnp.convert_single_elem_array_to_scalar(result_obj)

return call_origin(numpy.all, x1, axis, out, keepdims)
return call_origin(numpy.all, x1, axis=axis, out=out, keepdims=keepdims, where=where)


def allclose(x1, x2, rtol=1.e-5, atol=1.e-8, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)


@pytest.mark.parametrize("type", get_all_dtypes(no_complex=True))
@pytest.mark.parametrize("type", get_all_dtypes())
@pytest.mark.parametrize("shape",
[(0,), (4,), (2, 3), (2, 2, 2)],
ids=['(0,)', '(4,)', '(2,3)', '(2,2,2)'])
Expand Down
1 change: 0 additions & 1 deletion tests/test_usm_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def test_linspace_arrays(usm_type_start, usm_type_stop):
assert res.usm_type == du.get_coerced_usm_type([usm_type_start, usm_type_stop])


@pytest.mark.skip()
@pytest.mark.parametrize("func", ["tril", "triu"], ids=["tril", "triu"])
@pytest.mark.parametrize("usm_type", list_of_usm_types, ids=list_of_usm_types)
def test_tril_triu(func, usm_type):
Expand Down