From 1c4fedecc01f00da4c0713bed755174d22805f71 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 18 Aug 2023 23:09:44 -0700 Subject: [PATCH 1/3] Implements matrix_transpose - Function wrapper for call to dpctl.tensor.usm_ndarray.mT attribute --- dpctl/tensor/__init__.py | 2 + dpctl/tensor/_linear_algebra_functions.py | 48 +++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 dpctl/tensor/_linear_algebra_functions.py diff --git a/dpctl/tensor/__init__.py b/dpctl/tensor/__init__.py index 7438fb8a67..ad51689f3f 100644 --- a/dpctl/tensor/__init__.py +++ b/dpctl/tensor/__init__.py @@ -60,6 +60,7 @@ from dpctl.tensor._device import Device from dpctl.tensor._dlpack import from_dlpack from dpctl.tensor._indexing_functions import extract, nonzero, place, put, take +from dpctl.tensor._linear_algebra_functions import matrix_transpose from dpctl.tensor._manipulation_functions import ( broadcast_arrays, broadcast_to, @@ -199,6 +200,7 @@ "tril", "triu", "where", + "matrix_transpose", "all", "any", "dtype", diff --git a/dpctl/tensor/_linear_algebra_functions.py b/dpctl/tensor/_linear_algebra_functions.py new file mode 100644 index 0000000000..fd2c58b08a --- /dev/null +++ b/dpctl/tensor/_linear_algebra_functions.py @@ -0,0 +1,48 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dpctl.tensor as dpt + + +def matrix_transpose(x): + """matrix_transpose(x) + + Transposes the innermost two dimensions of `x`, where `x` is a + 2-dimensional matrix or a stack of 2-dimensional matrices. + + To convert from a 1-dimensional array to a 2-dimensional column + vector, use x[:, dpt.newaxis]. + + Args: + x (usm_ndarray): + Input array with shape (..., m, n). + + Returns: + usm_ndarray: + Array with shape (..., n, m). + """ + + if not isinstance(x, dpt.usm_ndarray): + raise TypeError( + "Expected instance of `dpt.usm_ndarray`, got `{}`.".format(type(x)) + ) + if x.ndim < 2: + raise ValueError( + "dpctl.tensor.matrix_transpose requires array to have" + "at least 2 dimensions" + ) + + return x.mT From 460d5e528404d1d5afaabbde2b78622a858b4ede Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 18 Aug 2023 23:09:48 -0700 Subject: [PATCH 2/3] Add arg validation tests for matrix_transpose --- dpctl/tests/test_usm_ndarray_linalg.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 dpctl/tests/test_usm_ndarray_linalg.py diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py new file mode 100644 index 0000000000..87e1e425eb --- /dev/null +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -0,0 +1,33 @@ +# Data Parallel Control (dpctl) +# +# Copyright 2020-2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +import dpctl.tensor as dpt +from dpctl.tests.helper import get_queue_or_skip + + +def test_matrix_transpose_arg_validation(): + get_queue_or_skip() + + X = dpt.ones(5, dtype="i4") + + with pytest.raises(ValueError): + dpt.matrix_transpose(X) + + X = dict() + with pytest.raises(TypeError): + dpt.matrix_transpose(X) From b3ab3da75fd2136309812117e9fc429e82cd31e3 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Sat, 19 Aug 2023 14:34:24 -0700 Subject: [PATCH 3/3] Added a test for matrix_transpose for coverage --- dpctl/tests/test_usm_ndarray_linalg.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/dpctl/tests/test_usm_ndarray_linalg.py b/dpctl/tests/test_usm_ndarray_linalg.py index 87e1e425eb..4023eb8ad7 100644 --- a/dpctl/tests/test_usm_ndarray_linalg.py +++ b/dpctl/tests/test_usm_ndarray_linalg.py @@ -20,14 +20,29 @@ from dpctl.tests.helper import get_queue_or_skip -def test_matrix_transpose_arg_validation(): +def test_matrix_transpose(): get_queue_or_skip() - X = dpt.ones(5, dtype="i4") + X = dpt.reshape(dpt.arange(2 * 3, dtype="i4"), (2, 3)) + res = dpt.matrix_transpose(X) + expected_res = X.mT + + assert expected_res.shape == res.shape + assert expected_res.flags["C"] == res.flags["C"] + assert expected_res.flags["F"] == res.flags["F"] + assert dpt.all(X.mT == res) + +def test_matrix_transpose_arg_validation(): + get_queue_or_skip() + + X = dpt.empty(5, dtype="i4") with pytest.raises(ValueError): dpt.matrix_transpose(X) X = dict() with pytest.raises(TypeError): dpt.matrix_transpose(X) + + X = dpt.empty((5, 5), dtype="i4") + assert isinstance(dpt.matrix_transpose(X), dpt.usm_ndarray)