diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 7438fb8a67..ad51689f3f 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,6 +60,7 @@ from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take +from dpctl.tensor._linear_algebra_functions import matrix_transpose from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -199,6 +200,7 @@ "tril", "triu", "where", + "matrix_transpose", "all", "any", "dtype", diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py new file mode 100644 index 0000000000..fd2c58b08a --- /dev/null +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -0,0 +1,48 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl.tensor as dpt + + +def matrix_transpose(x): + """matrix_transpose(x) + + Transposes the innermost two dimensions of `x`, where `x` is a + 2-dimensional matrix or a stack of 2-dimensional matrices. + + To convert from a 1-dimensional array to a 2-dimensional column + vector, use x[:, dpt.newaxis]. + + Args: + x (usm_ndarray): + Input array with shape (..., m, n). + + Returns: + usm_ndarray: + Array with shape (..., n, m). + """ + + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) + ) + if x.ndim < 2: + raise ValueError( + "dpctl.tensor.matrix_transpose requires array to have" + "at least 2 dimensions" + ) + + return x.mT diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py new file mode 100644 index 0000000000..4023eb8ad7 --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -0,0 +1,48 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip + + +def test_matrix_transpose(): + get_queue_or_skip() + + X = dpt.reshape(dpt.arange(2 * 3, dtype="i4"), (2, 3)) + res = dpt.matrix_transpose(X) + expected_res = X.mT + + assert expected_res.shape == res.shape + assert expected_res.flags["C"] == res.flags["C"] + assert expected_res.flags["F"] == res.flags["F"] + assert dpt.all(X.mT == res) + + +def test_matrix_transpose_arg_validation(): + get_queue_or_skip() + + X = dpt.empty(5, dtype="i4") + with pytest.raises(ValueError): + dpt.matrix_transpose(X) + + X = dict() + with pytest.raises(TypeError): + dpt.matrix_transpose(X) + + X = dpt.empty((5, 5), dtype="i4") + assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray)