Skip to content

Commit ad90f66

Browse files
authored
implement dpnp.mean (#1632)
* implement dpnp.mean * address comments
1 parent 912bb77 commit ad90f66

7 files changed

+98
-99
lines changed

.github/workflows/conda-package.yml

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ env:
2525
test_random_state.py
2626
test_sort.py
2727
test_special.py
28+
test_statistics.py
2829
test_sycl_queue.py
2930
test_umath.py
3031
test_usm_type.py
@@ -47,6 +48,7 @@ env:
4748
third_party/cupy/math_tests/test_trigonometric.py
4849
third_party/cupy/sorting_tests/test_sort.py
4950
third_party/cupy/sorting_tests/test_count.py
51+
third_party/cupy/statistics_tests/test_meanvar.py
5052
VER_JSON_NAME: 'version.json'
5153
VER_SCRIPT1: "import json; f = open('version.json', 'r'); j = json.load(f); f.close(); "
5254
VER_SCRIPT2: "d = j['dpnp'][0]; print('='.join((d[s] for s in ('version', 'build'))))"

dpnp/dpnp_array.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -915,10 +915,16 @@ def max(
915915

916916
return dpnp.max(self, axis, out, keepdims, initial, where)
917917

918-
def mean(self, axis=None, **kwargs):
919-
"""Returns the average of the array elements."""
918+
def mean(
919+
self, axis=None, dtype=None, out=None, keepdims=False, *, where=True
920+
):
921+
"""
922+
Returns the average of the array elements.
923+
924+
Refer to :obj:`dpnp.mean` for full documentation.
925+
"""
920926

921-
return dpnp.mean(self, axis=axis, **kwargs)
927+
return dpnp.mean(self, axis, dtype, out, keepdims, where=where)
922928

923929
def min(
924930
self,

dpnp/dpnp_iface_statistics.py

+17-56
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
import dpctl.tensor as dpt
4242
import numpy
43-
from numpy.core.numeric import normalize_axis_tuple
4443

4544
import dpnp
4645
from dpnp.dpnp_algo import *
@@ -417,24 +416,24 @@ def max(a, axis=None, out=None, keepdims=False, initial=None, where=True):
417416
return dpnp.get_result_array(result, out)
418417

419418

420-
def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
419+
def mean(a, /, axis=None, dtype=None, out=None, keepdims=False, *, where=True):
421420
"""
422421
Compute the arithmetic mean along the specified axis.
423422
424423
For full documentation refer to :obj:`numpy.mean`.
425424
426425
Returns
427426
-------
428-
y : dpnp.ndarray
427+
out : dpnp.ndarray
429428
an array containing the mean values of the elements along the specified axis(axes).
430-
If the input array is empty, an array containing a single NaN value is returned.
429+
If the input is a zero-size array, an array containing NaN values is returned.
431430
432431
Limitations
433432
-----------
434-
Parameters `x` is supported as either :class:`dpnp.ndarray`
433+
Parameters `a` is supported as either :class:`dpnp.ndarray`
435434
or :class:`dpctl.tensor.usm_ndarray`.
436-
Parameters `keepdims`, `out` and `where` are supported with their default values.
437-
Otherwise the function will be executed sequentially on CPU.
435+
Parameter `where` is supported only with their default values.
436+
Otherwise ``NotImplementedError`` exception will be raised.
438437
Input array data types are limited by supported DPNP :ref:`Data types`.
439438
440439
See Also
@@ -459,59 +458,21 @@ def mean(x, /, *, axis=None, dtype=None, keepdims=False, out=None, where=True):
459458
array([2., 3.])
460459
>>> np.mean(a, axis=1)
461460
array([1.5, 3.5])
461+
462462
"""
463463

464-
if keepdims is not False:
465-
pass
466-
elif out is not None:
467-
pass
468-
elif where is not True:
469-
pass
464+
if where is not True:
465+
raise NotImplementedError(
466+
"where keyword argument is only supported by its default value."
467+
)
470468
else:
471-
if dtype is None and dpnp.issubdtype(x.dtype, dpnp.inexact):
472-
dtype = x.dtype
473-
474-
if axis is None:
475-
if x.size == 0:
476-
return dpnp.array(dpnp.nan, dtype=dtype)
477-
else:
478-
result = dpnp.sum(x, dtype=dtype) / x.size
479-
return result.astype(dtype) if result.dtype != dtype else result
480-
481-
if not isinstance(axis, (tuple, list)):
482-
axis = (axis,)
483-
484-
axis = normalize_axis_tuple(axis, x.ndim, "axis")
485-
res_sum = dpnp.sum(x, axis=axis, dtype=dtype)
486-
487-
del_ = 1.0
488-
for axis_value in axis:
489-
del_ *= x.shape[axis_value]
490-
491-
# performing an inplace operation on arrays of bool or integer types
492-
# is not possible due to incompatible data types because
493-
# it returns a floating value
494-
if dpnp.issubdtype(res_sum.dtype, dpnp.inexact):
495-
res_sum /= del_
496-
else:
497-
new_res_sum = res_sum / del_
498-
return (
499-
new_res_sum.astype(dtype)
500-
if new_res_sum.dtype != dtype
501-
else new_res_sum
502-
)
503-
504-
return res_sum.astype(dtype) if res_sum.dtype != dtype else res_sum
469+
dpt_array = dpnp.get_usm_ndarray(a)
470+
result = dpnp_array._create_from_usm_ndarray(
471+
dpt.mean(dpt_array, axis=axis, keepdims=keepdims)
472+
)
473+
result = result.astype(dtype) if dtype is not None else result
505474

506-
return call_origin(
507-
numpy.mean,
508-
x,
509-
axis=axis,
510-
dtype=dtype,
511-
out=out,
512-
keepdims=keepdims,
513-
where=where,
514-
)
475+
return dpnp.get_result_array(result, out)
515476

516477

517478
def median(x1, axis=None, out=None, overwrite_input=False, keepdims=False):

tests/test_mathematical.py

-39
Original file line numberDiff line numberDiff line change
@@ -1993,45 +1993,6 @@ def test_sum(shape, dtype_in, dtype_out, transpose, keepdims, order):
19931993
assert_array_equal(numpy_res, dpnp_res.asnumpy())
19941994

19951995

1996-
class TestMean:
1997-
@pytest.mark.parametrize("dtype", get_all_dtypes())
1998-
def test_mean_axis_tuple(self, dtype):
1999-
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dtype)
2000-
np_array = dpnp.asnumpy(dp_array)
2001-
2002-
result = dpnp.mean(dp_array, axis=(0, 1))
2003-
expected = numpy.mean(np_array, axis=(0, 1))
2004-
assert_allclose(expected, result)
2005-
2006-
def test_mean_axis_zero_size(self):
2007-
dp_array = dpnp.array([], dtype="int64")
2008-
np_array = dpnp.asnumpy(dp_array)
2009-
2010-
result = dpnp.mean(dp_array)
2011-
expected = numpy.mean(np_array)
2012-
assert_allclose(expected, result)
2013-
2014-
def test_mean_strided(self):
2015-
dp_array = dpnp.array([-2, -1, 0, 1, 0, 2], dtype="f4")
2016-
np_array = dpnp.asnumpy(dp_array)
2017-
2018-
result = dpnp.mean(dp_array[::-1])
2019-
expected = numpy.mean(np_array[::-1])
2020-
assert_allclose(expected, result)
2021-
2022-
result = dpnp.mean(dp_array[::2])
2023-
expected = numpy.mean(np_array[::2])
2024-
assert_allclose(expected, result)
2025-
2026-
def test_mean_scalar(self):
2027-
dp_array = dpnp.array(5)
2028-
np_array = dpnp.asnumpy(dp_array)
2029-
2030-
result = dp_array.mean()
2031-
expected = np_array.mean()
2032-
assert_allclose(expected, result)
2033-
2034-
20351996
@pytest.mark.parametrize(
20361997
"dtype", get_all_dtypes(no_bool=True, no_none=True, no_complex=True)
20371998
)

tests/test_statistics.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import dpnp
1010

11-
from .helper import get_all_dtypes
11+
from .helper import assert_dtype_allclose, get_all_dtypes
1212

1313

1414
@pytest.mark.parametrize(
@@ -88,6 +88,73 @@ def test_max_min_NotImplemented(func):
8888
getattr(dpnp, func)(ia, initial=6)
8989

9090

91+
class TestMean:
92+
@pytest.mark.parametrize("dtype", get_all_dtypes())
93+
def test_mean_axis_tuple(self, dtype):
94+
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dtype)
95+
np_array = dpnp.asnumpy(dp_array)
96+
97+
result = dpnp.mean(dp_array, axis=(0, 1))
98+
expected = numpy.mean(np_array, axis=(0, 1))
99+
assert_allclose(expected, result)
100+
101+
@pytest.mark.parametrize("dtype", get_all_dtypes())
102+
@pytest.mark.parametrize("axis", [0, 1, (0, 1)])
103+
def test_mean_out(self, dtype, axis):
104+
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype=dtype)
105+
np_array = dpnp.asnumpy(dp_array)
106+
107+
expected = numpy.mean(np_array, axis=axis)
108+
result = dpnp.empty_like(dpnp.asarray(expected))
109+
dpnp.mean(dp_array, axis=axis, out=result)
110+
assert_dtype_allclose(result, expected)
111+
112+
@pytest.mark.parametrize("dtype", get_all_dtypes())
113+
def test_mean_dtype(self, dtype):
114+
dp_array = dpnp.array([[0, 1, 2], [3, 4, 0]], dtype="i4")
115+
np_array = dpnp.asnumpy(dp_array)
116+
117+
expected = numpy.mean(np_array, dtype=dtype)
118+
result = dpnp.mean(dp_array, dtype=dtype)
119+
assert_allclose(expected, result)
120+
121+
@pytest.mark.usefixtures("suppress_invalid_numpy_warnings")
122+
@pytest.mark.parametrize("axis", [0, 1, (0, 1)])
123+
@pytest.mark.parametrize("shape", [(2, 3), (2, 0), (0, 3)])
124+
def test_mean_empty(self, axis, shape):
125+
dp_array = dpnp.empty(shape, dtype=dpnp.int64)
126+
np_array = dpnp.asnumpy(dp_array)
127+
128+
result = dpnp.mean(dp_array, axis=axis)
129+
expected = numpy.mean(np_array, axis=axis)
130+
assert_allclose(expected, result)
131+
132+
def test_mean_strided(self):
133+
dp_array = dpnp.array([-2, -1, 0, 1, 0, 2], dtype="f4")
134+
np_array = dpnp.asnumpy(dp_array)
135+
136+
result = dpnp.mean(dp_array[::-1])
137+
expected = numpy.mean(np_array[::-1])
138+
assert_allclose(expected, result)
139+
140+
result = dpnp.mean(dp_array[::2])
141+
expected = numpy.mean(np_array[::2])
142+
assert_allclose(expected, result)
143+
144+
def test_mean_scalar(self):
145+
dp_array = dpnp.array(5)
146+
np_array = dpnp.asnumpy(dp_array)
147+
148+
result = dp_array.mean()
149+
expected = np_array.mean()
150+
assert_allclose(expected, result)
151+
152+
def test_mean_NotImplemented(func):
153+
ia = dpnp.arange(5)
154+
with pytest.raises(NotImplementedError):
155+
dpnp.mean(ia, where=False)
156+
157+
91158
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
92159
@pytest.mark.parametrize(
93160
"array",

tests/test_sycl_queue.py

+1
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def test_meshgrid(device_x, device_y):
367367
pytest.param("log1p", [1.0e-10, 1.0, 2.0, 4.0, 7.0]),
368368
pytest.param("log2", [1.0, 2.0, 4.0, 7.0]),
369369
pytest.param("max", [1.0, 2.0, 4.0, 7.0]),
370+
pytest.param("mean", [1.0, 2.0, 4.0, 7.0]),
370371
pytest.param("min", [1.0, 2.0, 4.0, 7.0]),
371372
pytest.param("nancumprod", [1.0, dpnp.nan]),
372373
pytest.param("nancumsum", [1.0, dpnp.nan]),

tests/test_usm_type.py

+1
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def test_meshgrid(usm_type_x, usm_type_y):
395395
pytest.param("log2", [1.0, 2.0, 4.0, 7.0]),
396396
pytest.param("nanprod", [1.0, 2.0, dp.nan]),
397397
pytest.param("max", [1.0, 2.0, 4.0, 7.0]),
398+
pytest.param("mean", [1.0, 2.0, 4.0, 7.0]),
398399
pytest.param("min", [1.0, 2.0, 4.0, 7.0]),
399400
pytest.param("negative", [1.0, 0.0, -1.0]),
400401
pytest.param("positive", [1.0, 0.0, -1.0]),

0 commit comments

Comments
 (0)