Skip to content

Commit 9afb742

Browse files
authored
Implements statistical functions mean, std, var (#1465)
* Resolves gh-1456 Tree reductions now populate destination with the identity when reducing over zero-size axes. As a result, logic was removed for handling zero-size axes. ``argmax``, ``argmin``, ``max``, and ``min`` still raise an error for zero-size axes. Reductions now return a copy when provided an empty axis tuple. Adds additional supported dtype combinations to ``prod`` and ``sum``, specifically for input integers and inexact output type * Implements mean, var, and std * Adds more tests for statistical functions * Adds docstrings for statistical functions * Adds more supported types to arithmetic reductions Permits `float` accumulation type with 64 bit integer and unsigned integer inouts to prevent unnecessary copies on devices that don't support double precision * Changes mean reduction to use output data type as sum accumulation type Mean in-place division now uses the real type for the denominator
1 parent f2af753 commit 9afb742

File tree

6 files changed

+797
-28
lines changed

6 files changed

+797
-28
lines changed

dpctl/tensor/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
)
9191
from dpctl.tensor._reshape import reshape
9292
from dpctl.tensor._search_functions import where
93+
from dpctl.tensor._statistical_functions import mean, std, var
9394
from dpctl.tensor._usmarray import usm_ndarray
9495
from dpctl.tensor._utility_functions import all, any
9596

@@ -336,6 +337,9 @@
336337
"clip",
337338
"logsumexp",
338339
"reduce_hypot",
340+
"mean",
341+
"std",
342+
"var",
339343
"__array_api_version__",
340344
"__array_namespace_info__",
341345
]

dpctl/tensor/_reduction.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _reduction_over_axis(
8383
_reduction_fn,
8484
_dtype_supported,
8585
_default_reduction_type_fn,
86-
_identity=None,
8786
):
8887
if not isinstance(x, dpt.usm_ndarray):
8988
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
@@ -106,23 +105,8 @@ def _reduction_over_axis(
106105
res_dt = _to_device_supported_dtype(res_dt, q.sycl_device)
107106

108107
res_usm_type = x.usm_type
109-
if x.size == 0:
110-
if _identity is None:
111-
raise ValueError("reduction does not support zero-size arrays")
112-
else:
113-
if keepdims:
114-
res_shape = res_shape + (1,) * red_nd
115-
inv_perm = sorted(range(nd), key=lambda d: perm[d])
116-
res_shape = tuple(res_shape[i] for i in inv_perm)
117-
return dpt.full(
118-
res_shape,
119-
_identity,
120-
dtype=res_dt,
121-
usm_type=res_usm_type,
122-
sycl_queue=q,
123-
)
124108
if red_nd == 0:
125-
return dpt.astype(x, res_dt, copy=False)
109+
return dpt.astype(x, res_dt, copy=True)
126110

127111
host_tasks_list = []
128112
if _dtype_supported(inp_dt, res_dt, res_usm_type, q):
@@ -251,7 +235,6 @@ def sum(x, axis=None, dtype=None, keepdims=False):
251235
tri._sum_over_axis,
252236
tri._sum_over_axis_dtype_supported,
253237
_default_reduction_dtype,
254-
_identity=0,
255238
)
256239

257240

@@ -312,7 +295,6 @@ def prod(x, axis=None, dtype=None, keepdims=False):
312295
tri._prod_over_axis,
313296
tri._prod_over_axis_dtype_supported,
314297
_default_reduction_dtype,
315-
_identity=1,
316298
)
317299

318300

@@ -368,7 +350,6 @@ def logsumexp(x, axis=None, dtype=None, keepdims=False):
368350
inp_dt, res_dt
369351
),
370352
_default_reduction_dtype_fp_types,
371-
_identity=-dpt.inf,
372353
)
373354

374355

@@ -424,7 +405,6 @@ def reduce_hypot(x, axis=None, dtype=None, keepdims=False):
424405
inp_dt, res_dt
425406
),
426407
_default_reduction_dtype_fp_types,
427-
_identity=0,
428408
)
429409

430410

@@ -446,9 +426,19 @@ def _comparison_over_axis(x, axis, keepdims, _reduction_fn):
446426
res_dt = x.dtype
447427
res_usm_type = x.usm_type
448428
if x.size == 0:
449-
raise ValueError("reduction does not support zero-size arrays")
429+
if any([x.shape[i] == 0 for i in axis]):
430+
raise ValueError(
431+
"reduction cannot be performed over zero-size axes"
432+
)
433+
else:
434+
return dpt.empty(
435+
res_shape,
436+
dtype=res_dt,
437+
usm_type=res_usm_type,
438+
sycl_queue=exec_q,
439+
)
450440
if red_nd == 0:
451-
return x
441+
return dpt.copy(x)
452442

453443
res = dpt.empty(
454444
res_shape,
@@ -549,7 +539,17 @@ def _search_over_axis(x, axis, keepdims, _reduction_fn):
549539
res_dt = ti.default_device_index_type(exec_q.sycl_device)
550540
res_usm_type = x.usm_type
551541
if x.size == 0:
552-
raise ValueError("reduction does not support zero-size arrays")
542+
if any([x.shape[i] == 0 for i in axis]):
543+
raise ValueError(
544+
"reduction cannot be performed over zero-size axes"
545+
)
546+
else:
547+
return dpt.empty(
548+
res_shape,
549+
dtype=res_dt,
550+
usm_type=res_usm_type,
551+
sycl_queue=exec_q,
552+
)
553553
if red_nd == 0:
554554
return dpt.zeros(
555555
res_shape, dtype=res_dt, usm_type=res_usm_type, sycl_queue=exec_q

0 commit comments

Comments
 (0)