diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 7e476aba04..143d6b1bcb 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -5,4 +5,5 @@ import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.extra_ops +import pytensor.link.pytorch.dispatch.sort # isort: on diff --git a/pytensor/link/pytorch/dispatch/sort.py b/pytensor/link/pytorch/dispatch/sort.py new file mode 100644 index 0000000000..95e24c4fe3 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/sort.py @@ -0,0 +1,25 @@ +import torch + +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.sort import ArgSortOp, SortOp + + +@pytorch_funcify.register(SortOp) +def pytorch_funcify_Sort(op, **kwargs): + stable = op.kind == "stable" + + def sort(arr, axis): + sorted, _ = torch.sort(arr, dim=axis, stable=stable) + return sorted + + return sort + + +@pytorch_funcify.register(ArgSortOp) +def pytorch_funcify_ArgSort(op, **kwargs): + stable = op.kind == "stable" + + def argsort(arr, axis): + return torch.argsort(arr, dim=axis, stable=stable) + + return argsort diff --git a/tests/link/pytorch/test_sort.py b/tests/link/pytorch/test_sort.py new file mode 100644 index 0000000000..386a974cf4 --- /dev/null +++ b/tests/link/pytorch/test_sort.py @@ -0,0 +1,26 @@ +import numpy as np +import pytest + +from pytensor.graph import FunctionGraph +from pytensor.tensor import matrix +from pytensor.tensor.sort import argsort, sort +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +@pytest.mark.parametrize("func", (sort, argsort)) +@pytest.mark.parametrize( + "axis", + [ + pytest.param(0), + pytest.param(1), + pytest.param( + None, marks=pytest.mark.xfail(reason="Reshape Op not implemented") + ), + ], +) +def test_sort(func, axis): + x = matrix("x", shape=(2, 2), dtype="float64") + out = func(x, axis=axis) + fgraph = FunctionGraph([x], [out]) + arr = np.array([[1.0, 4.0], [5.0, 2.0]]) + compare_pytorch_and_py(fgraph, [arr])