Skip to content

Commit a804dd7

Browse files
authored
set a default value for axis parameter in dpnp.take_along_axis (#2442)
Set a default value of `-1` for axis parameter in `dpnp.take_along_axis`
1 parent a79d1ca commit a804dd7

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ This release achieves 100% compliance with Python Array API specification (revis
3333
* Added MKL functions `arg`, `copysign`, `i0`, and `inv` from VM namespace to be used by implementation of the appropriate element-wise functions [#2445](https://github.com/IntelPython/dpnp/pull/2445)
3434
* Clarified details about conda install instructions in `Quick start quide` and `README` [#2446](https://github.com/IntelPython/dpnp/pull/2446)
3535
* Bumped oneMKL version up to `0.7` [#2448](https://github.com/IntelPython/dpnp/pull/2448)
36+
* The parameter `axis` in `dpnp.take_along_axis` function has now a default value of `-1` [#2442](https://github.com/IntelPython/dpnp/pull/2442)
3637

3738
### Fixed
3839

dpnp/dpnp_iface_indexing.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,7 +2205,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
22052205
return dpnp.get_result_array(usm_res, out=out)
22062206

22072207

2208-
def take_along_axis(a, indices, axis, mode="wrap"):
2208+
def take_along_axis(a, indices, axis=-1, mode="wrap"):
22092209
"""
22102210
Take values from the input array by matching 1d index and data slices.
22112211
@@ -2227,9 +2227,12 @@ def take_along_axis(a, indices, axis, mode="wrap"):
22272227
dimension of the input array, but dimensions ``Ni`` and ``Nj``
22282228
only need to broadcast against `a`.
22292229
axis : {None, int}
2230-
The axis to take 1d slices along. If axis is ``None``, the input
2231-
array is treated as if it had first been flattened to 1d,
2232-
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
2230+
The axis to take 1d slices along. If axis is ``None``, the input array
2231+
is treated as if it had first been flattened to 1d. The default is
2232+
``-1``, which takes 1d slices along the last axis. These behaviors are
2233+
consistent with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
2234+
2235+
Default: ``-1``.
22332236
mode : {"wrap", "clip"}, optional
22342237
Specifies how out-of-bounds indices will be handled. Possible values
22352238
are:
@@ -2274,8 +2277,8 @@ def take_along_axis(a, indices, axis, mode="wrap"):
22742277
array([[10, 20, 30],
22752278
[40, 50, 60]])
22762279
2277-
The same works for max and min, if you maintain the trivial dimension
2278-
with ``keepdims``:
2280+
The same works for :obj:`dpnp.max` and :obj:`dpnp.min`, if you maintain
2281+
the trivial dimension with ``keepdims``:
22792282
22802283
>>> np.max(a, axis=1, keepdims=True)
22812284
array([[30],

dpnp/tests/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def test_argequivalent(self, func, argfunc, kwargs):
804804
# a = dpnp.random.random(size=(3, 4, 5))
805805
a = dpnp.asarray(numpy.random.random(size=(3, 4, 5)))
806806

807-
for axis in list(range(a.ndim)) + [None]:
807+
for axis in list(range(a.ndim)) + [None, -1]:
808808
a_func = func(a, axis=axis, **kwargs)
809809
ai_func = argfunc(a, axis=axis, **kwargs)
810810
assert_array_equal(

0 commit comments

Comments
 (0)