@@ -1490,7 +1490,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
14901490 return dpnp .get_result_array (result , out )
14911491
14921492
1493- def take_along_axis (a , indices , axis ):
1493+ def take_along_axis (a , indices , axis , mode = "wrap" ):
14941494 """
14951495 Take values from the input array by matching 1d index and data slices.
14961496
@@ -1511,15 +1511,24 @@ def take_along_axis(a, indices, axis):
15111511 Indices to take along each 1d slice of `a`. This must match the
15121512 dimension of the input array, but dimensions ``Ni`` and ``Nj``
15131513 only need to broadcast against `a`.
1514- axis : int
1514+ axis : {None, int}
15151515 The axis to take 1d slices along. If axis is ``None``, the input
15161516 array is treated as if it had first been flattened to 1d,
15171517 for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
1518+ mode : {"wrap", "clip"}, optional
1519+ Specifies how out-of-bounds indices will be handled. Possible values
1520+ are:
1521+
1522+ - ``"wrap"``: clamps indices to (``-n <= i < n``), then wraps
1523+ negative indices.
1524+ - ``"clip"``: clips indices to (``0 <= i < n``).
1525+
1526+ Default: ``"wrap"``.
15181527
15191528 Returns
15201529 -------
15211530 out : dpnp.ndarray
1522- The indexed result.
1531+ The indexed result of the same data type as `a` .
15231532
15241533 See Also
15251534 --------
@@ -1579,12 +1588,21 @@ def take_along_axis(a, indices, axis):
15791588
15801589 """
15811590
1582- dpnp .check_supported_arrays_type (a , indices )
1583-
15841591 if axis is None :
1585- a = a .ravel ()
1592+ dpnp .check_supported_arrays_type (indices )
1593+ if indices .ndim != 1 :
1594+ raise ValueError (
1595+ "when axis=None, `indices` must have a single dimension."
1596+ )
15861597
1587- return a [_build_along_axis_index (a , indices , axis )]
1598+ a = dpnp .ravel (a )
1599+ axis = 0
1600+
1601+ usm_a = dpnp .get_usm_ndarray (a )
1602+ usm_ind = dpnp .get_usm_ndarray (indices )
1603+
1604+ usm_res = dpt .take_along_axis (usm_a , usm_ind , axis = axis , mode = mode )
1605+ return dpnp_array ._create_from_usm_ndarray (usm_res )
15881606
15891607
15901608def tril_indices (
0 commit comments