From 6593253bfe7aa4111b8440622e81ee93147538c7 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Wed, 10 Apr 2024 19:37:54 +0530 Subject: [PATCH 1/5] Added support for matrix_transpose --- pytensor/tensor/basic.py | 58 +++++++++++++++++++++++++++++ pytensor/tensor/rewriting/linalg.py | 8 ++-- pytensor/tensor/variable.py | 4 ++ scripts/mypy-failing.txt | 1 + tests/tensor/test_basic.py | 27 ++++++++++++++ 5 files changed, 94 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index f36f8888ba..2782bbf492 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1982,6 +1982,64 @@ def transpose(x, axes=None): return ret +def matrix_transpose(x, axes=None): + """ + Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor. + + Parameters + ---------- + x : array_like + Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions + of the matrices. Each matrix is of shape (M, N). + + axes : list of int, optional + By default, reverse the dimensions, otherwise permute the axes according + to the values given. + + Returns + ------- + out : tensor + Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix + in the input tensor has been transposed along the last two dimensions. + + Examples + -------- + >>> import pytensor as ptb + >>> import numpy as np + >>> x = np.arange(24).reshape((2, 3, 4)) + [[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]] + + >>> ptb.matrix_transpose(x) + [[[ 0 4 8] + [ 1 5 9] + [ 2 6 10] + [ 3 7 11]] + + [[12 16 20] + [13 17 21] + [14 18 22] + [15 19 23]]] + + Notes + ----- + This function transposes each 2-dimensional matrix within the input tensor along + the last two dimensions. If the input tensor has more than two dimensions, it + transposes each 2-dimensional matrix independently while preserving other dimensions. + """ + _x = as_tensor_variable(x) + if _x.ndim < 2: + raise ValueError( + f"Input array must be at least 2-dimensional, but it is {_x.ndim}" + ) + return swapaxes(x, -1, -2) + + def split(x, splits_size, n_splits, axis=0): the_split = Split(n_splits) return the_split(x, axis, splits_size) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 717a7af884..203606be76 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -43,7 +43,7 @@ def is_matrix_transpose(x: TensorVariable) -> bool: return False -def _T(x: TensorVariable) -> TensorVariable: +def _mT(x: TensorVariable) -> TensorVariable: """Matrix transpose for potentially higher dimensionality tensors""" return swapaxes(x, -1, -2) @@ -83,9 +83,9 @@ def inv_as_solve(fgraph, node): ): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: - return [_T(solve(x, _T(l)))] + return [_mT(solve(x, _mT(l)))] else: - return [_T(solve(_T(x), _T(l)))] + return [_mT(solve(_mT(x), _mT(l)))] @register_stabilize @@ -216,7 +216,7 @@ def psd_solve_with_chol(fgraph, node): # __if__ no other Op makes use of the L matrix during the # stabilization Li_b = solve_triangular(L, b, lower=True, b_ndim=2) - x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2) + x = solve_triangular(_mT(L), Li_b, lower=False, b_ndim=2) return [x] diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index c1dc3d2de3..6100108380 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -232,6 +232,10 @@ def __trunc__(self): def T(self): return pt.basic.transpose(self) + @property + def mT(self): + return pt.basic.matrix_transpose(self) + def transpose(self, *axes): """Transpose this array. diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 52fa8dc502..9df9211236 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -7,6 +7,7 @@ pytensor/compile/sharedvalue.py pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/basic.py +pytensor/link/c/cmodule.py pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/random.py pytensor/link/numba/dispatch/scan.py diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index d96dc3fd0c..d7cb20c073 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3813,6 +3813,7 @@ def test_transpose(): ) t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v) + assert t1.shape == np.transpose(x1v).shape assert t2.shape == np.transpose(x2v).shape assert t3.shape == np.transpose(x3v).shape @@ -3838,6 +3839,32 @@ def test_transpose(): assert ptb.transpose(dmatrix()).name is None +def test_matrix_transpose(): + with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"): + ptb.matrix_transpose(np.arange(6)) + + x2 = dmatrix("x2") + x3 = dtensor3("x3") + + x2v = np.arange(6).reshape((2, 3)) + x3v = np.arange(12).reshape((2, 3, 2)) + + f = pytensor.function( + [x2, x3], + [ + ptb.matrix_transpose(x2), + ptb.matrix_transpose(x3), + ], + ) + t2, t3 = f(x2v, x3v) + + assert t2.shape == np.transpose(x2v).shape + assert np.all(t2 == np.transpose(x2v)) + # TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099) + assert t3.shape == np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]).shape + assert np.all(t3 == np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])])) + + def test_stacklists(): a, b, c, d = map(scalar, "abcd") X = stacklists([[a, b], [c, d]]) From 02e46abe282834ec1cd8bb9e52900c3feb944e47 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 11 Apr 2024 03:14:07 +0530 Subject: [PATCH 2/5] Support matrix_transpose and mT --- pytensor/tensor/basic.py | 12 ++++-------- pytensor/tensor/rewriting/linalg.py | 13 ++++--------- tests/tensor/test_basic.py | 10 +++++----- 3 files changed, 13 insertions(+), 22 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 2782bbf492..4107388542 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1982,7 +1982,7 @@ def transpose(x, axes=None): return ret -def matrix_transpose(x, axes=None): +def matrix_transpose(x: "TensorLike") -> TensorVariable: """ Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor. @@ -1992,10 +1992,6 @@ def matrix_transpose(x, axes=None): Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions of the matrices. Each matrix is of shape (M, N). - axes : list of int, optional - By default, reverse the dimensions, otherwise permute the axes according - to the values given. - Returns ------- out : tensor @@ -2032,10 +2028,10 @@ def matrix_transpose(x, axes=None): the last two dimensions. If the input tensor has more than two dimensions, it transposes each 2-dimensional matrix independently while preserving other dimensions. """ - _x = as_tensor_variable(x) - if _x.ndim < 2: + x = as_tensor_variable(x) + if x.ndim < 2: raise ValueError( - f"Input array must be at least 2-dimensional, but it is {_x.ndim}" + f"Input array must be at least 2-dimensional, but it is {x.ndim}" ) return swapaxes(x, -1, -2) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 203606be76..ea83d9356a 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,7 +2,7 @@ from typing import cast from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter -from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes +from pytensor.tensor.basic import TensorVariable, diagonal from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle @@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool: return False -def _mT(x: TensorVariable) -> TensorVariable: - """Matrix transpose for potentially higher dimensionality tensors""" - return swapaxes(x, -1, -2) - - @register_canonicalize @node_rewriter([DimShuffle]) def transinv_to_invtrans(fgraph, node): @@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node): ): x = r.owner.inputs[0] if getattr(x.tag, "symmetric", None) is True: - return [_mT(solve(x, _mT(l)))] + return [solve(x, (l.mT)).mT] else: - return [_mT(solve(_mT(x), _mT(l)))] + return [solve((x.mT), (l.mT)).mT] @register_stabilize @@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node): # __if__ no other Op makes use of the L matrix during the # stabilization Li_b = solve_triangular(L, b, lower=True, b_ndim=2) - x = solve_triangular(_mT(L), Li_b, lower=False, b_ndim=2) + x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2) return [x] diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index d7cb20c073..7d5038cb12 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3852,17 +3852,17 @@ def test_matrix_transpose(): f = pytensor.function( [x2, x3], [ - ptb.matrix_transpose(x2), + x2.mT, ptb.matrix_transpose(x3), ], ) t2, t3 = f(x2v, x3v) - assert t2.shape == np.transpose(x2v).shape - assert np.all(t2 == np.transpose(x2v)) + assert equal_computations([t2], [np.transpose(x2v)]) # TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099) - assert t3.shape == np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]).shape - assert np.all(t3 == np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])])) + assert equal_computations( + [t3], [np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])])] + ) def test_stacklists(): From 388b1130db947e6dfc5c18b39de8b6b2a5268624 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 11 Apr 2024 21:44:07 +0530 Subject: [PATCH 3/5] Use equal computations in matrix_transpose test --- pytensor/tensor/basic.py | 19 +++++++++++-------- tests/tensor/test_basic.py | 22 +++++++--------------- 2 files changed, 18 insertions(+), 23 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 4107388542..b81439f960 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2000,18 +2000,19 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable: Examples -------- - >>> import pytensor as ptb + >>> import pytensor as pt >>> import numpy as np >>> x = np.arange(24).reshape((2, 3, 4)) - [[[ 0, 1, 2, 3], - [ 4, 5, 6, 7], - [ 8, 9, 10, 11]], + [[[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] - [[12, 13, 14, 15], - [16, 17, 18, 19], - [20, 21, 22, 23]]] + [[12 13 14 15] + [16 17 18 19] + [20 21 22 23]]] - >>> ptb.matrix_transpose(x) + + >>> pt.matrix_transpose(x).eval() [[[ 0 4 8] [ 1 5 9] [ 2 6 10] @@ -2022,6 +2023,7 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable: [14 18 22] [15 19 23]]] + Notes ----- This function transposes each 2-dimensional matrix within the input tensor along @@ -4356,6 +4358,7 @@ def ix_(*args): "join", "split", "transpose", + "matrix_transpose", "extract_constant", "default", "tensor_copy", diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 7d5038cb12..66f80703d9 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3841,28 +3841,20 @@ def test_transpose(): def test_matrix_transpose(): with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"): - ptb.matrix_transpose(np.arange(6)) + ptb.matrix_transpose(dvector("x1")) x2 = dmatrix("x2") x3 = dtensor3("x3") - x2v = np.arange(6).reshape((2, 3)) - x3v = np.arange(12).reshape((2, 3, 2)) + var1 = ptb.matrix_transpose(x2) + expected_var1 = swapaxes(x2, -1, -2) - f = pytensor.function( - [x2, x3], - [ - x2.mT, - ptb.matrix_transpose(x3), - ], - ) - t2, t3 = f(x2v, x3v) + var2 = ptb.matrix_transpose(x3) + expected_var2 = swapaxes(x3, -1, -2) - assert equal_computations([t2], [np.transpose(x2v)]) + assert equal_computations([var1], [expected_var1]) # TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099) - assert equal_computations( - [t3], [np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])])] - ) + assert equal_computations([var2], [expected_var2]) def test_stacklists(): From 404ff30b5eb561fb251a2ed69c10738512ce7c7b Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Thu, 11 Apr 2024 22:03:11 +0530 Subject: [PATCH 4/5] Added test for mT --- tests/tensor/test_basic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 66f80703d9..0f161760bd 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3849,11 +3849,10 @@ def test_matrix_transpose(): var1 = ptb.matrix_transpose(x2) expected_var1 = swapaxes(x2, -1, -2) - var2 = ptb.matrix_transpose(x3) + var2 = x3.mT expected_var2 = swapaxes(x3, -1, -2) assert equal_computations([var1], [expected_var1]) - # TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099) assert equal_computations([var2], [expected_var2]) From a66aa225ad314eb1c9bb1883986ccb9cacb7f5b1 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Fri, 12 Apr 2024 20:34:22 +0530 Subject: [PATCH 5/5] Matrix_transpose added after resolving mypy error --- scripts/mypy-failing.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/mypy-failing.txt b/scripts/mypy-failing.txt index 9df9211236..52fa8dc502 100644 --- a/scripts/mypy-failing.txt +++ b/scripts/mypy-failing.txt @@ -7,7 +7,6 @@ pytensor/compile/sharedvalue.py pytensor/graph/rewriting/basic.py pytensor/ifelse.py pytensor/link/basic.py -pytensor/link/c/cmodule.py pytensor/link/numba/dispatch/elemwise.py pytensor/link/numba/dispatch/random.py pytensor/link/numba/dispatch/scan.py