Skip to content

set a default value for axis parameter in dpnp.take_along_axis #2442

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ This release achieves 100% compliance with Python Array API specification (revis
* Added MKL functions `arg`, `copysign`, `i0`, and `inv` from VM namespace to be used by implementation of the appropriate element-wise functions [#2445](https://github.com/IntelPython/dpnp/pull/2445)
* Clarified details about conda install instructions in `Quick start quide` and `README` [#2446](https://github.com/IntelPython/dpnp/pull/2446)
* Bumped oneMKL version up to `0.7` [#2448](https://github.com/IntelPython/dpnp/pull/2448)
* The parameter `axis` in `dpnp.take_along_axis` function has now a default value of `-1` [#2442](https://github.com/IntelPython/dpnp/pull/2442)

### Fixed

Expand Down
15 changes: 9 additions & 6 deletions dpnp/dpnp_iface_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2205,7 +2205,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
return dpnp.get_result_array(usm_res, out=out)


def take_along_axis(a, indices, axis, mode="wrap"):
def take_along_axis(a, indices, axis=-1, mode="wrap"):
"""
Take values from the input array by matching 1d index and data slices.

Expand All @@ -2227,9 +2227,12 @@ def take_along_axis(a, indices, axis, mode="wrap"):
dimension of the input array, but dimensions ``Ni`` and ``Nj``
only need to broadcast against `a`.
axis : {None, int}
The axis to take 1d slices along. If axis is ``None``, the input
array is treated as if it had first been flattened to 1d,
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
The axis to take 1d slices along. If axis is ``None``, the input array
is treated as if it had first been flattened to 1d. The default is
``-1``, which takes 1d slices along the last axis. These behaviors are
consistent with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.

Default: ``-1``.
mode : {"wrap", "clip"}, optional
Specifies how out-of-bounds indices will be handled. Possible values
are:
Expand Down Expand Up @@ -2274,8 +2277,8 @@ def take_along_axis(a, indices, axis, mode="wrap"):
array([[10, 20, 30],
[40, 50, 60]])

The same works for max and min, if you maintain the trivial dimension
with ``keepdims``:
The same works for :obj:`dpnp.max` and :obj:`dpnp.min`, if you maintain
the trivial dimension with ``keepdims``:

>>> np.max(a, axis=1, keepdims=True)
array([[30],
Expand Down
2 changes: 1 addition & 1 deletion dpnp/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,7 +804,7 @@ def test_argequivalent(self, func, argfunc, kwargs):
# a = dpnp.random.random(size=(3, 4, 5))
a = dpnp.asarray(numpy.random.random(size=(3, 4, 5)))

for axis in list(range(a.ndim)) + [None]:
for axis in list(range(a.ndim)) + [None, -1]:
a_func = func(a, axis=axis, **kwargs)
ai_func = argfunc(a, axis=axis, **kwargs)
assert_array_equal(
Expand Down
Loading