Skip to content

Commit 7c500ec

Browse files
committed
Adds docstrings for statistical functions
1 parent b67bcea commit 7c500ec

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

dpctl/tensor/_statistical_functions.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,35 @@ def _var_impl(x, axis, correction, keepdims):
162162

163163

164164
def mean(x, axis=None, keepdims=False):
165+
"""mean(x, axis=None, keepdims=False)
166+
167+
Calculates the arithmetic mean of elements in the input array `x`.
168+
169+
Args:
170+
x (usm_ndarray):
171+
input array.
172+
axis (Optional[int, Tuple[int, ...]]):
173+
axis or axes along which the arithmetic means must be computed. If
174+
a tuple of unique integers, the means are computed over multiple
175+
axes. If `None`, the mean is computed over the entire array.
176+
Default: `None`.
177+
keepdims (Optional[bool]):
178+
if `True`, the reduced axes (dimensions) are included in the result
179+
as singleton dimensions, so that the returned array remains
180+
compatible with the input array according to Array Broadcasting
181+
rules. Otherwise, if `False`, the reduced axes are not included in
182+
the returned array. Default: `False`.
183+
Returns:
184+
usm_ndarray:
185+
an array containing the arithmetic means. If the mean was computed
186+
over the entire array, a zero-dimensional array is returned.
187+
188+
If `x` has a floating-point data type, the returned array will have
189+
the same data type as `x`.
190+
If `x` has a boolean or integral data type, the returned array
191+
will have the default floating point data type for the device
192+
where input array `x` is allocated.
193+
"""
165194
if not isinstance(x, dpt.usm_ndarray):
166195
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
167196
nd = x.ndim
@@ -242,6 +271,40 @@ def mean(x, axis=None, keepdims=False):
242271

243272

244273
def var(x, axis=None, correction=0.0, keepdims=False):
274+
"""var(x, axis=None, correction=0.0, keepdims=False)
275+
276+
Calculates the variance of elements in the input array `x`.
277+
278+
Args:
279+
x (usm_ndarray):
280+
input array.
281+
axis (Optional[int, Tuple[int, ...]]):
282+
axis or axes along which the variances must be computed. If a tuple
283+
of unique integers, the variances are computed over multiple axes.
284+
If `None`, the variance is computed over the entire array.
285+
Default: `None`.
286+
correction (Optional[float, int]):
287+
degrees of freedom adjustment. The divisor used in calculating the
288+
variance is `N-correction`, where `N` corresponds to the total
289+
number of elements over which the variance is calculated.
290+
Default: `0.0`.
291+
keepdims (Optional[bool]):
292+
if `True`, the reduced axes (dimensions) are included in the result
293+
as singleton dimensions, so that the returned array remains
294+
compatible with the input array according to Array Broadcasting
295+
rules. Otherwise, if `False`, the reduced axes are not included in
296+
the returned array. Default: `False`.
297+
Returns:
298+
usm_ndarray:
299+
an array containing the variances. If the variance was computed
300+
over the entire array, a zero-dimensional array is returned.
301+
302+
If `x` has a real-valued floating-point data type, the returned
303+
array will have the same data type as `x`.
304+
If `x` has a boolean or integral data type, the returned array
305+
will have the default floating point data type for the device
306+
where input array `x` is allocated.
307+
"""
245308
if not isinstance(x, dpt.usm_ndarray):
246309
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
247310

@@ -260,6 +323,41 @@ def var(x, axis=None, correction=0.0, keepdims=False):
260323

261324

262325
def std(x, axis=None, correction=0.0, keepdims=False):
326+
"""std(x, axis=None, correction=0.0, keepdims=False)
327+
328+
Calculates the standard deviation of elements in the input array `x`.
329+
330+
Args:
331+
x (usm_ndarray):
332+
input array.
333+
axis (Optional[int, Tuple[int, ...]]):
334+
axis or axes along which the standard deviations must be computed.
335+
If a tuple of unique integers, the standard deviations are computed
336+
over multiple axes. If `None`, the standard deviation is computed
337+
over the entire array. Default: `None`.
338+
correction (Optional[float, int]):
339+
degrees of freedom adjustment. The divisor used in calculating the
340+
standard deviation is `N-correction`, where `N` corresponds to the
341+
total number of elements over which the standard deviation is
342+
calculated. Default: `0.0`.
343+
keepdims (Optional[bool]):
344+
if `True`, the reduced axes (dimensions) are included in the result
345+
as singleton dimensions, so that the returned array remains
346+
compatible with the input array according to Array Broadcasting
347+
rules. Otherwise, if `False`, the reduced axes are not included in
348+
the returned array. Default: `False`.
349+
Returns:
350+
usm_ndarray:
351+
an array containing the standard deviations. If the standard
352+
deviation was computed over the entire array, a zero-dimensional
353+
array is returned.
354+
355+
If `x` has a real-valued floating-point data type, the returned
356+
array will have the same data type as `x`.
357+
If `x` has a boolean or integral data type, the returned array
358+
will have the default floating point data type for the device
359+
where input array `x` is allocated.
360+
"""
263361
if not isinstance(x, dpt.usm_ndarray):
264362
raise TypeError(f"Expected dpctl.tensor.usm_ndarray, got {type(x)}")
265363

0 commit comments

Comments
 (0)