From ed91950eafe54a70dbb8b0f4fdf0c8ac14d475d6 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 20 Jul 2023 16:41:17 +0200 Subject: [PATCH 1/3] Reuse dpctl.tensort.take for dpnp.take --- dpnp/backend/include/dpnp_iface_fptr.hpp | 4 +- dpnp/backend/kernels/dpnp_krnl_indexing.cpp | 27 ------- dpnp/dpnp_algo/dpnp_algo.pxd | 2 - dpnp/dpnp_algo/dpnp_algo_indexing.pxi | 44 ----------- dpnp/dpnp_array.py | 6 +- dpnp/dpnp_iface_indexing.py | 59 ++++++++++---- tests/skipped_tests.tbl | 1 - tests/skipped_tests_gpu.tbl | 2 - tests/test_indexing.py | 78 ++++++++----------- .../cupy/indexing_tests/test_indexing.py | 3 +- 10 files changed, 85 insertions(+), 141 deletions(-) diff --git a/dpnp/backend/include/dpnp_iface_fptr.hpp b/dpnp/backend/include/dpnp_iface_fptr.hpp index 707711306613..96b858854731 100644 --- a/dpnp/backend/include/dpnp_iface_fptr.hpp +++ b/dpnp/backend/include/dpnp_iface_fptr.hpp @@ -487,9 +487,7 @@ enum class DPNPFuncName : size_t DPNP_FN_SVD_EXT, /**< Used in numpy.linalg.svd() impl, requires extra parameters */ DPNP_FN_TAKE, /**< Used in numpy.take() impl */ - DPNP_FN_TAKE_EXT, /**< Used in numpy.take() impl, requires extra parameters - */ - DPNP_FN_TAN, /**< Used in numpy.tan() impl */ + DPNP_FN_TAN, /**< Used in numpy.tan() impl */ DPNP_FN_TAN_EXT, /**< Used in numpy.tan() impl, requires extra parameters */ DPNP_FN_TANH, /**< Used in numpy.tanh() impl */ DPNP_FN_TANH_EXT, /**< Used in numpy.tanh() impl, requires extra parameters diff --git a/dpnp/backend/kernels/dpnp_krnl_indexing.cpp b/dpnp/backend/kernels/dpnp_krnl_indexing.cpp index c2f9e0d9bda7..b8fa2179b6f6 100644 --- a/dpnp/backend/kernels/dpnp_krnl_indexing.cpp +++ b/dpnp/backend/kernels/dpnp_krnl_indexing.cpp @@ -1059,32 +1059,5 @@ void func_map_init_indexing_func(func_map_t &fmap) fmap[DPNPFuncName::DPNP_FN_TAKE][eft_C128][eft_LNG] = { eft_C128, (void *)dpnp_take_default_c, int64_t>}; - // TODO: add a handling of other indexes types once DPCtl implementation of - // data copy is ready - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_BLN][eft_INT] = { - eft_BLN, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_INT][eft_INT] = { - eft_INT, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_LNG][eft_INT] = { - eft_LNG, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_FLT][eft_INT] = { - eft_FLT, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_DBL][eft_INT] = { - eft_DBL, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_C128][eft_INT] = { - eft_C128, (void *)dpnp_take_ext_c, int32_t>}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_BLN][eft_LNG] = { - eft_BLN, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_INT][eft_LNG] = { - eft_INT, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_LNG][eft_LNG] = { - eft_LNG, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_FLT][eft_LNG] = { - eft_FLT, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_DBL][eft_LNG] = { - eft_DBL, (void *)dpnp_take_ext_c}; - fmap[DPNPFuncName::DPNP_FN_TAKE_EXT][eft_C128][eft_LNG] = { - eft_C128, (void *)dpnp_take_ext_c, int64_t>}; - return; } diff --git a/dpnp/dpnp_algo/dpnp_algo.pxd b/dpnp/dpnp_algo/dpnp_algo.pxd index b56557f75c86..37d4a7d3694a 100644 --- a/dpnp/dpnp_algo/dpnp_algo.pxd +++ b/dpnp/dpnp_algo/dpnp_algo.pxd @@ -295,8 +295,6 @@ cdef extern from "dpnp_iface_fptr.hpp" namespace "DPNPFuncName": # need this na DPNP_FN_SUM_EXT DPNP_FN_SVD DPNP_FN_SVD_EXT - DPNP_FN_TAKE - DPNP_FN_TAKE_EXT DPNP_FN_TAN DPNP_FN_TAN_EXT DPNP_FN_TANH diff --git a/dpnp/dpnp_algo/dpnp_algo_indexing.pxi b/dpnp/dpnp_algo/dpnp_algo_indexing.pxi index e9dc538393c3..808961298c22 100644 --- a/dpnp/dpnp_algo/dpnp_algo_indexing.pxi +++ b/dpnp/dpnp_algo/dpnp_algo_indexing.pxi @@ -45,7 +45,6 @@ __all__ += [ "dpnp_put_along_axis", "dpnp_putmask", "dpnp_select", - "dpnp_take", "dpnp_take_along_axis", "dpnp_tril_indices", "dpnp_tril_indices_from", @@ -59,13 +58,6 @@ ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_choose_t)(c_dpctl.DPCTLSyclQueueRe ctypedef c_dpctl.DPCTLSyclEventRef(*fptr_dpnp_diag_indices)(c_dpctl.DPCTLSyclQueueRef, void * , size_t, const c_dpctl.DPCTLEventVectorRef) -ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t)(c_dpctl.DPCTLSyclQueueRef, - void *, - const size_t, - void * , - void * , - size_t, - const c_dpctl.DPCTLEventVectorRef) ctypedef c_dpctl.DPCTLSyclEventRef(*custom_indexing_2in_1out_func_ptr_t_)(c_dpctl.DPCTLSyclQueueRef, void * , const size_t, @@ -417,42 +409,6 @@ cpdef utils.dpnp_descriptor dpnp_select(list condlist, list choicelist, default) return res_array -cpdef utils.dpnp_descriptor dpnp_take(utils.dpnp_descriptor x1, utils.dpnp_descriptor indices): - cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype) - cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(indices.dtype) - - cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TAKE_EXT, param1_type, param2_type) - - x1_obj = x1.get_array() - - cdef utils.dpnp_descriptor result = utils.create_output_descriptor(indices.shape, - kernel_data.return_type, - None, - device=x1_obj.sycl_device, - usm_type=x1_obj.usm_type, - sycl_queue=x1_obj.sycl_queue) - - result_sycl_queue = result.get_array().sycl_queue - - cdef c_dpctl.SyclQueue q = result_sycl_queue - cdef c_dpctl.DPCTLSyclQueueRef q_ref = q.get_queue_ref() - - cdef custom_indexing_2in_1out_func_ptr_t func = kernel_data.ptr - - cdef c_dpctl.DPCTLSyclEventRef event_ref = func(q_ref, - x1.get_data(), - x1.size, - indices.get_data(), - result.get_data(), - indices.size, - NULL) # dep_events_ref - - with nogil: c_dpctl.DPCTLEvent_WaitAndThrow(event_ref) - c_dpctl.DPCTLEvent_Delete(event_ref) - - return result - - cpdef object dpnp_take_along_axis(object arr, object indices, int axis): cdef long size_arr = arr.size cdef shape_type_c shape_arr = arr.shape diff --git a/dpnp/dpnp_array.py b/dpnp/dpnp_array.py index 0ff1ca4a8e2a..75ae72a31d08 100644 --- a/dpnp/dpnp_array.py +++ b/dpnp/dpnp_array.py @@ -1033,15 +1033,15 @@ def sum( # 'swapaxes', - def take(self, indices, axis=None, out=None, mode="raise"): + def take(self, indices, /, *, axis=None, out=None, mode="wrap"): """ - Take elements from an array. + Take elements from an array along an axis. For full documentation refer to :obj:`numpy.take`. """ - return dpnp.take(self, indices, axis, out, mode) + return dpnp.take(self, indices, axis=axis, out=out, mode=mode) # 'tobytes', # 'tofile', diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index a0b352c4317a..15a0ede3964d 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -70,6 +70,13 @@ ] +def _is_supported_class(x): + # Check if the given object `x` is an instance of either as :class:`dpnp.ndarray` + # or :class:`dpctl.tensor.usm_ndarray`. + + return isinstance(x, (dpnp_array, dpt.usm_ndarray)) + + def choose(x1, choices, out=None, mode="raise"): """ Construct an array from an index array and a set of arrays to choose from. @@ -539,39 +546,63 @@ def select(condlist, choicelist, default=0): return call_origin(numpy.select, condlist, choicelist, default) -def take(x1, indices, axis=None, out=None, mode="raise"): +def take(x, indices, /, *, axis=None, out=None, mode="wrap"): """ - Take elements from an array. + Take elements from an array along an axis. For full documentation refer to :obj:`numpy.take`. + Returns + ------- + dpnp.ndarray + An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:] + filled with elements. + Limitations ----------- - Input array is supported as :obj:`dpnp.ndarray`. - Parameters ``axis``, ``out`` and ``mode`` are supported only with default values. - Parameter ``indices`` is supported as :obj:`dpnp.ndarray`. + Parameters `x` and `indices` are supported either as :class:`dpnp.ndarray` + or :class:`dpctl.tensor.usm_ndarray`. + Parameter `indices` is supported as 1-D array of integer data type. + Parameter `out` is supported only with default value. + Parameter `mode` is supported with ``wrap``(default) and ``clip`` mode. + Providing parameter `axis` is optional when `x` is a 1-D array. + Otherwise the function will be executed sequentially on CPU. See Also -------- :obj:`dpnp.compress` : Take elements using a boolean mask. :obj:`take_along_axis` : Take elements by matching the array and the index arrays. + + Notes + ----- + How out-of-bounds indices will be handled. + "wrap" - clamps indices to (-n <= i < n), then wraps negative indices. + "clip" - clips indices to (0 <= i < n) """ - x1_desc = dpnp.get_dpnp_descriptor(x1, copy_when_nondefault_queue=False) - indices_desc = dpnp.get_dpnp_descriptor( - indices, copy_when_nondefault_queue=False - ) - if x1_desc and indices_desc: - if axis is not None: + if _is_supported_class(x) and _is_supported_class(indices): + if indices.ndim != 1 or not dpnp.issubdtype( + indices.dtype, dpnp.integer + ): + pass + elif axis is None and x.ndim > 1: pass elif out is not None: pass - elif mode != "raise": + elif mode not in ("clip", "wrap"): pass else: - return dpnp_take(x1_desc, indices_desc).get_pyobj() + dpt_array = x.get_array() if isinstance(x, dpnp_array) else x + dpt_indices = ( + indices.get_array() + if isinstance(indices, dpnp_array) + else indices + ) + return dpnp_array._create_from_usm_ndarray( + dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode) + ) - return call_origin(numpy.take, x1, indices, axis, out, mode) + return call_origin(numpy.take, x, indices, axis, out, mode) def take_along_axis(x1, indices, axis): diff --git a/tests/skipped_tests.tbl b/tests/skipped_tests.tbl index d5eab8dc75cc..ff5494e5f71e 100644 --- a/tests/skipped_tests.tbl +++ b/tests/skipped_tests.tbl @@ -401,7 +401,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool -tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast diff --git a/tests/skipped_tests_gpu.tbl b/tests/skipped_tests_gpu.tbl index 9a44f4cbfce4..4fc12eff1aba 100644 --- a/tests/skipped_tests_gpu.tbl +++ b/tests/skipped_tests_gpu.tbl @@ -49,7 +49,6 @@ tests/test_sycl_queue.py::test_modf[level_zero:gpu:0] tests/test_sycl_queue.py::test_1in_1out[opencl:gpu:0-trapz-data19] tests/test_sycl_queue.py::test_1in_1out[opencl:cpu:0-trapz-data19] -tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_no_axis tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_0_{n=2, ndim=2}::test_diag_indices tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_1_{n=2, ndim=3}::test_diag_indices tests/third_party/cupy/indexing_tests/test_insert.py::TestDiagIndices_param_2_{n=2, ndim=1}::test_diag_indices @@ -597,7 +596,6 @@ tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compr tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_empty_1dim_no_axis tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_axis tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_compress_no_bool -tests/third_party/cupy/indexing_tests/test_indexing.py::TestIndexing::test_take_index_range_overflow tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_1D_choicelist tests/third_party/cupy/indexing_tests/test_indexing.py::TestSelect::test_select_choicelist_condlist_broadcast diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 354cfa06ff13..67600264356d 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -592,61 +592,51 @@ def test_select(): assert_array_equal(expected, result) +@pytest.mark.parametrize("array_type", get_all_dtypes()) @pytest.mark.parametrize( - "array_type", - [ - numpy.bool8, - numpy.int32, - numpy.int64, - numpy.float32, - numpy.float64, - numpy.complex128, - ], - ids=["bool8", "int32", "int64", "float32", "float64", "complex128"], + "indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"] ) @pytest.mark.parametrize( - "indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"] + "indices", [[-2, 2], [-5, 4]], ids=["[-2, 2]", "[-5, 4]"] ) +@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) +def test_take_1d(indices, array_type, indices_type, mode): + a = numpy.array([-2, -1, 0, 1, 2], dtype=array_type) + ind = numpy.array(indices, dtype=indices_type) + ia = dpnp.array(a) + iind = dpnp.array(ind) + expected = numpy.take(a, ind, mode=mode) + result = dpnp.take(ia, iind, mode=mode) + assert_array_equal(expected, result) + + +@pytest.mark.parametrize("array_type", get_all_dtypes()) @pytest.mark.parametrize( - "indices", - [[[0, 0], [0, 0]], [[1, 2], [1, 2]], [[1, 2], [3, 4]]], - ids=["[[0, 0], [0, 0]]", "[[1, 2], [1, 2]]", "[[1, 2], [3, 4]]"], + "indices_type", [numpy.int32, numpy.int64], ids=["int32", "int64"] ) @pytest.mark.parametrize( - "array", - [ - [[0, 1, 2], [3, 4, 5], [6, 7, 8]], - [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]], - [[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], - [ - [[[1, 2], [3, 4]], [[1, 2], [2, 1]]], - [[[1, 3], [3, 1]], [[0, 1], [1, 3]]], - ], - [ - [[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], - [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]], - ], - [ - [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], - [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]], - ], - ], - ids=[ - "[[0, 1, 2], [3, 4, 5], [6, 7, 8]]", - "[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]", - "[[[1, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]", - "[[[[1, 2], [3, 4]], [[1, 2], [2, 1]]], [[[1, 3], [3, 1]], [[0, 1], [1, 3]]]]", - "[[[[1, 2, 3], [3, 4, 5]], [[1, 2, 3], [2, 1, 0]]], [[[1, 3, 5], [3, 1, 0]], [[0, 1, 2], [1, 3, 4]]]]", - "[[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], [[[13, 14, 15], [16, 17, 18]], [[19, 20, 21], [22, 23, 24]]]]", - ], + "indices", [[-1, 0], [-3, 2]], ids=["[-1, 0]", "[-3, 2]"] ) -def test_take(array, indices, array_type, indices_type): - a = numpy.array(array, dtype=array_type) +@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) +@pytest.mark.parametrize("axis", [0, 1], ids=["0", "1"]) +def test_take_2d(indices, array_type, indices_type, axis, mode): + a = numpy.array([[-1, 0, 1], [-2, -3, -4], [2, 3, 4]], dtype=array_type) ind = numpy.array(indices, dtype=indices_type) ia = dpnp.array(a) iind = dpnp.array(ind) - expected = numpy.take(a, ind) - result = dpnp.take(ia, iind) + expected = numpy.take(a, ind, axis=axis, mode=mode) + result = dpnp.take(ia, iind, axis=axis, mode=mode) + assert_array_equal(expected, result) + + +@pytest.mark.parametrize("array_type", get_all_dtypes()) +@pytest.mark.parametrize("indices", [[-5, 5]], ids=["[-5, 5]"]) +@pytest.mark.parametrize("mode", ["clip", "wrap"], ids=["clip", "wrap"]) +def test_take_over_index(indices, array_type, mode): + a = dpnp.array([-2, -1, 0, 1, 2], dtype=array_type) + ind = dpnp.array(indices, dtype=dpnp.int64) + expected = dpnp.array([-2, 2], dtype=a.dtype) + result = dpnp.take(a, ind, mode=mode) assert_array_equal(expected, result) diff --git a/tests/third_party/cupy/indexing_tests/test_indexing.py b/tests/third_party/cupy/indexing_tests/test_indexing.py index ce6b1e86496d..9e323990891c 100644 --- a/tests/third_party/cupy/indexing_tests/test_indexing.py +++ b/tests/third_party/cupy/indexing_tests/test_indexing.py @@ -28,6 +28,7 @@ def test_take_by_array(self, xp): b = xp.array([[1, 3], [2, 0]]) return a.take(b, axis=1) + @pytest.mark.usefixtures("allow_fall_back_on_numpy") @testing.numpy_cupy_array_equal() def test_take_no_axis(self, xp): a = testing.shaped_arange((2, 3, 4), xp) @@ -46,7 +47,7 @@ def test_take_index_range_overflow(self, xp, dtype): if dtype in (numpy.int32, numpy.uint32): pytest.skip() iinfo = numpy.iinfo(dtype) - a = xp.broadcast_to(xp.ones(1), (iinfo.max + 1,)) + a = xp.broadcast_to(xp.ones(1, dtype=dtype), (iinfo.max + 1,)) b = xp.array([0], dtype=dtype) return a.take(b) From a7798bd438153089183ebd08c8dd1b8e3ede508f Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 26 Jul 2023 14:17:37 +0200 Subject: [PATCH 2/3] Add examples and use dpnp.is_supported_array_type --- dpnp/dpnp_iface_indexing.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 15a0ede3964d..83164a27faa6 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -70,13 +70,6 @@ ] -def _is_supported_class(x): - # Check if the given object `x` is an instance of either as :class:`dpnp.ndarray` - # or :class:`dpctl.tensor.usm_ndarray`. - - return isinstance(x, (dpnp_array, dpt.usm_ndarray)) - - def choose(x1, choices, out=None, mode="raise"): """ Construct an array from an index array and a set of arrays to choose from. @@ -556,7 +549,7 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): ------- dpnp.ndarray An array with shape x.shape[:axis] + indices.shape + x.shape[axis + 1:] - filled with elements. + filled with elements from `x`. Limitations ----------- @@ -578,9 +571,29 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): How out-of-bounds indices will be handled. "wrap" - clamps indices to (-n <= i < n), then wraps negative indices. "clip" - clips indices to (0 <= i < n) + + Examples + -------- + >>> import dpnp as np + >>> x = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> np.take(x, indices) + array([4, 3, 6]) + + >>> x[indices] + array([4, 3, 6]) + + >>> indices = dpnp.array([-1, -6, -7, 5, 6]) + >>> np.take(x, indices) + array([8, 4, 4, 8, 8]) + + >>> np.take(x, indices, mode="clip") + array([4, 4, 4, 8, 8]) """ - if _is_supported_class(x) and _is_supported_class(indices): + if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type( + indices + ): if indices.ndim != 1 or not dpnp.issubdtype( indices.dtype, dpnp.integer ): From de26663325d98312f862048e04d36c0517100f5c Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Wed, 26 Jul 2023 16:38:13 +0200 Subject: [PATCH 3/3] Use dpnp.get_usm_ndarray in take and update examples --- dpnp/dpnp_iface_indexing.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/dpnp/dpnp_iface_indexing.py b/dpnp/dpnp_iface_indexing.py index 83164a27faa6..0fd3803c917f 100644 --- a/dpnp/dpnp_iface_indexing.py +++ b/dpnp/dpnp_iface_indexing.py @@ -580,6 +580,8 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): >>> np.take(x, indices) array([4, 3, 6]) + In this example "fancy" indexing can be used. + >>> x[indices] array([4, 3, 6]) @@ -589,6 +591,7 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): >>> np.take(x, indices, mode="clip") array([4, 4, 4, 8, 8]) + """ if dpnp.is_supported_array_type(x) and dpnp.is_supported_array_type( @@ -605,12 +608,8 @@ def take(x, indices, /, *, axis=None, out=None, mode="wrap"): elif mode not in ("clip", "wrap"): pass else: - dpt_array = x.get_array() if isinstance(x, dpnp_array) else x - dpt_indices = ( - indices.get_array() - if isinstance(indices, dpnp_array) - else indices - ) + dpt_array = dpnp.get_usm_ndarray(x) + dpt_indices = dpnp.get_usm_ndarray(indices) return dpnp_array._create_from_usm_ndarray( dpt.take(dpt_array, dpt_indices, axis=axis, mode=mode) )