diff --git a/dpctl/tensor/_manipulation_functions.py b/dpctl/tensor/_manipulation_functions.py index e62ebce236..a099ddc031 100644 --- a/dpctl/tensor/_manipulation_functions.py +++ b/dpctl/tensor/_manipulation_functions.py @@ -15,23 +15,11 @@ # limitations under the License. -import numpy as np from numpy.core.numeric import normalize_axis_tuple import dpctl.tensor as dpt -def _check_value_of_axes(axes): - axes_len = len(axes) - check_array = np.zeros(axes_len) - for i in axes: - ii = i.__index__() - if ii < 0 or ii > axes_len or check_array[ii] != 0: - return False - check_array[ii] = 1 - return True - - def permute_dims(X, axes): """ permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray @@ -48,11 +36,7 @@ def permute_dims(X, axes): "The length of the passed axes does not match " "to the number of usm_ndarray dimensions." ) - if not _check_value_of_axes(axes): - raise ValueError( - "The values of the axes must be in the range " - f"from 0 to {X.ndim} and have no duplicates." - ) + axes = normalize_axis_tuple(axes, X.ndim, "axes") newstrides = tuple(X.strides[i] for i in axes) newshape = tuple(X.shape[i] for i in axes) return dpt.usm_ndarray( diff --git a/dpctl/tests/test_usm_ndarray_manipulation.py b/dpctl/tests/test_usm_ndarray_manipulation.py index bcbe093b23..37af01f9ef 100644 --- a/dpctl/tests/test_usm_ndarray_manipulation.py +++ b/dpctl/tests/test_usm_ndarray_manipulation.py @@ -63,7 +63,7 @@ def test_permute_dims_0d_1d(): assert_array_equal(dpt.asnumpy(Y_1d), dpt.asnumpy(X_1d)) pytest.raises(ValueError, dpt.permute_dims, X_1d, ()) - pytest.raises(IndexError, dpt.permute_dims, X_1d, (1)) + pytest.raises(np.AxisError, dpt.permute_dims, X_1d, (1)) pytest.raises(ValueError, dpt.permute_dims, X_1d, (1, 0)) pytest.raises( ValueError, dpt.permute_dims, dpt.reshape(X_1d, (2, 3)), (1, 1)