Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion dpnp/dpnp_iface.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def get_result_array(a, out=None, casting="safe"):

Parameters
----------
a : {dpnp_array}
a : {dpnp.ndarray, usm_ndarray}
Input array.
out : {dpnp.ndarray, usm_ndarray}
If provided, value of `a` array will be copied into it
Expand All @@ -671,6 +671,8 @@ def get_result_array(a, out=None, casting="safe"):
"""

if out is None:
if isinstance(a, dpt.usm_ndarray):
return dpnp_array._create_from_usm_ndarray(a)
return a

if isinstance(out, dpt.usm_ndarray):
Expand Down
43 changes: 29 additions & 14 deletions dpnp/dpnp_iface_counting.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,38 @@
__all__ = ["count_nonzero"]


def count_nonzero(a, axis=None, *, keepdims=False):
def count_nonzero(a, axis=None, *, keepdims=False, out=None):
"""
Counts the number of non-zero values in the array `a`.

For full documentation refer to :obj:`numpy.count_nonzero`.

Parameters
----------
a : {dpnp.ndarray, usm_ndarray}
The array for which to count non-zeros.
axis : {None, int, tuple}, optional
Axis or tuple of axes along which to count non-zeros.
Default value means that non-zeros will be counted along a flattened
version of `a`.
Default: ``None``.
keepdims : bool, optional
If this is set to ``True``, the axes that are counted are left in the
result as dimensions with size one. With this option, the result will
broadcast correctly against the input array.
Default: ``False``.
out : {None, dpnp.ndarray, usm_ndarray}, optional
The array into which the result is written. The data type of `out` must
match the expected shape and the expected data type of the result.
If ``None`` then a new array is returned.
Default: ``None``.

Returns
-------
out : dpnp.ndarray
Number of non-zero values in the array along a given axis.
Otherwise, a zero-dimensional array with the total number of
non-zero values in the array is returned.

Limitations
-----------
Parameters `a` is supported as either :class:`dpnp.ndarray`
or :class:`dpctl.tensor.usm_ndarray`.
Otherwise ``TypeError`` exception will be raised.
Input array data types are limited by supported DPNP :ref:`Data types`.
Otherwise, a zero-dimensional array with the total number of non-zero
values in the array is returned.

See Also
--------
Expand All @@ -87,8 +100,10 @@ def count_nonzero(a, axis=None, *, keepdims=False):

"""

# TODO: might be improved by implementing an extension
# with `count_nonzero` kernel
usm_a = dpnp.get_usm_ndarray(a)
usm_a = dpt.astype(usm_a, dpnp.bool, copy=False)
return dpnp.sum(usm_a, axis=axis, dtype=dpnp.intp, keepdims=keepdims)
usm_out = None if out is None else dpnp.get_usm_ndarray(out)

usm_res = dpt.count_nonzero(
usm_a, axis=axis, keepdims=keepdims, out=usm_out
)
return dpnp.get_result_array(usm_res, out)
21 changes: 8 additions & 13 deletions dpnp/dpnp_iface_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@

import dpnp
from dpnp.dpnp_algo.dpnp_elementwise_common import DPNPBinaryFunc, DPNPUnaryFunc
from dpnp.dpnp_array import dpnp_array

__all__ = [
"all",
Expand Down Expand Up @@ -167,13 +166,11 @@ def all(a, /, axis=None, out=None, keepdims=False, *, where=True):

dpnp.check_limitations(where=where)

dpt_array = dpnp.get_usm_ndarray(a)
result = dpnp_array._create_from_usm_ndarray(
dpt.all(dpt_array, axis=axis, keepdims=keepdims)
)
usm_a = dpnp.get_usm_ndarray(a)
usm_res = dpt.all(usm_a, axis=axis, keepdims=keepdims)

# TODO: temporary solution until dpt.all supports out parameter
result = dpnp.get_result_array(result, out)
return result
return dpnp.get_result_array(usm_res, out)


def allclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
Expand Down Expand Up @@ -333,13 +330,11 @@ def any(a, /, axis=None, out=None, keepdims=False, *, where=True):

dpnp.check_limitations(where=where)

dpt_array = dpnp.get_usm_ndarray(a)
result = dpnp_array._create_from_usm_ndarray(
dpt.any(dpt_array, axis=axis, keepdims=keepdims)
)
usm_a = dpnp.get_usm_ndarray(a)
usm_res = dpt.any(usm_a, axis=axis, keepdims=keepdims)

# TODO: temporary solution until dpt.any supports out parameter
result = dpnp.get_result_array(result, out)
return result
return dpnp.get_result_array(usm_res, out)


_EQUAL_DOCSTRING = """
Expand Down
6 changes: 2 additions & 4 deletions dpnp/dpnp_iface_searching.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,5 @@ def where(condition, x=None, y=None, /, *, order="K", out=None):
usm_condition = dpnp.get_usm_ndarray(condition)

usm_out = None if out is None else dpnp.get_usm_ndarray(out)
result = dpnp_array._create_from_usm_ndarray(
dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out)
)
return dpnp.get_result_array(result, out)
usm_res = dpt.where(usm_condition, usm_x, usm_y, order=order, out=usm_out)
return dpnp.get_result_array(usm_res, out)
16 changes: 6 additions & 10 deletions dpnp/dpnp_iface_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,11 +904,9 @@ def std(
)
dpnp.sqrt(result, out=result)
else:
dpt_array = dpnp.get_usm_ndarray(a)
result = dpnp_array._create_from_usm_ndarray(
dpt.std(dpt_array, axis=axis, correction=ddof, keepdims=keepdims)
)
result = dpnp.get_result_array(result, out)
usm_a = dpnp.get_usm_ndarray(a)
usm_res = dpt.std(usm_a, axis=axis, correction=ddof, keepdims=keepdims)
result = dpnp.get_result_array(usm_res, out)

if dtype is not None and out is None:
result = result.astype(dtype, casting="same_kind")
Expand Down Expand Up @@ -1028,11 +1026,9 @@ def var(

dpnp.divide(result, cnt, out=result)
else:
dpt_array = dpnp.get_usm_ndarray(a)
result = dpnp_array._create_from_usm_ndarray(
dpt.var(dpt_array, axis=axis, correction=ddof, keepdims=keepdims)
)
result = dpnp.get_result_array(result, out)
usm_a = dpnp.get_usm_ndarray(a)
usm_res = dpt.var(usm_a, axis=axis, correction=ddof, keepdims=keepdims)
result = dpnp.get_result_array(usm_res, out)

if out is None and dtype is not None:
result = result.astype(dtype, casting="same_kind")
Expand Down
4 changes: 1 addition & 3 deletions dpnp/dpnp_utils/dpnp_utils_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@


import dpnp
from dpnp.dpnp_array import dpnp_array

__all__ = ["dpnp_wrap_reduction_call"]

Expand Down Expand Up @@ -53,5 +52,4 @@ def dpnp_wrap_reduction_call(

kwargs["out"] = usm_out
res_usm = _reduction_fn(*args, **kwargs)
res = dpnp_array._create_from_usm_ndarray(res_usm)
return dpnp.get_result_array(res, input_out, casting="unsafe")
return dpnp.get_result_array(res_usm, input_out, casting="unsafe")