From 15e5e7312a2479494c1a74d7250315d7be386578 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Mon, 17 Apr 2023 23:05:48 -0500 Subject: [PATCH] fix an error for moveaxis --- dpctl/tensor/_manipulation_functions.py | 36 ++++--- dpctl/tests/test_usm_ndarray_manipulation.py | 103 +++++++++++++++---- 2 files changed, 104 insertions(+), 35 deletions(-) diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index 845f2d6f80..07c7e4b6ac 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -785,21 +785,21 @@ def unstack(X, axis=0): return tuple(Y[i] for i in range(Y.shape[0])) -def moveaxis(X, src, dst): - """moveaxis(x, src, dst) +def moveaxis(X, source, destination): + """moveaxis(x, source, destination) Moves axes of an array to new positions. Args: x (usm_ndarray): input array - src (int or a sequence of int): + source (int or a sequence of int): Original positions of the axes to move. These must be unique. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be in the half-open interval `[-N, N)`. - dst (int or a sequence of int): + destination (int or a sequence of int): Destination positions for each of the original axes. These must also be unique. If `x` has rank (i.e., number of dimensions) `N`, a valid `axis` must be @@ -814,22 +814,30 @@ def moveaxis(X, src, dst): Raises: AxisError: if `axis` value is invalid. + ValueError: if `src` and `dst` have not equal number of elements. """ if not isinstance(X, dpt.usm_ndarray): raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") - if not isinstance(src, (tuple, list)): - src = (src,) + if not isinstance(source, (tuple, list)): + source = (source,) - if not isinstance(dst, (tuple, list)): - dst = (dst,) + if not isinstance(destination, (tuple, list)): + destination = (destination,) - src = normalize_axis_tuple(src, X.ndim, "src") - dst = normalize_axis_tuple(dst, X.ndim, "dst") - ind = list(range(0, X.ndim)) - for i in range(len(src)): - ind.remove(src[i]) # using the value here which is the same as index - ind.insert(dst[i], src[i]) + source = normalize_axis_tuple(source, X.ndim, "source") + destination = normalize_axis_tuple(destination, X.ndim, "destination") + + if len(source) != len(destination): + raise ValueError( + "`source` and `destination` arguments must have " + "the same number of elements" + ) + + ind = [n for n in range(X.ndim) if n not in source] + + for src, dst in sorted(zip(destination, source)): + ind.insert(src, dst) return dpt.permute_dims(X, tuple(ind)) diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index a6e8ec244c..7bbf9d6ce6 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -17,7 +17,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_, assert_array_equal, assert_raises_regex import dpctl import dpctl.tensor as dpt @@ -1068,34 +1068,95 @@ def test_swapaxes_2d(): assert_array_equal(exp, dpt.asnumpy(res)) -def test_moveaxis_1axis(): - x = np.arange(60).reshape((3, 4, 5)) - exp = np.moveaxis(x, 0, -1) - - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) - res = dpt.moveaxis(y, 0, -1) - - assert_array_equal(exp, dpt.asnumpy(res)) +@pytest.mark.parametrize( + "source, expected", + [ + (0, (6, 7, 5)), + (1, (5, 7, 6)), + (2, (5, 6, 7)), + (-1, (5, 6, 7)), + ], +) +def test_moveaxis_move_to_end(source, expected): + x = dpt.reshape(dpt.arange(5 * 6 * 7), (5, 6, 7)) + actual = dpt.moveaxis(x, source, -1).shape + assert_(actual, expected) -def test_moveaxis_2axes(): - x = np.arange(60).reshape((3, 4, 5)) - exp = np.moveaxis(x, [0, 1], [-1, -2]) +@pytest.mark.parametrize( + "source, destination, expected", + [ + (0, 1, (2, 1, 3, 4)), + (1, 2, (1, 3, 2, 4)), + (1, -1, (1, 3, 4, 2)), + ], +) +def test_moveaxis_new_position(source, destination, expected): + x = dpt.reshape(dpt.arange(24), (1, 2, 3, 4)) + actual = dpt.moveaxis(x, source, destination).shape + assert_(actual, expected) - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) - res = dpt.moveaxis(y, [0, 1], [-1, -2]) - assert_array_equal(exp, dpt.asnumpy(res)) +@pytest.mark.parametrize( + "source, destination", + [ + (0, 0), + (3, -1), + (-1, 3), + ([0, -1], [0, -1]), + ([2, 0], [2, 0]), + ], +) +def test_moveaxis_preserve_order(source, destination): + x = dpt.zeros((1, 2, 3, 4)) + actual = dpt.moveaxis(x, source, destination).shape + assert_(actual, (1, 2, 3, 4)) -def test_moveaxis_3axes(): - x = np.arange(60).reshape((3, 4, 5)) - exp = np.moveaxis(x, [0, 1, 2], [-1, -2, -3]) +@pytest.mark.parametrize( + "source, destination, expected", + [ + ([0, 1], [2, 3], (2, 3, 0, 1)), + ([2, 3], [0, 1], (2, 3, 0, 1)), + ([0, 1, 2], [2, 3, 0], (2, 3, 0, 1)), + ([3, 0], [1, 0], (0, 3, 1, 2)), + ([0, 3], [0, 1], (0, 3, 1, 2)), + ], +) +def test_moveaxis_move_multiples(source, destination, expected): + x = dpt.zeros((0, 1, 2, 3)) + actual = dpt.moveaxis(x, source, destination).shape + assert_(actual, expected) - y = dpt.reshape(dpt.arange(60), (3, 4, 5)) - res = dpt.moveaxis(y, [0, 1, 2], [-1, -2, -3]) - assert_array_equal(exp, dpt.asnumpy(res)) +def test_moveaxis_errors(): + x = dpt.reshape(dpt.arange(6), (1, 2, 3)) + assert_raises_regex( + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, 3, 0 + ) + assert_raises_regex( + np.AxisError, "source.*out of bounds", dpt.moveaxis, x, -4, 0 + ) + assert_raises_regex( + np.AxisError, "destination.*out of bounds", dpt.moveaxis, x, 0, 5 + ) + assert_raises_regex( + ValueError, "repeated axis in `source`", dpt.moveaxis, x, [0, 0], [0, 1] + ) + assert_raises_regex( + ValueError, + "repeated axis in `destination`", + dpt.moveaxis, + x, + [0, 1], + [1, 1], + ) + assert_raises_regex( + ValueError, "must have the same number", dpt.moveaxis, x, 0, [0, 1] + ) + assert_raises_regex( + ValueError, "must have the same number", dpt.moveaxis, x, [0, 1], [0] + ) def test_unstack_axis0():