From ecf4e9b3d4502306a7162bf389705900712f177d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:09:21 +0100 Subject: [PATCH 01/14] Handle no-op Subtensors in rewrites --- pytensor/tensor/rewriting/subtensor.py | 73 +++++++++++++++--------- tests/tensor/rewriting/test_subtensor.py | 13 +++++ tests/tensor/test_subtensor.py | 3 +- 3 files changed, 62 insertions(+), 27 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 4e80d3bb30..e1174a7e8d 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -336,35 +336,46 @@ def local_subtensor_of_dot(fgraph, node): @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ - Remove Subtensor of the form X[0, :] -> X[0] + Remove Subtensor of the form: + 1. X[0, :] -> X[0] + 2. X[:] -> X + """ - if isinstance(node.op, Subtensor): - slices = get_idx_list(node.inputs, node.op.idx_list) - last_slice = len(slices) - for s in slices[::-1]: - # check if slice and then check slice indices - if ( - isinstance(s, slice) - and s.start is None - and s.stop is None - and ( - s.step is None - or extract_constant(s.step, only_process_constants=True) == 1 - ) - ): - last_slice -= 1 - else: - break - # check if we removed something - if last_slice < len(slices): - subtens = Subtensor(slices[:last_slice]) - sl_ins = get_slice_elements( - slices[:last_slice], lambda x: isinstance(x, Variable) + idxs = get_idx_list(node.inputs, node.op.idx_list) + + if not idxs: + return [node.inputs[0]] + + last_useless_slice = len(idxs) + for s in idxs[::-1]: + # check if slice and then check slice indices + if ( + isinstance(s, slice) + and s.start is None + and s.stop is None + and ( + s.step is None + or extract_constant(s.step, only_process_constants=True) == 1 + ) + ): + last_useless_slice -= 1 + else: + break + # check if we removed something + if last_useless_slice < len(idxs): + new_idxs = idxs[:last_useless_slice] + if new_idxs: + new_subtensor = Subtensor(new_idxs) + new_subtensor_inputs = get_slice_elements( + new_idxs, lambda x: isinstance(x, Variable) ) - out = subtens(node.inputs[0], *sl_ins) + out = new_subtensor(node.inputs[0], *new_subtensor_inputs) # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) return [out] + else: + # Subtensor is not needed at all + return [node.inputs[0]] # fast_compile to allow opt subtensor(cast{float32}(make_vector)) @@ -747,7 +758,13 @@ def local_subtensor_make_vector(fgraph, node): make_vector_op = x.owner.op if isinstance(node.op, Subtensor): - (idx,) = node.op.idx_list + idxs = node.op.idx_list + + # Subtensor has no indexes, return make_vector + if not idxs: + return [x] + + (idx,) = idxs if isinstance(idx, (aes.ScalarType, TensorType)): old_idx, idx = idx, node.inputs[1] @@ -903,7 +920,11 @@ def local_set_to_inc_subtensor(fgraph, node): @node_rewriter([Subtensor]) def local_useless_subtensor(fgraph, node): """Remove `Subtensor` if it takes the full input.""" - # This optimization needs ShapeOpt and fgraph.shape_feature + + if not node.op.idx_list: + return [node.inputs[0]] + + # The more elaborate optimization needs ShapeOpt and fgraph.shape_feature if not hasattr(fgraph, "shape_feature"): return diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index a5a643d0da..4a0016136f 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config +from pytensor.graph import FunctionGraph from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -21,6 +22,7 @@ from pytensor.tensor.math import Dot, add, dot, exp, sqr from pytensor.tensor.rewriting.subtensor import ( local_replace_AdvancedSubtensor, + local_subtensor_make_vector, local_subtensor_shape_constant, ) from pytensor.tensor.shape import ( @@ -764,6 +766,17 @@ def test_stack_trace(self): f = function([x, y, z], v_subtensor, mode=mode) assert check_stack_trace(f, ops_to_check="all") + def test_empty_subtensor(self): + x, y = lscalars("xy") + v = make_vector(x, y) + out = v[()] + + fgraph = FunctionGraph(outputs=[out], clone=False) + node = fgraph.outputs[0].owner + assert isinstance(node.op, Subtensor) + + assert local_subtensor_make_vector.transform(fgraph, node) == [v] + class TestLocalSubtensorLift: def test_basic(self): diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 188c959bbc..b5952cc49e 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -389,7 +389,8 @@ def test_0_dims(self): t = Subtensor([])(n) assert isinstance(t.owner.op, Subtensor) self.eval_output_and_check( - t, mode=self.mode.excluding("local_useless_subtensor") + t, + mode=self.mode.excluding("local_useless_subtensor", "local_useless_slice"), ) def test_err_invalid_2(self): From 0f1abc77ff2f0279584bed127257cde41d4dd3fa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 6 Dec 2023 10:18:33 +0000 Subject: [PATCH 02/14] Do not pickle functions of Blockwise --- pytensor/tensor/blockwise.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 96357f59f8..e7c34769e3 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from copy import copy from typing import Any, Optional, cast import numpy as np @@ -87,6 +88,11 @@ def __init__( self._gufunc = None super().__init__(**kwargs) + def __getstate__(self): + d = copy(self.__dict__) + d["_gufunc"] = None + return d + def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: core_input_types = [] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): From b155a00f2543af5924b3374ee56d3ba30a98b642 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:45:54 +0100 Subject: [PATCH 03/14] Avoid creating useless squeezes and expand_dims --- pytensor/tensor/extra_ops.py | 4 ++++ pytensor/tensor/shape.py | 6 ++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 8f4158696a..b9bcceb2db 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -603,6 +603,10 @@ def squeeze(x, axis=None): except np.AxisError: raise np.AxisError(axis, ndim=_x.ndim) + if not axis: + # Nothing to do + return _x + return _x.dimshuffle([i for i in range(_x.ndim) if i not in axis]) diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1d8efa02c5..0d8dea8a2e 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -868,7 +868,8 @@ def shape_padleft(t, n_ones=1): """ _t = at.as_tensor_variable(t) - + if n_ones == 0: + return _t pattern = ["x"] * n_ones + list(range(_t.type.ndim)) return _t.dimshuffle(pattern) @@ -884,7 +885,8 @@ def shape_padright(t, n_ones=1): """ _t = at.as_tensor_variable(t) - + if n_ones == 0: + return _t pattern = list(range(_t.type.ndim)) + ["x"] * n_ones return _t.dimshuffle(pattern) From d0b66a9f85557756d44099b492b43b178963eb71 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:48:56 +0100 Subject: [PATCH 04/14] Remove assert in local_useless_alloc Rewrite was already tagged as "shape_unsafe" --- pytensor/tensor/rewriting/basic.py | 14 +++---------- tests/tensor/rewriting/test_basic.py | 31 +++++++++++++++++++++------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 021660d8e0..98f6d68dab 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -67,9 +67,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_arrays -from pytensor.tensor.math import Sum, add -from pytensor.tensor.math import all as at_all -from pytensor.tensor.math import eq +from pytensor.tensor.math import Sum, add, eq from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.sort import TopKOp from pytensor.tensor.type import DenseTensorType, TensorType @@ -266,6 +264,7 @@ def local_elemwise_alloc(fgraph, node): introduces them as a canonicalization of `Alloc`'s with leading broadcastable dimensions. """ + # This is handled by local_alloc_unary if len(node.inputs) == 1: return None @@ -465,14 +464,7 @@ def local_useless_alloc(fgraph, node): inp.type.dtype == output.type.dtype and inp.type.broadcastable == output.type.broadcastable ): - if inp.ndim == 0: - return [inp] - else: - return [ - Assert("Shapes must be equal")( - inp, at_all(eq(inp.shape, node.inputs[1:])) - ) - ] + return [inp] @register_specialize diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 0e5c618ba0..cd5d3cc255 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -272,21 +272,36 @@ class TestLocalCanonicalizeAlloc: def setup_method(self): self.rng = np.random.default_rng(utt.fetch_seed()) - def test_inconsistent_shared(self): + @pytest.mark.parametrize("shape_unsafe", (True, False)) + def test_inconsistent_shared(self, shape_unsafe): # These shapes don't match! x = shared(self.rng.standard_normal((3, 7))) a = at.alloc(x, 6, 7) assert a.owner and isinstance(a.owner.op, Alloc) - f = function([], a, mode=rewrite_mode) + mode = rewrite_mode if shape_unsafe else rewrite_mode.excluding("shape_unsafe") + f = function([], a, mode=mode) - # The rewrite should then be applied, and remove Alloc - assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) - assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()) - - with pytest.raises(AssertionError): - f() + has_alloc = any( + isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort() + ) + if shape_unsafe: + assert not has_alloc + # Error raised by SpecifyShape that is introduced due to static shape inference + with pytest.raises( + AssertionError, + match="SpecifyShape: dim 0 of input has shape 3, expected 6.", + ): + f() + else: + assert has_alloc + # Error raised by Alloc Op + with pytest.raises( + ValueError, + match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", + ): + f() good_x_val = self.rng.standard_normal((6, 7)) x.set_value(good_x_val) From d15747d89e412ad1ca59d3e71c20d733eaff9881 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 6 Dec 2023 16:29:19 +0000 Subject: [PATCH 05/14] Apply useless blockwise rewrite when there are only dummy batch dims Also extend eager rewrite to more Ops The Blockwise MatrixInverse grad test became more sensitive in float32, because desired stabilization rewrites (mainly `inv_as_solve`) that target Dot of Blockwise{MatrixInverse} are now triggered in the default blockwise grad but not in the non-default non-blockwise grad --- pytensor/tensor/blockwise.py | 10 +++--- pytensor/tensor/rewriting/blockwise.py | 43 ++++++++++++++++++++++---- tests/tensor/test_blockwise.py | 2 +- 3 files changed, 43 insertions(+), 12 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index e7c34769e3..8c0c1587a9 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -163,8 +163,8 @@ def make_node(self, *inputs): return Apply(self, batched_inputs, batched_outputs) - def _batch_ndim_from_outputs(self, outputs: Sequence[TensorVariable]) -> int: - return cast(int, outputs[0].type.ndim - len(self.outputs_sig[0])) + def batch_ndim(self, node: Apply) -> int: + return cast(int, node.outputs[0].type.ndim - len(self.outputs_sig[0])) def infer_shape( self, fgraph, node, input_shapes @@ -172,7 +172,7 @@ def infer_shape( from pytensor.tensor import broadcast_shape from pytensor.tensor.shape import Shape_i - batch_ndims = self._batch_ndim_from_outputs(node.outputs) + batch_ndims = self.batch_ndim(node) core_dims: dict[str, Any] = {} batch_shapes = [] for input_shape, sig in zip(input_shapes, self.inputs_sig): @@ -278,7 +278,7 @@ def L_op(self, inputs, outs, ograds): return new_rval # Sum out the broadcasted dimensions - batch_ndims = self._batch_ndim_from_outputs(outs) + batch_ndims = self.batch_ndim(outs[0].owner) batch_shape = outs[0].type.shape[:batch_ndims] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): if isinstance(rval[i].type, (NullType, DisconnectedType)): @@ -320,7 +320,7 @@ def core_func(*inner_inputs): return self._gufunc def _check_runtime_broadcast(self, node, inputs): - batch_ndim = self._batch_ndim_from_outputs(node.outputs) + batch_ndim = self.batch_ndim(node) for dims_and_bcast in zip( *[ diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 101aeec368..69cddc595d 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -2,9 +2,15 @@ from pytensor.graph import node_rewriter from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in +from pytensor.tensor.basic import Alloc, ARange, shape_padleft from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.math import _matrix_matrix_matmul -from pytensor.tensor.rewriting.basic import register_canonicalize +from pytensor.tensor.math import Dot +from pytensor.tensor.rewriting.basic import ( + register_canonicalize, + register_specialize, + register_stabilize, +) +from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedSubtensor, Subtensor @node_rewriter([Blockwise]) @@ -29,8 +35,17 @@ def local_useless_unbatched_blockwise(fgraph, node): op = node.op inputs = node.inputs - if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0: - return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs) + batch_ndims = node.op.batch_ndim(node) + if all(all(inp.type.broadcastable[:batch_ndims]) for inp in inputs): + if batch_ndims: + # Remove dummy batch dims + axis = tuple(range(batch_ndims)) + inputs = [inp.squeeze(axis) for inp in inputs] + new_outs = op.core_op.make_node(*inputs).outputs + if batch_ndims: + # Reintroduce dummy batch dims + new_outs = [shape_padleft(out, batch_ndims) for out in new_outs] + return copy_stack_trace(node.outputs, new_outs) # We register this rewrite late, so that other rewrites need only target Blockwise Ops @@ -46,6 +61,22 @@ def local_useless_unbatched_blockwise(fgraph, node): # Avoid redundant cases early on for Ops whose default form is not Blockwised @register_canonicalize -@node_rewriter(tracks=[_matrix_matrix_matmul]) +@register_stabilize +@register_specialize +@node_rewriter(tracks=[Blockwise]) def local_eager_useless_unbatched_blockwise(fgraph, node): - return local_useless_unbatched_blockwise.fn(fgraph, node) + if isinstance( + node.op.core_op, + ( + # Many Dot-related rewrites (e.g., all of BlasOpt) happen before specialize + Dot, + # These Ops can't always be trivially vectorized at runtime, + # Since their inputs may imply non-rectangular shapes. + Alloc, + ARange, + Subtensor, + AdvancedSubtensor, + AdvancedIncSubtensor, + ), + ): + return local_useless_unbatched_blockwise.fn(fgraph, node) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index a0533143ed..06045f11f6 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -293,7 +293,7 @@ def test_grad(self): pt_out, np_out, rtol=1e-7 if config.floatX == "float64" else 1e-5, - atol=1e-6 if config.floatX == "float64" else 1e-5, + atol=1e-6 if config.floatX == "float64" else 1e-4, ) From 54310803dd208bff174b2dfb9b2e0d995d6ed3f1 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 17:17:50 +0100 Subject: [PATCH 06/14] Faster perform method for matmul Also return matmul for respective vectorize of dot, to avoid creating redundant Blockwise Ops --- pytensor/tensor/blockwise.py | 17 ++++++++++++++--- pytensor/tensor/math.py | 18 ++++++++++++++++-- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 8c0c1587a9..d4cd5152e7 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -58,6 +58,7 @@ def __init__( core_op: Op, signature: Optional[str] = None, name: Optional[str] = None, + gufunc_spec: Optional[tuple[str, int, int]] = None, **kwargs, ): """ @@ -69,7 +70,12 @@ def __init__( signature Generalized universal function signature, e.g., (m,n),(n)->(m) for vectorized matrix-vector multiplication - + gufunc: tuple, Optional + Tuple containing: + 1. String import path for a numpy/scipy function (e.g., "numpy.matmul", "scipy.special.softmax") + that implements the blockwised operation of the scalar op. + 2 Number of inputs of the function + 3 Number of outputs of the function """ if isinstance(core_op, Blockwise): raise TypeError("Core Op is already a Blockwise") @@ -85,6 +91,7 @@ def __init__( self.signature = signature self.name = name self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) + self.gufunc_spec = gufunc_spec self._gufunc = None super().__init__(**kwargs) @@ -297,10 +304,14 @@ def L_op(self, inputs, outs, ograds): return rval def _create_gufunc(self, node): - if hasattr(self.core_op, "gufunc_spec"): - self._gufunc = import_func_from_string(self.core_op.gufunc_spec[0]) + gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) + + if gufunc_spec is not None: + self._gufunc = import_func_from_string(gufunc_spec[0]) if self._gufunc: return self._gufunc + else: + raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") n_outs = len(self.outputs_sig) core_node = self._create_dummy_core_node(node.inputs) diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 0f035272af..7d1e32ba21 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -9,6 +9,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import Generic @@ -25,7 +26,7 @@ stack, switch, ) -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.type import ( @@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") +_matrix_matrix_matmul = Blockwise( + _dot, + signature="(m,k),(k,n)->(m,n)", + gufunc_spec=("numpy.matmul", 2, 1), +) def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): @@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None return out +@_vectorize_node.register(Dot) +def vectorize_node_to_matmul(op, node, batched_x, batched_y): + old_x, old_y = node.inputs + if old_x.type.ndim == 2 and old_y.type.ndim == 2: + return matmul(batched_x, batched_y).owner + else: + return vectorize_node_fallback(op, node, batched_x, batched_y) + + __all__ = [ "max_and_argmax", "max", From 1e687adb711eb777d006946e963cfbbdd25a1288 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 6 Dec 2023 09:55:02 +0000 Subject: [PATCH 07/14] Expand batched_vector_b_solve_to_matrix rewrite It now supports an arbitrary number of batched dimensions of b, by raveling them together --- pytensor/tensor/rewriting/linalg.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 475e454037..bc3eef6fca 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -138,14 +138,12 @@ def generic_solve_to_solve_triangular(fgraph, node): ] -@register_stabilize @register_specialize @node_rewriter([Blockwise]) def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T `a` must have no batched dimensions, while `b` can have arbitrary batched dimensions. - Only the last two dimensions of `b` and the output are swapped. """ core_op = node.op.core_op @@ -175,8 +173,17 @@ def batched_vector_b_solve_to_matrix_b_solve(fgraph, node): new_core_op = type(core_op)(**props) matrix_b_solve = Blockwise(new_core_op) + # Ravel any batched dims + original_b_shape = tuple(b.shape) + if len(original_b_shape) > 2: + b = b.reshape((-1, original_b_shape[-1])) + # Apply the rewrite - new_solve = _T(matrix_b_solve(a, _T(b))) + new_solve = matrix_b_solve(a, b.T).T + + # Unravel any batched dims + if len(original_b_shape) > 2: + new_solve = new_solve.reshape(original_b_shape) old_solve = node.outputs[0] copy_stack_trace(old_solve, new_solve) From 820928ffb463488eb3eca4d76bb06a906c5064ce Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 15:19:01 +0100 Subject: [PATCH 08/14] Simplify BatchedDot implementation The Op now always expects rank 3 inputs, and any dimshuffles are added explicitly by the helper function --- pytensor/link/jax/dispatch/nlinalg.py | 4 +- pytensor/link/numba/dispatch/basic.py | 2 + pytensor/tensor/blas.py | 225 +++++++++----------------- tests/link/jax/test_nlinalg.py | 9 -- tests/link/numba/test_basic.py | 16 +- 5 files changed, 88 insertions(+), 168 deletions(-) diff --git a/pytensor/link/jax/dispatch/nlinalg.py b/pytensor/link/jax/dispatch/nlinalg.py index 6f6467cff7..567b4407ab 100644 --- a/pytensor/link/jax/dispatch/nlinalg.py +++ b/pytensor/link/jax/dispatch/nlinalg.py @@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs): def batched_dot(a, b): if a.shape[0] != b.shape[0]: raise TypeError("Shapes must match in the 0-th dimension") - if a.ndim == 2 or b.ndim == 2: - return jnp.einsum("n...j,nj...->n...", a, b) - return jnp.einsum("nij,njk->nik", a, b) + return jnp.matmul(a, b) return batched_dot diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ab7054ccaf..9c9c800b92 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs): @numba_njit def batched_dot(x, y): + # Numba does not support 3D matmul + # https://github.com/numba/numba/issues/3804 shape = x.shape[:-1] + y.shape[2:] z0 = np.empty(shape, dtype=dtype) for i in range(z0.shape[0]): diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 78a80bd323..301cc5d199 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -98,10 +98,11 @@ from pytensor.printing import FunctionPrinter, pprint from pytensor.scalar import bool as bool_t from pytensor.tensor import basic as at +from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import add, mul, neg, sub -from pytensor.tensor.shape import specify_broadcastable +from pytensor.tensor.shape import shape_padright, specify_broadcastable from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor from pytensor.utils import memoize @@ -1637,48 +1638,53 @@ def c_code_cache_version(self): class BatchedDot(COp): """ - Computes the batched dot product of two variables: + Computes a batch matrix-matrix dot with tensor3 variables batched_dot(a, b)[i] = dot(a[i], b[i]) """ __props__ = () + gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" - def make_node(self, *inputs): - inputs = list(map(at.as_tensor_variable, inputs)) + def make_node(self, x, y): + x = at.as_tensor_variable(x) + y = at.as_tensor_variable(y) - if any(not isinstance(i.type, DenseTensorType) for i in inputs): + if not ( + isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) + ): raise NotImplementedError("Only dense tensor types are supported") - if len(inputs) != 2: - raise TypeError(f"Two arguments required, but {len(inputs)} given.") - if inputs[0].ndim not in (2, 3): + if not (x.type.ndim == 3 and y.type.ndim == 3): raise TypeError( - "Input 0 (0-indexed)" - f" must have ndim of 2 or 3, {int(inputs[0].ndim)} given. Consider" - " calling batched_dot instead." - ) - if inputs[1].ndim not in (2, 3): - raise TypeError( - "Input 1 (0-indexed)" - f" must have ndim of 2 or 3, {int(inputs[1].ndim)} given. Consider" - " calling batched_dot instead." + f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. " + "Consider calling batched_dot instead." ) - dtype = pytensor.scalar.upcast(*[input.type.dtype for input in inputs]) - # upcast inputs to common dtype if needed - upcasted_inputs = [at.cast(input, dtype) for input in inputs] - out_shape = ( - ( - 1 - if inputs[0].type.shape[0] == 1 or inputs[1].type.shape[0] == 1 - else None, - ) - + inputs[0].type.shape[1:-1] - + inputs[1].type.shape[2:] - ) - out_shape = tuple(1 if s == 1 else None for s in out_shape) - return Apply(self, upcasted_inputs, [tensor(dtype=dtype, shape=out_shape)]) + def extract_static_dim(dim_x, dim_y): + dims = {dim_x, dim_y} - {None} + if len(dims) > 1: + # BatchedDot doesn't allow broadcasting + raise ValueError( + f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}" + ) + elif not dims: + return None + else: + return dims.pop() + + x_batch_dim, x_row_dim, x_sum_dim = x.type.shape + y_batch_dim, y_sum_dim, y_col_dim = y.type.shape + batch_dim = extract_static_dim(x_batch_dim, y_batch_dim) + # Raise if static sum dimensions do not match + _ = extract_static_dim(x_sum_dim, y_sum_dim) + out_shape = (batch_dim, x_row_dim, y_col_dim) + + # Change dtype if needed + dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) + x, y = at.cast(x, dtype), at.cast(y, dtype) + out = tensor(dtype=dtype, shape=out_shape) + return Apply(self, [x, y], [out]) def perform(self, node, inp, out): x, y = inp @@ -1690,11 +1696,7 @@ def perform(self, node, inp, out): f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]." ) - shape = self.infer_shape(None, node, [i.shape for i in inp])[0] - dtype = node.outputs[0].dtype - z0 = z[0] = np.empty(shape, dtype=dtype) - for i in range(z0.shape[0]): - z0[i] = np.dot(x[i], y[i]) + z[0] = np.matmul(x, y) def c_support_code(self, **kwargs): batch_gemm_defn = """ @@ -1792,14 +1794,6 @@ def c_lib_dirs(self, **kwargs): def c_header_dirs(self, **kwargs): return ldflags(libs=False, include_dir=True) - def c_code_cleanup(self, node, name, inputs, outputs, sub): - return """ - // clean up views - Py_XDECREF(xs); xs = 0; - Py_XDECREF(ys); ys = 0; - Py_XDECREF(zs); zs = 0; - """ - def c_code(self, node, name, inp, out, sub): _x, _y = inp (_z,) = out @@ -1832,12 +1826,11 @@ def contiguous(var, ndim): ) # generate code to allocate output based on runtime input shapes - z_dims = [f"PyArray_DIMS({_x})[0]"] - if x_ndim == 3: - z_dims.append(f"PyArray_DIMS({_x})[1]") - if y_ndim == 3: - z_dims.append(f"PyArray_DIMS({_y})[2]") - assert len(z_dims) == z_ndim + z_dims = [ + f"PyArray_DIMS({_x})[0]", + f"PyArray_DIMS({_x})[1]", + f"PyArray_DIMS({_y})[2]", + ] z_shape_correct = " && ".join( "PyArray_DIMS(%s)[%i] == %s" % (_z, i, dim) for i, dim in enumerate(z_dims) @@ -1880,76 +1873,26 @@ def contiguous(var, ndim): ) contiguate = "\n".join(contiguate) - def c_dimshuffle(newname, oldname, shape): - _fail = fail - _shape = ", ".join( - "1" if axis is None else "PyArray_DIMS(%s)[%i]" % (oldname, axis) - for axis in shape - ) - return ( - """{ - npy_intp dims[3] = {%(_shape)s}; - PyArray_Dims newshape = {dims, 3}; - %(newname)s = (PyArrayObject*)PyArray_Newshape(%(oldname)s, &newshape, NPY_ANYORDER); - if (!%(newname)s) - %(_fail)s - // make sure we didn't accidentally copy - assert(PyArray_DATA(%(oldname)s) == PyArray_DATA(%(newname)s)); - }""" - % locals() - ) - - # create tensor3 views for any of x, y, z that are not tensor3, so that - # we only need to implement the tensor3-tensor3 batched dot product. - # xs, ys and zs will point to these views, or to the original array if - # it was already tensor3. - # in the latter case, we artificially increase the reference count of - # the original array so that the c_code_cleanup method can decref them - # all indiscriminately. - upcast = [] - if x_ndim == 3: - upcast.append("xs = %(_x)s; Py_XINCREF(xs);") - elif x_ndim == 2: - upcast.append(c_dimshuffle("xs", _x, (0, None, 1))) - if y_ndim == 3: - upcast.append("ys = %(_y)s; Py_XINCREF(ys);") - elif y_ndim == 2: - upcast.append(c_dimshuffle("ys", _y, (0, 1, None))) - if z_ndim == 3: - upcast.append("zs = %(_z)s; Py_XINCREF(zs);") - else: - upcast.append( - c_dimshuffle( - "zs", - _z, - (0, None if x_ndim == 2 else 1, None if y_ndim == 2 else 1), - ) - ) - upcast = "\n".join(upcast) % locals() - return ( """ int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes - // xs, ys, zs will point to views onto %(_x)s, %(_y)s, %(_z)s - PyArrayObject *xs = 0, *ys = 0, *zs = 0; - - if (PyArray_NDIM(%(_x)s) != %(x_ndim)s) { + if (PyArray_NDIM(%(_x)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(x) != %(x_ndim)s. rank(x) is %%d.", + "rank(x) != 3. rank(x) is %%d.", PyArray_NDIM(%(_x)s)); %(fail)s; } - if (PyArray_NDIM(%(_y)s) != %(y_ndim)s) { + if (PyArray_NDIM(%(_y)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(y) != %(y_ndim)s. rank(y) is %%d.", + "rank(y) != 3. rank(y) is %%d.", PyArray_NDIM(%(_y)s)); %(fail)s; } - if (%(_z)s && PyArray_NDIM(%(_z)s) != %(z_ndim)s) { + if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) { PyErr_Format(PyExc_NotImplementedError, - "rank(z) != %(z_ndim)s. rank(z) is %%d.", + "rank(z) != 3. rank(z) is %%d.", PyArray_NDIM(%(_z)s)); %(fail)s; } @@ -1958,36 +1901,32 @@ def c_dimshuffle(newname, oldname, shape): %(allocate)s // reallocate any noncontiguous arrays or arrays with invalid strides %(contiguate)s - // add dims to make sure everything is tensor3 - %(upcast)s - // from here on, use xs, ys and zs as they are tensor3 and share memory - // with the original %(_x)s, %(_y)s and %(_z)s arrays. - if ((PyArray_DESCR(xs)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(xs)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(ys)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(ys)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(zs)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(zs)->type_num != NPY_FLOAT)) + if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT)) {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} - if ((PyArray_DESCR(xs)->type_num != PyArray_DESCR(ys)->type_num) - ||(PyArray_DESCR(xs)->type_num != PyArray_DESCR(zs)->type_num)) + if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) + ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num)) { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } switch (type_num) { case NPY_FLOAT: - if (batch_gemm(sgemm_, type_size, xs, ys, zs)) { + if (batch_gemm(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { %(fail)s; } break; case NPY_DOUBLE: - if (batch_gemm(dgemm_, type_size, xs, ys, zs)) { + if (batch_gemm(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { %(fail)s; } break; @@ -1999,32 +1938,14 @@ def c_dimshuffle(newname, oldname, shape): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (4, blas_header_version()) + return (5, blas_header_version()) def grad(self, inp, grads): x, y = inp (gz,) = grads - xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim - # grad is a vector, so x is a matrix and y is a matrix - if gdim == 1: - xgrad = gz.dimshuffle(0, "x") * y - ygrad = gz.dimshuffle(0, "x") * x - - # x is a matrix, y is a tensor3, grad is a matrix - elif xdim == 2 and ydim == 3: - xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = x.dimshuffle(0, 1, "x") * gz.dimshuffle(0, "x", 1) - - # x is a tensor3, y is a matrix, grad is a matrix - elif xdim == 3 and ydim == 2: - xgrad = gz.dimshuffle(0, 1, "x") * y.dimshuffle(0, "x", 1) - ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) - - # x is a tensor3, y is a tensor3, grad is a tensor3 - elif xdim == ydim == 3: - xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) + xgrad = batched_dot(gz, y.dimshuffle(0, 2, 1)) + ygrad = batched_dot(x.dimshuffle(0, 2, 1), gz) # If x or y contain broadcastable dimensions but only one of # them know that a matching dimensions is broadcastable, the @@ -2105,6 +2026,7 @@ def R_op(self, inputs, eval_points): + " to BatchedDot.R_op should have the same shape, but " f"their shapes are {input_values[i].shape} and {eval_point_values[i].shape}, respectively" ) + if eval_points[0]: t1 = self(eval_points[0], inputs[1]) if eval_points[1]: @@ -2118,9 +2040,6 @@ def R_op(self, inputs, eval_points): return [t2] def infer_shape(self, fgraph, node, shapes): - for shape_ in shapes: - if len(shape_) not in (2, 3): - raise NotImplementedError() xshp, yshp = shapes return [xshp[:-1] + yshp[2:]] @@ -2157,14 +2076,24 @@ def batched_dot(a, b): elif b.ndim == 0: raise TypeError("b must have at least one (batch) axis") elif a.ndim == 1: - return a.dimshuffle(*([0] + ["x"] * (b.ndim - 1))) * b + return shape_padright(a, (b.ndim - 1)) * b elif b.ndim == 1: - return a * b.dimshuffle(*([0] + ["x"] * (a.ndim - 1))) + return a * shape_padright(b, (a.ndim - 1)) elif a.ndim > 3 or b.ndim > 3: return batched_tensordot(a, b, [[a.ndim - 1], [np.maximum(1, b.ndim - 2)]]) else: - # avoid circular import - return _batched_dot(a, b) + # If either a or b is a batched vector, expand dims and later squeeze them + expanded_axis = [] + if a.ndim == 2: + a = expand_dims(a, axis=1) + expanded_axis.append(1) + if b.ndim == 2: + b = expand_dims(b, axis=2) + expanded_axis.append(2) + out = _batched_dot(a, b) + if expanded_axis: + out = out.squeeze(axis=expanded_axis) + return out def batched_tensordot(x, y, axes=2): diff --git a/tests/link/jax/test_nlinalg.py b/tests/link/jax/test_nlinalg.py index 98bfbb610c..5aec8f88c2 100644 --- a/tests/link/jax/test_nlinalg.py +++ b/tests/link/jax/test_nlinalg.py @@ -43,15 +43,6 @@ def test_jax_BatchedDot(): with pytest.raises(TypeError): pytensor_jax_fn(*inputs) - # matrix . matrix - a = matrix("a") - a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3)) - b = matrix("b") - b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3)) - out = at_blas.BatchedDot()(a, b) - fgraph = FunctionGraph([a, b], [out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - def test_jax_basic_multiout(): rng = np.random.default_rng(213234) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 7632aa6d33..92ab879e5c 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -843,23 +843,23 @@ def test_Softplus(x, exc): [ ( set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), None, ), ( set_test_value( - at.dmatrix(), - rng.random(size=(3, 3)).astype("float64"), + at.dtensor3(), + rng.random(size=(2, 3, 3)).astype("float64"), ), set_test_value( - at.lmatrix(), - rng.poisson(size=(3, 3)).astype("int64"), + at.ltensor3(), + rng.poisson(size=(2, 3, 3)).astype("int64"), ), None, ), From 8c58a2e407046c57dc695528283b0fd95d74cc22 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 17:08:08 +0100 Subject: [PATCH 09/14] Specialize matmul to batched dot --- pytensor/tensor/rewriting/blas.py | 41 +++++++++++++++++++----- tests/tensor/rewriting/test_blas.py | 48 +++++++++++++++++++++++++++++ tests/tensor/test_blockwise.py | 8 ++++- 3 files changed, 88 insertions(+), 9 deletions(-) create mode 100644 tests/tensor/rewriting/test_blas.py diff --git a/pytensor/tensor/rewriting/blas.py b/pytensor/tensor/rewriting/blas.py index a310cb5837..7434fd7e1c 100644 --- a/pytensor/tensor/rewriting/blas.py +++ b/pytensor/tensor/rewriting/blas.py @@ -59,6 +59,8 @@ import numpy as np +from pytensor.tensor.rewriting.basic import register_specialize + try: import numpy.__config__ # noqa @@ -79,12 +81,12 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError -from pytensor.printing import debugprint from pytensor.tensor import basic as at from pytensor.tensor.blas import ( Dot22, _dot22, _dot22scalar, + batched_dot, gemm_inplace, gemm_no_inplace, gemv_inplace, @@ -94,7 +96,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.math import Dot, add, mul, neg, sub +from pytensor.tensor.math import Dot, _matrix_matrix_matmul, add, mul, neg, sub from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift from pytensor.tensor.type import ( DenseTensorType, @@ -899,9 +901,32 @@ def local_dot22_to_dot22scalar(fgraph, node): ) -# from opt import register_specialize, register_canonicalize -# @register_specialize -@node_rewriter([sub, add]) -def local_print_as_we_go_along(fgraph, node): - if node.op in (sub, add): - debugprint(node) +@register_specialize +@node_rewriter([_matrix_matrix_matmul]) +def specialize_matmul_to_batched_dot(fgraph, node): + """Rewrite Matmul (Blockwise matrix-matrix) without implicit broadcasted batched dimension as BatchedDot. + + TODO: Do the same for Blockwise BatchedDot + """ + x, y = node.inputs + + # BatchedDot does not allow implicit broadcasting of the batch dimensions + # We do not want to explicitly broadcast as it may result in huge arrays + if x.type.broadcastable[:-2] != y.type.broadcastable[:-2]: + return None + + x_shape = tuple(x.shape) + y_shape = tuple(y.shape) + if len(x_shape) > 3: + # If we have more than one batch dim, ravel it + x = x.reshape((-1, x_shape[-2], x_shape[-1])) + y = y.reshape((-1, y_shape[-2], y_shape[-1])) + + new_out = batched_dot(x, y) + + if len(x_shape) > 3: + # And then unravel it + new_out = new_out.reshape((*x_shape[:-2], x_shape[-2], y_shape[-1])) + + copy_stack_trace(node.outputs, [new_out]) + return [new_out] diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py new file mode 100644 index 0000000000..efd18c3831 --- /dev/null +++ b/tests/tensor/rewriting/test_blas.py @@ -0,0 +1,48 @@ +import numpy as np +import pytest + +from pytensor import function +from pytensor.compile import get_default_mode +from pytensor.tensor import matmul, tensor, vectorize +from pytensor.tensor.blas import BatchedDot +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot + + +@pytest.mark.parametrize("valid_case", (True, False)) +def test_specialize_matmul_to_batched_dot(valid_case): + signature = BatchedDot.gufunc_signature + rewrite = specialize_matmul_to_batched_dot.__name__ + + def core_pt(x, y): + return matmul(x, y) + + def core_np(x, y): + return np.matmul(x, y) + + x = tensor(shape=(7, 5, 3, 3)) + if valid_case: + y = tensor(shape=(7, 5, 3, 3)) + else: + y = tensor(shape=(5, 3, 3)) + + vectorize_pt = function( + [x, y], + vectorize(core_pt, signature=signature)(x, y), + mode=get_default_mode().including(rewrite), + ) + blocwkise_node = any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + if valid_case: + assert not blocwkise_node + else: + assert blocwkise_node + + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + y_test = np.random.normal(size=y.type.shape).astype(y.type.dtype) + vectorize_np = np.vectorize(core_np, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, y_test), + vectorize_np(x_test, y_test), + ) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 06045f11f6..9500fec5f8 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -6,6 +6,7 @@ import pytensor from pytensor import config, function +from pytensor.compile import get_mode from pytensor.gradient import grad from pytensor.graph import Apply, Op from pytensor.graph.replace import vectorize_node @@ -13,6 +14,7 @@ from pytensor.tensor import diagonal, log, tensor from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.nlinalg import MatrixInverse +from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular from pytensor.tensor.utils import _parse_gufunc_signature @@ -45,7 +47,11 @@ def check_blockwise_runtime_broadcasting(mode): b = tensor("b", shape=(None, 5, 3)) out = a @ b - fn = function([a, b], out, mode=mode) + fn = function( + [a, b], + out, + mode=get_mode(mode).excluding(specialize_matmul_to_batched_dot.__name__), + ) assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) for valid_test_values in [ From b9f6d1b87de6340c0b7a863241586c955269511f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:45:44 +0100 Subject: [PATCH 10/14] Vectorize ExtractDiag Also adds better static shapes --- pytensor/tensor/basic.py | 21 ++++++++++++++++++++- tests/tensor/test_basic.py | 26 +++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 434f8b85e7..498803c0f7 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -26,6 +26,7 @@ from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.graph.rewriting.db import EquilibriumDB from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp @@ -3497,10 +3498,17 @@ def make_node(self, x): if x.ndim < 2: raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x) + + out_shape = [ + st_dim + for i, st_dim in enumerate(x.type.shape) + if i not in (self.axis1, self.axis2) + ] + [None] + return Apply( self, [x], - [x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()], + [x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()], ) def perform(self, node, inputs, outputs): @@ -3601,6 +3609,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1): return ExtractDiag(offset, axis1, axis2)(a) +@_vectorize_node.register(ExtractDiag) +def vectorize_extract_diag(op: ExtractDiag, node, batched_x): + batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim + return diagonal( + batched_x, + offset=op.offset, + axis1=op.axis1 + batched_ndims, + axis2=op.axis2 + batched_ndims, + ).owner + + def trace(a, offset=0, axis1=0, axis2=1): """ Returns the sum along diagonals of the array. diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 2c2b82d1b5..3ce5ffce63 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -20,7 +20,7 @@ from pytensor.misc.safe_asarray import _asarray from pytensor.raise_op import Assert from pytensor.scalar import autocast_float, autocast_float_as -from pytensor.tensor import NoneConst +from pytensor.tensor import NoneConst, vectorize from pytensor.tensor.basic import ( Alloc, AllocEmpty, @@ -88,6 +88,7 @@ vertical_stack, zeros_like, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import dense_dot @@ -4517,3 +4518,26 @@ def test_trace(): trace(x, offset=-1, axis1=0, axis2=-1).eval(), np.trace(x_val, offset=-1, axis1=0, axis2=-1), ) + + +def test_vectorize_extract_diag(): + signature = "(a1,b,a2)->(b,a)" + + def core_pt(x): + return at.diagonal(x, offset=1, axis1=0, axis2=2) + + def core_np(x): + return np.diagonal(x, offset=1, axis1=0, axis2=2) + + x = tensor(shape=(5, 5, 5, 5)) + vectorize_pt = function([x], vectorize(core_pt, signature=signature)(x)) + assert not any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + vectorize_np = np.vectorize(core_np, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test), + vectorize_np(x_test), + ) From 020cb46488b8e29ed6f81caa853078c65c5c7231 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:47:21 +0100 Subject: [PATCH 11/14] Vectorize Subtensor without batched indices --- pytensor/tensor/subtensor.py | 17 +++++++++++++ tests/tensor/test_subtensor.py | 44 +++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c05e965bf8..de0862f443 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -13,6 +13,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply, Constant, Variable from pytensor.graph.op import Op +from pytensor.graph.replace import _vectorize_node from pytensor.graph.type import Type from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp @@ -22,6 +23,7 @@ from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero +from pytensor.tensor.blockwise import vectorize_node_fallback from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError from pytensor.tensor.math import clip @@ -1283,6 +1285,21 @@ def _process(self, idxs, op_inputs, pstate): pprint.assign(Subtensor, SubtensorPrinter()) +# TODO: Implement similar vectorize for Inc/SetSubtensor +@_vectorize_node.register(Subtensor) +def vectorize_subtensor(op: Subtensor, node, batch_x, *batch_idxs): + """Rewrite subtensor with non-batched indexes as another Subtensor with prepended empty slices.""" + + # TODO: Vectorize Subtensor with non-slice batched indexes as AdvancedSubtensor + if any(batch_inp.type.ndim > 0 for batch_inp in batch_idxs): + return vectorize_node_fallback(op, node, batch_x, *batch_idxs) + + old_x, *_ = node.inputs + batch_ndims = batch_x.type.ndim - old_x.type.ndim + new_idx_list = (slice(None),) * batch_ndims + op.idx_list + return Subtensor(new_idx_list).make_node(batch_x, *batch_idxs) + + def set_subtensor(x, y, inplace=False, tolerate_inplace_aliasing=False): """ Return x with the given subtensor overwritten by y. diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index b5952cc49e..9ee39a4a98 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -9,6 +9,7 @@ import pytensor import pytensor.scalar as scal import pytensor.tensor.basic as at +from pytensor import function from pytensor.compile import DeepCopyOp, shared from pytensor.compile.io import In from pytensor.configdefaults import config @@ -16,7 +17,8 @@ from pytensor.graph.rewriting.utils import is_same_graph from pytensor.printing import pprint from pytensor.scalar.basic import as_scalar -from pytensor.tensor import get_vector_length +from pytensor.tensor import get_vector_length, vectorize +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.math import exp, isinf from pytensor.tensor.math import sum as at_sum @@ -2709,3 +2711,43 @@ def test_static_shapes(x_shape, indices, expected): x = at.tensor(dtype="float64", shape=x_shape) y = x[indices] assert y.type.shape == expected + + +def test_vectorize_subtensor_without_batch_indices(): + signature = "(t1,t2,t3),()->(t1,t3)" + + def core_fn(x, start): + return x[:, start, :] + + x = tensor(shape=(11, 7, 5, 3)) + start = tensor(shape=(), dtype="int") + vectorize_pt = function( + [x, start], vectorize(core_fn, signature=signature)(x, start) + ) + assert not any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + start_test = np.random.randint(0, x.type.shape[-2]) + vectorize_np = np.vectorize(core_fn, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, start_test), + vectorize_np(x_test, start_test), + ) + + # If we vectorize start, we should get a Blockwise that still works + x = tensor(shape=(11, 7, 5, 3)) + start = tensor(shape=(11,), dtype="int") + vectorize_pt = function( + [x, start], vectorize(core_fn, signature=signature)(x, start) + ) + assert any( + isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes + ) + x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype) + start_test = np.random.randint(0, x.type.shape[-2], size=start.type.shape[0]) + vectorize_np = np.vectorize(core_fn, signature=signature) + np.testing.assert_allclose( + vectorize_pt(x_test, start_test), + vectorize_np(x_test, start_test), + ) From 50997a740b80dfe33faf2bb232ef106c94f236ea Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 5 Dec 2023 20:49:48 +0100 Subject: [PATCH 12/14] Add rewrite for Blockwise with Alloc inputs Also prevent Alloc from constant_folding when it's used by Elemwise and Blockwise to avoid creating useless large arrays --- pytensor/graph/basic.py | 6 +- pytensor/tensor/basic.py | 15 ++- pytensor/tensor/rewriting/blockwise.py | 123 ++++++++++++++++++++++- tests/tensor/rewriting/test_blockwise.py | 86 +++++++++++++++- 4 files changed, 221 insertions(+), 9 deletions(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index a11aa57bdf..9b1399b72f 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1777,6 +1777,7 @@ def equal_computations( ys: list[Union[np.ndarray, Variable]], in_xs: Optional[list[Variable]] = None, in_ys: Optional[list[Variable]] = None, + strict_dtype=True, ) -> bool: """Checks if PyTensor graphs represent the same computations. @@ -1908,7 +1909,10 @@ def compare_nodes(nd_x, nd_y, common, different): if dx != dy: if isinstance(dx, Constant) and isinstance(dy, Constant): if not dx.equals(dy): - return False + if strict_dtype: + return False + elif not np.array_equal(dx.data, dy.data): + return False else: return False diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 498803c0f7..946660e431 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -42,6 +42,7 @@ as_tensor_variable, get_vector_length, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.shape import ( @@ -1658,16 +1659,22 @@ def do_constant_folding(self, fgraph, node): if not clients: return False - for client in clients: - if client[0] == "output": + for client, idx in clients: + if client == "output": # If the output is a constant, it will have to be deepcopied # each time the function is called. So we do not fold. return False + # Allow alloc to be lifted out of Elemwise before constant folding it + elif isinstance(client.op, Elemwise): + return None + # Same for Blockwise, unless it has no batch_dims + elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client): + return None elif ( # The following ops work inplace of their input id 0. - client[1] == 0 + idx == 0 and isinstance( - client[0].op, + client.op, ( # Ops that will work inplace on the Alloc. So if they # get constant_folded, they would copy the diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 69cddc595d..4cbfcdaa32 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,8 +1,10 @@ +from typing import Optional + from pytensor.compile.mode import optdb -from pytensor.graph import node_rewriter +from pytensor.graph import Constant, node_rewriter from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in -from pytensor.tensor.basic import Alloc, ARange, shape_padleft +from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import Dot from pytensor.tensor.rewriting.basic import ( @@ -80,3 +82,120 @@ def local_eager_useless_unbatched_blockwise(fgraph, node): ), ): return local_useless_unbatched_blockwise.fn(fgraph, node) + + +def _squeeze_left(x, stop_at_dim: Optional[int] = None): + """Squeeze any leading dims of `x` until a real dim or `stop_at_dim` (if not None) is reached.""" + x_dims = x.type.broadcastable + squeeze_ndim = len(x_dims) if all(x_dims) else x_dims.index(False) + if stop_at_dim is not None: + squeeze_ndim = min(squeeze_ndim, stop_at_dim) + if squeeze_ndim == 0: + return x + return x.squeeze(axis=tuple(range(squeeze_ndim))) + + +@register_specialize("shape_unsafe") +@node_rewriter([Blockwise]) +def local_blockwise_alloc(fgraph, node): + """Push Allocs from the inputs to the output of Blockwise Ops. + + BOp = Blockwise(Op, signature="(x),(x)->(x)") + BOp(vector, alloc(vector, 10, 5)) -> alloc(BOp)(vector, vector), 10, 5) + BOp(vector, alloc(scalar, 10, 5)) -> alloc(BOp)(vector, alloc(scalar, 5), 10, 5) + BOp(matrix, alloc(vector, 10, 5)) -> BOp(matrix, vector) + """ + + if not any(isinstance(inp.owner.op, Alloc) for inp in node.inputs if inp.owner): + return None + + op: Blockwise = node.op # type: ignore + + batch_ndim = op.batch_ndim(node) + if not batch_ndim: + return None + + new_inputs = [] + batch_shapes = [] + can_push_any_alloc = False + for inp, inp_sig in zip(node.inputs, op.inputs_sig): + if inp.owner and isinstance(inp.owner.op, Alloc): + # Push batch dims from Alloc + value, *shape = inp.owner.inputs + + # Check what to do with the value of the Alloc + squeezed_value = _squeeze_left(value, batch_ndim) + missing_ndim = len(shape) - value.type.ndim + if ( + ((1,) * missing_ndim + value.type.broadcastable)[batch_ndim:] + ) != inp.type.broadcastable[batch_ndim:]: + # We still need an Alloc for the core dims + core_shape = shape[batch_ndim:] + # And the batch dims of the squeezed value + squeezed_value_batch_ndim = squeezed_value.type.ndim - len(core_shape) + batch_shape = [ + 1 if broadcastable else dim + for broadcastable, dim in zip( + squeezed_value.type.broadcastable[:squeezed_value_batch_ndim], + tuple(squeezed_value.shape)[:squeezed_value_batch_ndim], + ) + ] + squeezed_value = alloc(squeezed_value, *batch_shape, *core_shape) + if squeezed_value.type.broadcastable == inp.type.broadcastable: + # We can't change anything about this Alloc input + new_inputs.append(inp) + continue + + # We can push batch dims of this Alloc input + batch_shapes.append( + tuple( + 1 if broadcastable else dim + for broadcastable, dim in zip( + inp.type.broadcastable, shape[:batch_ndim] + ) + ) + ) + new_inputs.append(squeezed_value) + can_push_any_alloc = True + + else: + # Nothing to do with this input other than removing dummy batch dims + new_inputs.append(_squeeze_left(inp, batch_ndim)) + + if not can_push_any_alloc: + return None + + new_outs = node.op.make_node(*new_inputs).outputs + + new_out_type = new_outs[0].type + old_out_type = node.outputs[0].type + if new_out_type.broadcastable != old_out_type.broadcastable: + # An Alloc is still needed to broadcast the new output to the original shape + # We pick the most parsimonious batch dim from the pushed Alloc + missing_ndim = old_out_type.ndim - new_out_type.ndim + batch_shape = ([1] * missing_ndim + list(new_outs[0].shape))[:batch_ndim] + for i, batch_dims in enumerate(zip(*batch_shapes)): # Transpose shape tuples + for batch_dim in batch_dims: + if batch_dim == 1: + continue + if isinstance(batch_dim, Constant): + # Give preference to Constants + batch_shape[i] = batch_dim + break + elif old_out_type.broadcastable[i]: + # Only use non Constant shapes if absolutely necessary + # Otherwise, we use the shape of the non-alloc output + batch_shape[i] = batch_dim + + copy_stack_trace(node.outputs, new_outs) + new_outs = [ + alloc( + new_out, + *batch_shape, + *tuple(new_out.shape)[batch_ndim - missing_ndim :], + ) + for new_out in new_outs + ] + assert new_outs[0].type.broadcastable == old_out_type.broadcastable + copy_stack_trace(node.outputs, new_outs) + return new_outs diff --git a/tests/tensor/rewriting/test_blockwise.py b/tests/tensor/rewriting/test_blockwise.py index 0b67eba197..d5ea6e2b4e 100644 --- a/tests/tensor/rewriting/test_blockwise.py +++ b/tests/tensor/rewriting/test_blockwise.py @@ -1,7 +1,10 @@ +from functools import partial + from pytensor import function -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, rewrite_graph +from pytensor.graph.basic import equal_computations from pytensor.scalar import log as scalar_log -from pytensor.tensor import matrix, tensor3 +from pytensor.tensor import add, alloc, matrix, tensor, tensor3 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.nlinalg import MatrixPinv @@ -36,3 +39,82 @@ def test_useless_unbatched_blockwise(): fn = function([x], out, mode="FAST_COMPILE") assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise) assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv) + + +def test_blockwise_alloc(): + rewrite = partial( + rewrite_graph, + include=("ShapeOpt", "specialize"), + exclude=("local_useless_unbatched_blockwise",), + ) + + vector_add = Blockwise(core_op=add, signature="(x),(x)->(x)") + + # Depending on the rewrites the Alloc shape may be upcast to int64 or not + # We do not care about that for the purposes of this test + equal = partial(equal_computations, strict_dtype=False) + + # Case where Alloc is not necessary + x = tensor("x", shape=(7, 5)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = vector_add(x, y) + assert equal([rewrite(out)], [expected_out]) + + # Cases where Alloc can be fully pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, y), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(1, 5)) + y = tensor("y", shape=(5,)) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x.squeeze(0), y), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(7, 5)) + y = tensor("y", shape=(7, 5)) + out = vector_add(x, alloc(y, 3, 7, 5)) + expected_out = alloc(vector_add(x, y), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(7, 1, 5)) + out = vector_add(x, alloc(y, 7, 2, 5)) + expected_out = alloc(vector_add(x, y), 7, 2, 5) + assert equal([rewrite(out)], [expected_out]) + + # Case where Alloc can be partially pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=()) + out = vector_add(x, alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, alloc(y, 5)), 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(7, 1, 1)) + out = vector_add(x, alloc(y, 7, 2, 5)) + expected_out = alloc(vector_add(x, alloc(y, 7, 1, 5)), 7, 2, 5) + assert equal([rewrite(out)], [expected_out], strict_dtype=False) + + # Cases involving multiple Allocs being pushed + x = tensor("x", shape=()) + y = tensor("y", shape=()) + out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) + expected_out = alloc(vector_add(alloc(x, 5), alloc(y, 5)), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + x = tensor("x", shape=(5,)) + y = tensor("y", shape=()) + out = vector_add(alloc(x, 3, 1, 5), alloc(y, 7, 5)) + expected_out = alloc(vector_add(x, alloc(y, 5)), 3, 7, 5) + assert equal([rewrite(out)], [expected_out]) + + # Case where Alloc cannot be pushed + x = tensor("x", shape=(5,)) + y = tensor("y", shape=(1,)) + out = vector_add(x, alloc(y, 5)) + expected_out = out + assert equal([rewrite(out)], [expected_out]) From 048a20995b67e2736779e8522bf35aba54cd57ea Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 6 Dec 2023 14:49:03 +0000 Subject: [PATCH 13/14] Add rewrite to remove Blockwise of AdvancedIncSubtensor --- pytensor/tensor/rewriting/subtensor.py | 56 ++++++++++++++ tests/tensor/rewriting/test_subtensor.py | 98 +++++++++++++++++++++++- 2 files changed, 153 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index e1174a7e8d..e860034235 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -29,6 +29,7 @@ register_infer_shape, switch, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, add @@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node): copy_stack_trace(node.outputs, new_outs) return new_outs + + +@register_canonicalize("shape_unsafe") +@register_stabilize("shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([Blockwise]) +def local_blockwise_advanced_inc_subtensor(fgraph, node): + """Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices.""" + if not isinstance(node.op.core_op, AdvancedIncSubtensor): + return None + + x, y, *idxs = node.inputs + + # It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case + if any( + ( + isinstance(idx, (SliceType, NoneTypeT)) + or (idx.type.dtype == "bool" and idx.type.ndim > 0) + ) + for idx in idxs + ): + return None + + op: Blockwise = node.op # type: ignore + batch_ndim = op.batch_ndim(node) + + new_idxs = [] + for idx in idxs: + if all(idx.type.broadcastable[:batch_ndim]): + new_idxs.append(idx.squeeze(tuple(range(batch_ndim)))) + else: + # Rewrite does not apply + return None + + x_batch_bcast = x.type.broadcastable[:batch_ndim] + y_batch_bcast = y.type.broadcastable[:batch_ndim] + if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)): + # Need to broadcast batch x dims + batch_shape = tuple( + x_dim if (not xb or yb) else y_dim + for xb, x_dim, yb, y_dim in zip( + x_batch_bcast, + tuple(x.shape)[:batch_ndim], + y_batch_bcast, + tuple(y.shape)[:batch_ndim], + ) + ) + core_shape = tuple(x.shape)[batch_ndim:] + x = alloc(x, *batch_shape, *core_shape) + + new_idxs = [slice(None)] * batch_ndim + new_idxs + symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:] + new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs + copy_stack_trace(node.outputs, new_out) + return new_out diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 4a0016136f..b77cdbe315 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -9,7 +9,7 @@ from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp from pytensor.configdefaults import config -from pytensor.graph import FunctionGraph +from pytensor.graph import FunctionGraph, vectorize_graph from pytensor.graph.basic import Constant, Variable, ancestors from pytensor.graph.rewriting.basic import check_stack_trace from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -18,6 +18,7 @@ from pytensor.raise_op import Assert from pytensor.tensor import inplace from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot, add, dot, exp, sqr from pytensor.tensor.rewriting.subtensor import ( @@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices(): new_index = subtensor_node.inputs[1] assert isinstance(new_index, Constant) assert new_index.type.dtype == "uint8" + + +@pytest.mark.parametrize("set_instead_of_inc", (True, False)) +def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): + core_x = tensor("x", shape=(6,)) + core_y = tensor("y", shape=(3,)) + core_idxs = [0, 2, 4] + if set_instead_of_inc: + core_graph = set_subtensor(core_x[core_idxs], core_y) + else: + core_graph = inc_subtensor(core_x[core_idxs], core_y) + + # Only x is batched + x = tensor("x", shape=(5, 2, 6)) + y = tensor("y", shape=(3,)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype) + expected_out = test_x.copy() + if set_instead_of_inc: + expected_out[:, :, core_idxs] = test_y + else: + expected_out[:, :, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Only y is batched + x = tensor("y", shape=(6,)) + y = tensor("y", shape=(2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype) + expected_out = np.ones((2, *x.type.shape)) + if set_instead_of_inc: + expected_out[:, core_idxs] = test_y + else: + expected_out[:, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Both x and y are batched, and do not need to be broadcasted + x = tensor("y", shape=(2, 6)) + y = tensor("y", shape=(2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype) + expected_out = test_x.copy() + if set_instead_of_inc: + expected_out[:, core_idxs] = test_y + else: + expected_out[:, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) + + # Both x and y are batched, but must be broadcasted + x = tensor("y", shape=(5, 1, 6)) + y = tensor("y", shape=(1, 2, 3)) + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) + assert isinstance(out.owner.op, Blockwise) + + fn = pytensor.function([x, y], out, mode="FAST_RUN") + assert not any( + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes + ) + + test_x = np.ones(x.type.shape, dtype=x.type.dtype) + test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype) + final_shape = ( + *np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]), + x.type.shape[-1], + ) + expected_out = np.broadcast_to(test_x, final_shape).copy() + if set_instead_of_inc: + expected_out[:, :, core_idxs] = test_y + else: + expected_out[:, :, core_idxs] += test_y + np.testing.assert_allclose(fn(test_x, test_y), expected_out) From 505882e26bb17a88c9ce161c969e1e88c09ce9e3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 9 Dec 2023 13:47:03 +0100 Subject: [PATCH 14/14] Better error for fallback of vectorize_node with non-tensor types --- pytensor/tensor/blockwise.py | 9 ++++++++- tests/tensor/test_blockwise.py | 16 +++++++++++++++- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index d4cd5152e7..d21af2f651 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -14,9 +14,10 @@ _vectorize_not_needed, vectorize_graph, ) +from pytensor.scalar import ScalarType from pytensor.tensor import as_tensor_variable from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor +from pytensor.tensor.type import TensorType, continuous_dtypes, discrete_dtypes, tensor from pytensor.tensor.utils import ( _parse_gufunc_signature, broadcast_static_dim_lengths, @@ -373,6 +374,12 @@ def __str__(self): @_vectorize_node.register(Op) def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply: + for inp in node.inputs: + if not isinstance(inp.type, (TensorType, ScalarType)): + raise NotImplementedError( + f"Cannot vectorize node {node} with input {inp} of type {inp.type}" + ) + if hasattr(op, "gufunc_signature"): signature = op.gufunc_signature else: diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 9500fec5f8..ac0a3c542e 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -1,3 +1,4 @@ +import re from itertools import product from typing import Optional, Union @@ -12,7 +13,7 @@ from pytensor.graph.replace import vectorize_node from pytensor.raise_op import assert_op from pytensor.tensor import diagonal, log, tensor -from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular @@ -42,6 +43,19 @@ def test_vectorize_blockwise(): assert new_vect_node.inputs[0] is tns4 +def test_vectorize_node_fallback_unsupported_type(): + x = tensor("x", shape=(2, 6)) + node = x[:, [0, 2, 4]].owner + + with pytest.raises( + NotImplementedError, + match=re.escape( + "Cannot vectorize node AdvancedSubtensor(x, MakeSlice.0, [0 2 4]) with input MakeSlice.0 of type slice" + ), + ): + vectorize_node_fallback(node.op, node, node.inputs) + + def check_blockwise_runtime_broadcasting(mode): a = tensor("a", shape=(None, 3, 5)) b = tensor("b", shape=(None, 5, 3))