From aefbb9a908adcf211c9ba6848090ffe87e1105fd Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 14 Aug 2023 22:28:50 +0200 Subject: [PATCH 1/2] Support shift tuple when axis=None for roll --- dpctl/tensor/_manipulation_functions.py | 3 +++ dpctl/tests/test_usm_ndarray_manipulation.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 9406e386af..165de83dd1 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -426,6 +426,9 @@ def roll(X, shift, axis=None): if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") if axis is None: + # get the combined shift value for all axes + if type(shift) is tuple: + shift = sum(shift) res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue ) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 1cee5e6c8f..eb7112462c 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -590,6 +590,8 @@ def test_roll_empty(): "data", [ [2, None], + [(0, 1), None], + [(-1, 0), None], [-2, None], [2, 0], [-2, 0], @@ -617,6 +619,8 @@ def test_roll_1d(data): "data", [ [1, None], + [(2, 1), None], + [(-1, 2), None], [1, 0], [1, 1], [1, ()], From 13708353a394f7c0b3a24cb12cb4e154c7fbb0ec Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Mon, 14 Aug 2023 22:33:55 +0200 Subject: [PATCH 2/2] Return X when src==dst for moveaxis() --- dpctl/tensor/_manipulation_functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 165de83dd1..2dcb365baa 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -885,6 +885,9 @@ def moveaxis(X, source, destination): "the same number of elements" ) + if source == destination: + return X + ind = [n for n in range(X.ndim) if n not in source] for src, dst in sorted(zip(destination, source)):