diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 9406e386af..2dcb365baa 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -426,6 +426,9 @@ def roll(X, shift, axis=None): if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") if axis is None: + # get the combined shift value for all axes + if type(shift) is tuple: + shift = sum(shift) res = dpt.empty( X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=X.sycl_queue ) @@ -882,6 +885,9 @@ def moveaxis(X, source, destination): "the same number of elements" ) + if source == destination: + return X + ind = [n for n in range(X.ndim) if n not in source] for src, dst in sorted(zip(destination, source)): diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index 1cee5e6c8f..eb7112462c 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -590,6 +590,8 @@ def test_roll_empty(): "data", [ [2, None], + [(0, 1), None], + [(-1, 0), None], [-2, None], [2, 0], [-2, 0], @@ -617,6 +619,8 @@ def test_roll_1d(data): "data", [ [1, None], + [(2, 1), None], + [(-1, 2), None], [1, 0], [1, 1], [1, ()],