diff --git a/spec/draft/API_specification/statistical_functions.rst b/spec/draft/API_specification/statistical_functions.rst index eb5e1a5d6..b54bec5df 100644 --- a/spec/draft/API_specification/statistical_functions.rst +++ b/spec/draft/API_specification/statistical_functions.rst @@ -18,6 +18,7 @@ Objects in API :toctree: generated :template: method.rst + bincount cumulative_prod cumulative_sum max diff --git a/src/array_api_stubs/_draft/statistical_functions.py b/src/array_api_stubs/_draft/statistical_functions.py index 55c84950c..e32302904 100644 --- a/src/array_api_stubs/_draft/statistical_functions.py +++ b/src/array_api_stubs/_draft/statistical_functions.py @@ -1,4 +1,5 @@ __all__ = [ + "bincount", "cumulative_sum", "cumulative_prod", "max", @@ -14,6 +15,45 @@ from ._types import Optional, Tuple, Union, array, dtype +def bincount( + x: array, /, weights: Optional[array] = None, *, minlength: int = 0 +) -> array: + """ + Counts the number of occurrences of each element in ``x``. + + .. admonition:: Data-dependent output shape + :class: important + + The shape of the output array for this function depends on the data values in ``x``; hence, array libraries which build computation graphs (e.g., JAX, Dask, etc.) can find this function difficult to implement without knowing the values in ``x``. Accordingly, such libraries **may** choose to omit this function. See :ref:`data-dependent-output-shapes` section for more details. + + Parameters + ---------- + x: array + input array. **Should** be a one-dimensional array. **Must** have an integer data type. + weights: Optional[array] + an array of weights for each element in ``x``. **Must** have the same shape as ``x``. **Must** have a numeric data type. If not provided, each bin in the returned array **must** give the number of occurrences of its index value in ``x``. If provided, each bin in the returned array **must** be a sum of the weights corresponding to the respective index values in ``x`` (i.e., if value ``n`` is found at index ``i`` in ``x``, then ``out[n] += weights[i]``, instead of ``out[n] += 1``). Default: ``None``. + minlength: int + minimum number of bins. **Must** be a nonnegative integer. Default: ``0``. + + Returns + ------- + out: array + an array containing the number of occurrences. Let ``N`` equal ``max(xp.max(x)+1, minlength)``. The returned array **should** have shape ``(N,)``. + + If ``weights`` is not ``None``, the returned array **must** have the same data type as ``weights``. + + If ``weights`` is ``None``, the returned array **must** have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases: + + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array **must** have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array **must** have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array **must** have a ``uint32`` data type). + + Notes + ----- + + - If ``x`` contains negative values, behavior is unspecified and thus implementation-defined. + """ + + def cumulative_prod( x: array, /,