Skip to content

Fix implementation of dpctl.tensor.moveaxis() #1174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,21 +785,21 @@ def unstack(X, axis=0):
return tuple(Y[i] for i in range(Y.shape[0]))


def moveaxis(X, src, dst):
"""moveaxis(x, src, dst)
def moveaxis(X, source, destination):
"""moveaxis(x, source, destination)

Moves axes of an array to new positions.

Args:
x (usm_ndarray): input array

src (int or a sequence of int):
source (int or a sequence of int):
Original positions of the axes to move.
These must be unique. If `x` has rank (i.e., number of
dimensions) `N`, a valid `axis` must be in the
half-open interval `[-N, N)`.

dst (int or a sequence of int):
destination (int or a sequence of int):
Destination positions for each of the original axes.
These must also be unique. If `x` has rank
(i.e., number of dimensions) `N`, a valid `axis` must be
Expand All @@ -814,22 +814,30 @@ def moveaxis(X, src, dst):

Raises:
AxisError: if `axis` value is invalid.
ValueError: if `src` and `dst` have not equal number of elements.
"""
if not isinstance(X, dpt.usm_ndarray):
raise TypeError(f"Expected usm_ndarray type, got {type(X)}.")

if not isinstance(src, (tuple, list)):
src = (src,)
if not isinstance(source, (tuple, list)):
source = (source,)

if not isinstance(dst, (tuple, list)):
dst = (dst,)
if not isinstance(destination, (tuple, list)):
destination = (destination,)

src = normalize_axis_tuple(src, X.ndim, "src")
dst = normalize_axis_tuple(dst, X.ndim, "dst")
ind = list(range(0, X.ndim))
for i in range(len(src)):
ind.remove(src[i]) # using the value here which is the same as index
ind.insert(dst[i], src[i])
source = normalize_axis_tuple(source, X.ndim, "source")
destination = normalize_axis_tuple(destination, X.ndim, "destination")

if len(source) != len(destination):
raise ValueError(
"`source` and `destination` arguments must have "
"the same number of elements"
)

ind = [n for n in range(X.ndim) if n not in source]

for src, dst in sorted(zip(destination, source)):
ind.insert(src, dst)

return dpt.permute_dims(X, tuple(ind))

Expand Down
103 changes: 82 additions & 21 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import pytest
from numpy.testing import assert_array_equal
from numpy.testing import assert_, assert_array_equal, assert_raises_regex

import dpctl
import dpctl.tensor as dpt
Expand Down Expand Up @@ -1068,34 +1068,95 @@ def test_swapaxes_2d():
assert_array_equal(exp, dpt.asnumpy(res))


def test_moveaxis_1axis():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, 0, -1)

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, 0, -1)

assert_array_equal(exp, dpt.asnumpy(res))
@pytest.mark.parametrize(
"source, expected",
[
(0, (6, 7, 5)),
(1, (5, 7, 6)),
(2, (5, 6, 7)),
(-1, (5, 6, 7)),
],
)
def test_moveaxis_move_to_end(source, expected):
x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7))
actual = dpt.moveaxis(x, source, -1).shape
assert_(actual, expected)


def test_moveaxis_2axes():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, [0, 1], [-1, -2])
@pytest.mark.parametrize(
"source, destination, expected",
[
(0, 1, (2, 1, 3, 4)),
(1, 2, (1, 3, 2, 4)),
(1, -1, (1, 3, 4, 2)),
],
)
def test_moveaxis_new_position(source, destination, expected):
x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4))
actual = dpt.moveaxis(x, source, destination).shape
assert_(actual, expected)

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, [0, 1], [-1, -2])

assert_array_equal(exp, dpt.asnumpy(res))
@pytest.mark.parametrize(
"source, destination",
[
(0, 0),
(3, -1),
(-1, 3),
([0, -1], [0, -1]),
([2, 0], [2, 0]),
],
)
def test_moveaxis_preserve_order(source, destination):
x = dpt.zeros((1, 2, 3, 4))
actual = dpt.moveaxis(x, source, destination).shape
assert_(actual, (1, 2, 3, 4))


def test_moveaxis_3axes():
x = np.arange(60).reshape((3, 4, 5))
exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3])
@pytest.mark.parametrize(
"source, destination, expected",
[
([0, 1], [2, 3], (2, 3, 0, 1)),
([2, 3], [0, 1], (2, 3, 0, 1)),
([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)),
([3, 0], [1, 0], (0, 3, 1, 2)),
([0, 3], [0, 1], (0, 3, 1, 2)),
],
)
def test_moveaxis_move_multiples(source, destination, expected):
x = dpt.zeros((0, 1, 2, 3))
actual = dpt.moveaxis(x, source, destination).shape
assert_(actual, expected)

y = dpt.reshape(dpt.arange(60), (3, 4, 5))
res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3])

assert_array_equal(exp, dpt.asnumpy(res))
def test_moveaxis_errors():
x = dpt.reshape(dpt.arange(6), (1, 2, 3))
assert_raises_regex(
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0
)
assert_raises_regex(
np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0
)
assert_raises_regex(
np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5
)
assert_raises_regex(
ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1]
)
assert_raises_regex(
ValueError,
"repeated axis in `destination`",
dpt.moveaxis,
x,
[0, 1],
[1, 1],
)
assert_raises_regex(
ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1]
)
assert_raises_regex(
ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0]
)


def test_unstack_axis0():
Expand Down