diff --git a/CHANGELOG.md b/CHANGELOG.md index bff849e27dfb..7c2851b87e25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ This release achieves 100% compliance with Python Array API specification (revis * Added MKL functions `arg`, `copysign`, `i0`, and `inv` from VM namespace to be used by implementation of the appropriate element-wise functions [#2445](https://github.com/IntelPython/dpnp/pull/2445) * Clarified details about conda install instructions in `Quick start quide` and `README` [#2446](https://github.com/IntelPython/dpnp/pull/2446) * Bumped oneMKL version up to `0.7` [#2448](https://github.com/IntelPython/dpnp/pull/2448) +* The parameter `axis` in `dpnp.take_along_axis` function has now a default value of `-1` [#2442](https://github.com/IntelPython/dpnp/pull/2442) ### Fixed diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 5394601b9e09..b2b07aea49aa 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -2205,7 +2205,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"): return dpnp.get_result_array(usm_res, out=out) -def take_along_axis(a, indices, axis, mode="wrap"): +def take_along_axis(a, indices, axis=-1, mode="wrap"): """ Take values from the input array by matching 1d index and data slices. @@ -2227,9 +2227,12 @@ def take_along_axis(a, indices, axis, mode="wrap"): dimension of the input array, but dimensions ``Ni`` and ``Nj`` only need to broadcast against `a`. axis : {None, int} - The axis to take 1d slices along. If axis is ``None``, the input - array is treated as if it had first been flattened to 1d, - for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`. + The axis to take 1d slices along. If axis is ``None``, the input array + is treated as if it had first been flattened to 1d. The default is + ``-1``, which takes 1d slices along the last axis. These behaviors are + consistent with :obj:`dpnp.sort` and :obj:`dpnp.argsort`. + + Default: ``-1``. mode : {"wrap", "clip"}, optional Specifies how out-of-bounds indices will be handled. Possible values are: @@ -2274,8 +2277,8 @@ def take_along_axis(a, indices, axis, mode="wrap"): array([[10, 20, 30], [40, 50, 60]]) - The same works for max and min, if you maintain the trivial dimension - with ``keepdims``: + The same works for :obj:`dpnp.max` and :obj:`dpnp.min`, if you maintain + the trivial dimension with ``keepdims``: >>> np.max(a, axis=1, keepdims=True) array([[30], diff --git a/dpnp/tests/test_indexing.py b/dpnp/tests/test_indexing.py index 5439b4dc484e..619cc6137ea6 100644 --- a/dpnp/tests/test_indexing.py +++ b/dpnp/tests/test_indexing.py @@ -804,7 +804,7 @@ def test_argequivalent(self, func, argfunc, kwargs): # a = dpnp.random.random(size=(3, 4, 5)) a = dpnp.asarray(numpy.random.random(size=(3, 4, 5))) - for axis in list(range(a.ndim)) + [None]: + for axis in list(range(a.ndim)) + [None, -1]: a_func = func(a, axis=axis, **kwargs) ai_func = argfunc(a, axis=axis, **kwargs) assert_array_equal(