diff --git a/pytensor/link/jax/dispatch/__init__.py b/pytensor/link/jax/dispatch/__init__.py index 00976f221c..5da81bf80c 100644 --- a/pytensor/link/jax/dispatch/__init__.py +++ b/pytensor/link/jax/dispatch/__init__.py @@ -14,6 +14,7 @@ import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.shape +import pytensor.link.jax.dispatch.signal import pytensor.link.jax.dispatch.slinalg import pytensor.link.jax.dispatch.sort import pytensor.link.jax.dispatch.sparse diff --git a/pytensor/link/jax/dispatch/signal/__init__.py b/pytensor/link/jax/dispatch/signal/__init__.py new file mode 100644 index 0000000000..9264ff44bd --- /dev/null +++ b/pytensor/link/jax/dispatch/signal/__init__.py @@ -0,0 +1 @@ +import pytensor.link.jax.dispatch.signal.conv diff --git a/pytensor/link/jax/dispatch/signal/conv.py b/pytensor/link/jax/dispatch/signal/conv.py new file mode 100644 index 0000000000..1c124065e2 --- /dev/null +++ b/pytensor/link/jax/dispatch/signal/conv.py @@ -0,0 +1,14 @@ +import jax + +from pytensor.link.jax.dispatch import jax_funcify +from pytensor.tensor.signal.conv import Conv1d + + +@jax_funcify.register(Conv1d) +def jax_funcify_Conv1d(op, node, **kwargs): + mode = op.mode + + def conv1d(data, kernel): + return jax.numpy.convolve(data, kernel, mode=mode) + + return conv1d diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 56a3e2c9b2..1fefb1d06d 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,9 +9,11 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.scalar +import pytensor.link.numba.dispatch.signal import pytensor.link.numba.dispatch.slinalg import pytensor.link.numba.dispatch.sparse import pytensor.link.numba.dispatch.subtensor import pytensor.link.numba.dispatch.tensor_basic + # isort: on diff --git a/pytensor/link/numba/dispatch/signal/__init__.py b/pytensor/link/numba/dispatch/signal/__init__.py new file mode 100644 index 0000000000..db4834d67d --- /dev/null +++ b/pytensor/link/numba/dispatch/signal/__init__.py @@ -0,0 +1 @@ +import pytensor.link.numba.dispatch.signal.conv diff --git a/pytensor/link/numba/dispatch/signal/conv.py b/pytensor/link/numba/dispatch/signal/conv.py new file mode 100644 index 0000000000..b1c63a440c --- /dev/null +++ b/pytensor/link/numba/dispatch/signal/conv.py @@ -0,0 +1,16 @@ +import numpy as np + +from pytensor.link.numba.dispatch import numba_funcify +from pytensor.link.numba.dispatch.basic import numba_njit +from pytensor.tensor.signal.conv import Conv1d + + +@numba_funcify.register(Conv1d) +def numba_funcify_Conv1d(op, node, **kwargs): + mode = op.mode + + @numba_njit + def conv1d(data, kernel): + return np.convolve(data, kernel, mode=mode) + + return conv1d diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 88d3f33199..c6b421d003 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -116,6 +116,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: off from pytensor.tensor import linalg from pytensor.tensor import special +from pytensor.tensor import signal # For backward compatibility from pytensor.tensor import nlinalg diff --git a/pytensor/tensor/signal/__init__.py b/pytensor/tensor/signal/__init__.py new file mode 100644 index 0000000000..577976184f --- /dev/null +++ b/pytensor/tensor/signal/__init__.py @@ -0,0 +1,4 @@ +from pytensor.tensor.signal.conv import convolve1d + + +__all__ = ("convolve1d",) diff --git a/pytensor/tensor/signal/conv.py b/pytensor/tensor/signal/conv.py new file mode 100644 index 0000000000..1152f02d8a --- /dev/null +++ b/pytensor/tensor/signal/conv.py @@ -0,0 +1,132 @@ +from typing import TYPE_CHECKING, Literal, cast + +from numpy import convolve as numpy_convolve + +from pytensor.graph import Apply, Op +from pytensor.scalar.basic import upcast +from pytensor.tensor.basic import as_tensor_variable, join, zeros +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import maximum, minimum +from pytensor.tensor.type import vector +from pytensor.tensor.variable import TensorVariable + + +if TYPE_CHECKING: + from pytensor.tensor import TensorLike + + +class Conv1d(Op): + __props__ = ("mode",) + gufunc_signature = "(n),(k)->(o)" + + def __init__(self, mode: Literal["full", "valid"] = "full"): + if mode not in ("full", "valid"): + raise ValueError(f"Invalid mode: {mode}") + self.mode = mode + + def make_node(self, in1, in2): + in1 = as_tensor_variable(in1) + in2 = as_tensor_variable(in2) + + assert in1.ndim == 1 + assert in2.ndim == 1 + + dtype = upcast(in1.dtype, in2.dtype) + + n = in1.type.shape[0] + k = in2.type.shape[0] + + if n is None or k is None: + out_shape = (None,) + elif self.mode == "full": + out_shape = (n + k - 1,) + else: # mode == "valid": + out_shape = (max(n, k) - min(n, k) + 1,) + + out = vector(dtype=dtype, shape=out_shape) + return Apply(self, [in1, in2], [out]) + + def perform(self, node, inputs, outputs): + # We use numpy_convolve as that's what scipy would use if method="direct" was passed. + # And mode != "same", which this Op doesn't cover anyway. + outputs[0][0] = numpy_convolve(*inputs, mode=self.mode) + + def infer_shape(self, fgraph, node, shapes): + in1_shape, in2_shape = shapes + n = in1_shape[0] + k = in2_shape[0] + if self.mode == "full": + shape = n + k - 1 + else: # mode == "valid": + shape = maximum(n, k) - minimum(n, k) + 1 + return [[shape]] + + def L_op(self, inputs, outputs, output_grads): + in1, in2 = inputs + [grad] = output_grads + + if self.mode == "full": + valid_conv = type(self)(mode="valid") + in1_bar = valid_conv(grad, in2[::-1]) + in2_bar = valid_conv(grad, in1[::-1]) + + else: # mode == "valid": + full_conv = type(self)(mode="full") + n = in1.shape[0] + k = in2.shape[0] + kmn = maximum(0, k - n) + nkm = maximum(0, n - k) + # We need mode="full" if k >= n else "valid" for `in1_bar` (opposite for `in2_bar`), but mode is not symbolic. + # Instead, we always use mode="full" and slice the result so it behaves like "valid" for the input that's shorter. + in1_bar = full_conv(grad, in2[::-1]) + in1_bar = in1_bar[kmn : in1_bar.shape[0] - kmn] + in2_bar = full_conv(grad, in1[::-1]) + in2_bar = in2_bar[nkm : in2_bar.shape[0] - nkm] + + return [in1_bar, in2_bar] + + +def convolve1d( + in1: "TensorLike", + in2: "TensorLike", + mode: Literal["full", "valid", "same"] = "full", +) -> TensorVariable: + """Convolve two one-dimensional arrays. + + Convolve in1 and in2, with the output size determined by the mode argument. + + Parameters + ---------- + in1 : (..., N,) tensor_like + First input. + in2 : (..., M,) tensor_like + Second input. + mode : {'full', 'valid', 'same'}, optional + A string indicating the size of the output: + - 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+M-1,). + - 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, M) - min(N, M) + 1,). + - 'same': The output is the same size as in1, centered with respect to the 'full' output. + + Returns + ------- + out: tensor_variable + The discrete linear convolution of in1 with in2. + + """ + in1 = as_tensor_variable(in1) + in2 = as_tensor_variable(in2) + + if mode == "same": + # We implement "same" as "valid" with padded `in1`. + in1_batch_shape = tuple(in1.shape)[:-1] + zeros_left = in2.shape[0] // 2 + zeros_right = (in2.shape[0] - 1) // 2 + in1 = join( + -1, + zeros((*in1_batch_shape, zeros_left), dtype=in2.dtype), + in1, + zeros((*in1_batch_shape, zeros_right), dtype=in2.dtype), + ) + mode = "valid" + + return cast(TensorVariable, Blockwise(Conv1d(mode=mode))(in1, in2)) diff --git a/tests/link/jax/signal/__init__.py b/tests/link/jax/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/jax/signal/test_conv.py b/tests/link/jax/signal/test_conv.py new file mode 100644 index 0000000000..7f448fc3e8 --- /dev/null +++ b/tests/link/jax/signal/test_conv.py @@ -0,0 +1,18 @@ +import numpy as np +import pytest + +from pytensor.tensor import dmatrix +from pytensor.tensor.signal import convolve1d +from tests.link.jax.test_basic import compare_jax_and_py + + +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) +def test_convolve1d(mode): + x = dmatrix("x") + y = dmatrix("y") + out = convolve1d(x[None], y[:, None], mode=mode) + + rng = np.random.default_rng() + test_x = rng.normal(size=(3, 5)) + test_y = rng.normal(size=(7, 11)) + compare_jax_and_py([x, y], out, [test_x, test_y]) diff --git a/tests/link/numba/signal/test_conv.py b/tests/link/numba/signal/test_conv.py new file mode 100644 index 0000000000..1a72c2df0b --- /dev/null +++ b/tests/link/numba/signal/test_conv.py @@ -0,0 +1,22 @@ +import numpy as np +import pytest + +from pytensor.tensor import dmatrix +from pytensor.tensor.signal import convolve1d +from tests.link.numba.test_basic import compare_numba_and_py + + +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) +def test_convolve1d(mode): + x = dmatrix("x") + y = dmatrix("y") + out = convolve1d(x[None], y[:, None], mode=mode) + + rng = np.random.default_rng() + test_x = rng.normal(size=(3, 5)) + test_y = rng.normal(size=(7, 11)) + # Blockwise dispatch for numba can't be run on object mode + compare_numba_and_py([x, y], out, [test_x, test_y], eval_obj_mode=False) diff --git a/tests/tensor/signal/__init__.py b/tests/tensor/signal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/tensor/signal/test_conv.py b/tests/tensor/signal/test_conv.py new file mode 100644 index 0000000000..968e408485 --- /dev/null +++ b/tests/tensor/signal/test_conv.py @@ -0,0 +1,49 @@ +from functools import partial + +import numpy as np +import pytest +from scipy.signal import convolve as scipy_convolve + +from pytensor import config, function +from pytensor.tensor import matrix, vector +from pytensor.tensor.signal.conv import convolve1d +from tests import unittest_tools as utt + + +@pytest.mark.parametrize("kernel_shape", [3, 5, 8], ids=lambda x: f"kernel_shape={x}") +@pytest.mark.parametrize("data_shape", [3, 5, 8], ids=lambda x: f"data_shape={x}") +@pytest.mark.parametrize("mode", ["full", "valid", "same"]) +def test_convolve1d(mode, data_shape, kernel_shape): + data = vector("data") + kernel = vector("kernel") + op = partial(convolve1d, mode=mode) + + rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode)))) + data_val = rng.normal(size=data_shape).astype(data.dtype) + kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype) + + fn = function([data, kernel], op(data, kernel)) + np.testing.assert_allclose( + fn(data_val, kernel_val), + scipy_convolve(data_val, kernel_val, mode=mode), + rtol=1e-6 if config.floatX == "float32" else 1e-15, + ) + utt.verify_grad(op=lambda x: op(x, kernel_val), pt=[data_val]) + + +def test_convolve1d_batch(): + x = matrix("data") + y = matrix("kernel") + out = convolve1d(x, y) + + rng = np.random.default_rng(38) + x_test = rng.normal(size=(2, 8)).astype(x.dtype) + y_test = x_test[::-1] + + res = out.eval({x: x_test, y: y_test}) + # Second entry of x, y are just y, x respectively, + # so res[0] and res[1] should be identical. + rtol = 1e-6 if config.floatX == "float32" else 1e-15 + res_np = np.convolve(x_test[0], y_test[0]) + np.testing.assert_allclose(res[0], res_np, rtol=rtol) + np.testing.assert_allclose(res[1], res_np, rtol=rtol)