Skip to content

Commit 8e61562

Browse files
Use equal computations in matrix_transpose test
1 parent e6704c6 commit 8e61562

File tree

2 files changed

+18
-23
lines changed

2 files changed

+18
-23
lines changed

pytensor/tensor/basic.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2000,18 +2000,19 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
20002000
20012001
Examples
20022002
--------
2003-
>>> import pytensor as ptb
2003+
>>> import pytensor as pt
20042004
>>> import numpy as np
20052005
>>> x = np.arange(24).reshape((2, 3, 4))
2006-
[[[ 0, 1, 2, 3],
2007-
[ 4, 5, 6, 7],
2008-
[ 8, 9, 10, 11]],
2006+
[[[ 0 1 2 3]
2007+
[ 4 5 6 7]
2008+
[ 8 9 10 11]]
20092009
2010-
[[12, 13, 14, 15],
2011-
[16, 17, 18, 19],
2012-
[20, 21, 22, 23]]]
2010+
[[12 13 14 15]
2011+
[16 17 18 19]
2012+
[20 21 22 23]]]
20132013
2014-
>>> ptb.matrix_transpose(x)
2014+
2015+
>>> pt.matrix_transpose(x).eval()
20152016
[[[ 0 4 8]
20162017
[ 1 5 9]
20172018
[ 2 6 10]
@@ -2022,6 +2023,7 @@ def matrix_transpose(x: "TensorLike") -> TensorVariable:
20222023
[14 18 22]
20232024
[15 19 23]]]
20242025
2026+
20252027
Notes
20262028
-----
20272029
This function transposes each 2-dimensional matrix within the input tensor along
@@ -4356,6 +4358,7 @@ def ix_(*args):
43564358
"join",
43574359
"split",
43584360
"transpose",
4361+
"matrix_transpose",
43594362
"extract_constant",
43604363
"default",
43614364
"tensor_copy",

tests/tensor/test_basic.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3841,28 +3841,20 @@ def test_transpose():
38413841

38423842
def test_matrix_transpose():
38433843
with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"):
3844-
ptb.matrix_transpose(np.arange(6))
3844+
ptb.matrix_transpose(dvector("x1"))
38453845

38463846
x2 = dmatrix("x2")
38473847
x3 = dtensor3("x3")
38483848

3849-
x2v = np.arange(6).reshape((2, 3))
3850-
x3v = np.arange(12).reshape((2, 3, 2))
3849+
var1 = ptb.matrix_transpose(x2)
3850+
expected_var1 = swapaxes(x2, -1, -2)
38513851

3852-
f = pytensor.function(
3853-
[x2, x3],
3854-
[
3855-
x2.mT,
3856-
ptb.matrix_transpose(x3),
3857-
],
3858-
)
3859-
t2, t3 = f(x2v, x3v)
3852+
var2 = ptb.matrix_transpose(x3)
3853+
expected_var2 = swapaxes(x3, -1, -2)
38603854

3861-
assert equal_computations([t2], [np.transpose(x2v)])
3855+
assert equal_computations([var1], [expected_var1])
38623856
# TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099)
3863-
assert equal_computations(
3864-
[t3], [np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])])]
3865-
)
3857+
assert equal_computations([var2], [expected_var2])
38663858

38673859

38683860
def test_stacklists():

0 commit comments

Comments
 (0)