diff --git a/pytensor/graph/rewriting/basic.py b/pytensor/graph/rewriting/basic.py index 16b5b65a0e..b91e743bb6 100644 --- a/pytensor/graph/rewriting/basic.py +++ b/pytensor/graph/rewriting/basic.py @@ -2800,16 +2800,6 @@ def _check_chain(r, chain): return r is not None -def check_chain(r, *chain): - """ - WRITEME - - """ - if isinstance(r, Apply): - r = r.outputs[0] - return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain))) - - def pre_greedy_node_rewriter( fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable ) -> Variable: diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index cb60427ba0..c37597906a 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -166,15 +166,20 @@ def __init__(self, *, input_ndim: int, new_order: Sequence[int | Literal["x"]]): self.transposition = self.shuffle + drop # List of dimensions of the output that are broadcastable and were not # in the original input - self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") + self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.drop = drop - self.is_left_expand_dims = self.augment and ( + dims_are_shuffled = sorted(self.shuffle) != self.shuffle + + self.is_transpose = dims_are_shuffled and not augment and not drop + self.is_squeeze = drop and not dims_are_shuffled and not augment + self.is_expand_dims = augment and not dims_are_shuffled and not drop + self.is_left_expand_dims = self.is_expand_dims and ( input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) ) - self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( - range(input_ndim) - ) + self.is_right_expand_dims = self.is_expand_dims and new_order[ + :input_ndim + ] == list(range(input_ndim)) if self.inplace: self.view_map = {0: [0]} @@ -215,16 +220,15 @@ def make_node(self, inp): return Apply(self, [input], [output]) def __str__(self): - shuffle = sorted(self.shuffle) != self.shuffle - if self.augment and not (shuffle or self.drop): + if self.is_expand_dims: if len(self.augment) == 1: return f"ExpandDims{{axis={self.augment[0]}}}" return f"ExpandDims{{axes={self.augment}}}" - if self.drop and not (self.augment or shuffle): + if self.is_squeeze: if len(self.drop) == 1: - return f"DropDims{{axis={self.drop[0]}}}" - return f"DropDims{{axes={self.drop}}}" - if shuffle and not (self.augment or self.drop): + return f"Squeeze{{axis={self.drop[0]}}}" + return f"Squeeze{{axes={self.drop}}}" + if self.is_transpose: return f"Transpose{{axes={self.shuffle}}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index e277772ad4..e86411dd9c 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -12,16 +12,17 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import ( GraphRewriter, - check_chain, copy_stack_trace, node_rewriter, ) from pytensor.graph.utils import InconsistencyError, get_variable_trace_string +from pytensor.scalar import ScalarType from pytensor.tensor.basic import ( MakeVector, as_tensor_variable, cast, constant, + expand_dims, get_scalar_constant_value, register_infer_shape, stack, @@ -35,6 +36,7 @@ register_useless, topo_constant_folding, ) +from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift from pytensor.tensor.shape import ( Reshape, Shape, @@ -47,6 +49,7 @@ from pytensor.tensor.subtensor import Subtensor, get_idx_list from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes from pytensor.tensor.type_other import NoneConst, NoneTypeT +from pytensor.tensor.variable import TensorVariable class ShapeFeature(Feature): @@ -755,6 +758,38 @@ def apply(self, fgraph): pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10) +@register_useless +@register_canonicalize +@node_rewriter([Reshape]) +def local_useless_expand_dims_in_reshape(fgraph, node): + """ + Removes useless expand_dims `DimShuffle` operations inside Reshape: + reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp) + reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp) + + Implicit (and useless) squeezes are kept in the graph, as they are + part of the canonical form of the graph. + """ + expanded_x, new_shape = node.inputs + + if not ( + expanded_x.owner is not None + and isinstance(expanded_x.owner.op, DimShuffle) + and expanded_x.owner.op.augment + ): + return False + + [x] = expanded_x.owner.inputs + + new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x") + if new_order != tuple(range(x.type.ndim)): + x = x.dimshuffle(new_order) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] + + @register_canonicalize("shape_unsafe") @register_specialize("shape_unsafe") @node_rewriter([Reshape]) @@ -763,30 +798,89 @@ def local_reshape_chain(fgraph, node): Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2) """ - if not check_chain(node, Reshape, Reshape): + inner_reshape, final_shape = node.inputs + + if not (inner_reshape.owner and isinstance(inner_reshape.owner.op, Reshape)): + return None + + x, _ = inner_reshape.owner.inputs + new_reshape = node.op(x, final_shape) + + copy_stack_trace(node.outputs, new_reshape) + return [new_reshape] + + +def _is_shape_i_of_x( + var: TensorVariable, + x: TensorVariable, + i: int, + shape_feature: ShapeFeature | None = None, +) -> bool: + if var.type.ndim != 0: return False - rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1]) - - # Copy over stacktrace from previous output node, as any error - # in new computational graph would have been caused by last op - # in the old computational graph. - copy_stack_trace(node.outputs, rval) - - # It might happen that the desired output of this node has a - # broadcastable pattern that does not match that of 'rval'. This is - # when originally, we were able to figure out that one of the - # dimensions of the reshape is one, but some other transformation - # replaced the shape by one for which this cannot be guessed. - # We should try to figure out why we lost the information about this - # constant value... but in the meantime, better not apply this - # rewrite. - if rval.type.ndim == node.outputs[0].type.ndim and all( - s1 == s2 - for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True) - if s1 == 1 or s2 == 1 - ): - return [rval] + constant_var = get_scalar_constant_value( + var, + only_process_constants=False, + # Don't go through Elemwise to keep things fast + elemwise=False, + raise_not_constant=False, + ) + + # Check var is a constant expression with the same value as x.type.shape[i] + if constant_var == x.type.shape[i]: + return True + + # Match shape_of[x][i] or its constant equivalent + if shape_feature is not None: + i_shape_of_x = shape_feature.get_shape(x, i) + if i_shape_of_x == var or ( + isinstance(i_shape_of_x, Constant) and (i_shape_of_x.data == constant_var) + ): + return True + + if var.owner is None: + # No more constant possibilities + return False + + # Match Shape_i{i}(x) + if isinstance(var.owner.op, Shape_i): + return (var.owner.op.i == i) and (var.owner.inputs[0] == x) # type: ignore + + # Match Subtensor((ScalarType,))(Shape(input), i) + if isinstance(var.owner.op, Subtensor): + return ( + # Check we have integer indexing operation + # (and not slice or multiple indexing) + len(var.owner.op.idx_list) == 1 + and isinstance(var.owner.op.idx_list[0], ScalarType) + # Check we are indexing on the shape of x + and var.owner.inputs[0].owner is not None + and isinstance(var.owner.inputs[0].owner.op, Shape) + and var.owner.inputs[0].owner.inputs[0] == x + # Check that index == i + and ( + get_scalar_constant_value(var.owner.inputs[1], raise_not_constant=False) + == i + ) + ) + + return False + + +def _unpack_shape_vector(shape: TensorVariable) -> tuple[TensorVariable, ...]: + """Return the elements of a symbolic vector representing a shape. + + Handles the most common constant vector or make_vector cases. + + Returns tuple(shape) as fallback. + """ + if isinstance(shape, Constant): + return tuple(as_tensor_variable(dim, ndim=0) for dim in shape.data) + elif shape.owner and isinstance(shape.owner.op, MakeVector): + return tuple(shape.owner.inputs) + else: + return tuple(shape) @register_useless("shape_unsafe") @@ -821,132 +915,151 @@ def local_useless_reshape(fgraph, node): if shape_input == inp: return [inp] - # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for - # broadcastable and constant dimensions - if isinstance(output_shape, Constant) or ( - output_shape.owner and isinstance(output_shape.owner.op, MakeVector) - ): - if isinstance(output_shape, Constant): - output_shape_is = [ - as_tensor_variable(dim, ndim=0) for dim in output_shape.data - ] - else: - output_shape_is = output_shape.owner.inputs - - shape_feature = getattr(fgraph, "shape_feature", None) - - nb_m1 = 0 - shape_match = [False] * inp.type.ndim - for dim in range(inp.type.ndim): - outshp_i = output_shape_is[dim] - # Match Shape_i{dim}(input) - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Shape_i) - and outshp_i.owner.op.i == dim - and outshp_i.owner.inputs[0] == inp - ): - shape_match[dim] = True - continue + shape_feature = getattr(fgraph, "shape_feature", None) - # Match Shape(input)[dim] - if ( - outshp_i.owner - and isinstance(outshp_i.owner.op, Subtensor) - and len(outshp_i.owner.inputs) == 2 - and get_scalar_constant_value( - outshp_i.owner.inputs[1], raise_not_constant=False - ) - == dim - ): - subtensor_inp = outshp_i.owner.inputs[0] - if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape): - shape_input_i = subtensor_inp.owner.inputs[0] - if shape_input_i == inp: - shape_match[dim] = True - continue - - # Match constant if input.type.shape[dim] == constant - cst_outshp_i = get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - if inp.type.shape[dim] == cst_outshp_i: - shape_match[dim] = True - continue + # Match case where at least (n-1) entries correspond to the original shape: + # Reshape(x, [x.shape[0], ..., x.shape[-1]]), or Reshape(x, [x.shape[0], y, x.shape[2], ... x.shape[-1]]) + # Where y can be -1 or anything with an unknown value, since the only valid reshape is still a no reshape. + output_shape_is = _unpack_shape_vector(output_shape) + nb_m1 = 0 + shape_match = [False] * inp.type.ndim + for dim in range(inp.type.ndim): + outshp_i = output_shape_is[dim] + if _is_shape_i_of_x(outshp_i, inp, dim, shape_feature=shape_feature): + shape_match[dim] = True + elif isinstance(outshp_i, Constant) and outshp_i.data == -1: + shape_match[dim] = True + nb_m1 += 1 - # Match -1 - if cst_outshp_i == -1: - shape_match[dim] = True - nb_m1 += 1 - continue + if nb_m1 <= 1 and all(shape_match): + return [inp] # This is provably correct - # Match shape_of[input][dim] or its constant equivalent - if shape_feature: - inpshp_i = shape_feature.get_shape(inp, dim) - if inpshp_i == outshp_i or ( - get_scalar_constant_value( - inpshp_i, only_process_constants=True, raise_not_constant=False - ) - == get_scalar_constant_value( - outshp_i, only_process_constants=True, raise_not_constant=False - ) - ): - shape_match[dim] = True - continue + # There is one missing match, but all other dimensions match + # Such as x.type.shape == (3, 5, None) and output_shape == (3, 5, y) + if (nb_m1 == 0) and (shape_match.count(False) == 1): + return [inp] # This could mask a shape error - if nb_m1 <= 1 and all(shape_match): - return [inp] + return False - if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1): - return [inp] - return False - - -@register_canonicalize +@register_canonicalize("shape_unsafe") @node_rewriter([Reshape]) def local_reshape_to_dimshuffle(fgraph, node): - r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. + r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions. - The goal is to avoid using `Reshape` to add or remove broadcastable - dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can - cancel out and/or be removed later on. + It's always valid to squeeze an input before doing the same reshape operation. + Equivalently, it's always valid to remove `1` entries from the reshape shape + and replace them by an expand_dims after the rewritten reshape operation. + + We chose to canonicalize the graph in this way as it allows isolating + operations that are unique to the reshaping operation (mixing dimensions) + from those that can be more legibly encoded by DimShuffle (squeeze and expand_dims). + This can allow further simplifications by other rewrites that target + DimShuffle but not Reshape, as well as facilitate the removal of useless reshape operations. For example: - - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) - - reshape(x, (1, m, 1, n, 1, 1)) - -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) + - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n)) + - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0) + - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5)) + """ - op = node.op inp, output_shape = node.inputs [output] = node.outputs - dimshuffle_new_order = [] - new_output_shape = [] - index = 0 # index over the output of the new reshape - for i in range(output.ndim): - # Since output_shape is a symbolic vector, we trust get_scalar_constant_value - # to go through however it is formed to see if its i-th element is 1. - # We need only_process_constants=False for that. - dim = get_scalar_constant_value( - output_shape[i], - only_process_constants=False, - elemwise=False, - raise_not_constant=False, - ) - if dim == 1: - dimshuffle_new_order.append("x") - else: - dimshuffle_new_order.append(index) - new_output_shape.append(dim) - index = index + 1 + # Remove any broadcastable dimensions from the input + squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] + + # Trivial case, all dimensions of input/output are known to be broadcastable: + # there's nothing to reshape + if all(inp.type.broadcastable) or all(output.type.broadcastable): + new_output_shape = [] + expand_axes = tuple(range(output.type.ndim)) + + else: + unpacked_shape = _unpack_shape_vector(output_shape) + new_output_shape = [] + expand_axes = [] + for i, dim_length in enumerate(unpacked_shape): + if isinstance(dim_length, Constant) and ( + dim_length.data == 1 + # -1 can be an implicit expand_dims, but it's tricky to prove + # as we would need to check whether all other dimensions + # already explain the full size of the array. + # Example: np.zeros((2, 2, 2)).reshape((8, -1)) + # We rely on the output static shape which will already have figured + # it out for some (but not all) cases + or (dim_length.data == -1 and output.type.shape[i] == 1) + ): + expand_axes.append(i) + else: + new_output_shape.append(dim_length) + + if squeeze_axes or expand_axes: + new_out = inp.squeeze(squeeze_axes) + + if new_output_shape: + new_out = new_out.reshape(new_output_shape) + copy_stack_trace(output, new_out) + + new_out = expand_dims(new_out, expand_axes) + + if not new_output_shape: + # Eagerly merge consecutive squeeze and expand_dims + new_out = apply_local_dimshuffle_lift(fgraph, new_out) + + copy_stack_trace(output, new_out) + return [new_out] + + +@register_specialize +@node_rewriter([Reshape]) +def local_fuse_squeeze_reshape(fgraph, node): + r"""If there is a squeeze right before a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + x, new_shape = node.inputs + + if ( + x.owner is not None + and isinstance(x.owner.op, DimShuffle) + and x.owner.op.is_squeeze + ): + # A reshape can always subsume a squeeze. + x = x.owner.inputs[0] + return [x.reshape(new_shape)] + - if index != output.type.ndim: - inner = op.__class__(len(new_output_shape))(inp, new_output_shape) - copy_stack_trace(output, inner) - new_node = [inner.dimshuffle(dimshuffle_new_order)] - copy_stack_trace(output, new_node) - return new_node +@register_specialize +@node_rewriter([DimShuffle]) +def local_fuse_expand_dims_reshape(fgraph, node): + r"""If there is an expand_dims right after a reshape, merge them. + + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. + """ + if not node.op.is_expand_dims: + return None + + reshaped_x = node.inputs[0] + + if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)): + return None + + if len(fgraph.clients[reshaped_x]) > 1: + # The reshape is used elsewhere, don't fuse as it can sometimes require a copy. + # Example: `x = pt.matrix(); y = x.T.reshape(-1); out = y[: None] * y[None, :]` + return None + + x, new_shape = reshaped_x.owner.inputs + + # Add expand_dims to shape + new_shape = list(_unpack_shape_vector(new_shape)) + for i in node.op.augment: + new_shape.insert(i, 1) + + new_reshaped_x = x.reshape(new_shape) + copy_stack_trace(node.outputs[0], new_reshaped_x) + return [new_reshaped_x] @register_canonicalize @@ -1186,44 +1299,6 @@ def local_track_shape_i(fgraph, node): return [shape_feature.shape_of[replacement][node.op.i]] -@register_canonicalize -@node_rewriter([Reshape]) -def local_useless_dimshuffle_in_reshape(fgraph, node): - """ - Removes useless DimShuffle operation inside Reshape: - - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) - reshape(col.dimshuffle(0), shp) => reshape(col, shp) - - """ - op = node.op - if not isinstance(op, Reshape): - return False - if not ( - node.inputs[0].owner is not None - and isinstance(node.inputs[0].owner.op, DimShuffle) - ): - return False - - new_order = node.inputs[0].owner.op.new_order - inp = node.inputs[0].owner.inputs[0] - new_order_of_nonbroadcast = [] - for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): - if s != 1: - new_order_of_nonbroadcast.append(i) - no_change_in_order = all( - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] - for i in range(len(new_order_of_nonbroadcast) - 1) - ) - if no_change_in_order: - shape = node.inputs[1] - ret = op.__class__(node.outputs[0].ndim)(inp, shape) - copy_stack_trace(node.outputs[0], ret) - return [ret] - - @register_useless @register_canonicalize @register_specialize diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 8913d6fb4d..1c23a21347 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -1,7 +1,9 @@ import warnings +from collections.abc import Sequence from numbers import Number from textwrap import dedent -from typing import cast +from typing import TYPE_CHECKING, Union, cast +from typing import cast as typing_cast import numpy as np from numpy.core.numeric import normalize_axis_tuple # type: ignore @@ -24,6 +26,9 @@ from pytensor.tensor.variable import TensorConstant, TensorVariable +if TYPE_CHECKING: + from pytensor.tensor import TensorLike + ShapeValueType = None | np.integer | int | Variable @@ -639,6 +644,8 @@ def make_node(self, x, shp): x = ptb.as_tensor_variable(x) shp_orig = shp shp = ptb.as_tensor_variable(shp, ndim=1) + if shp.type.shape == (None,): + shp = specify_shape(shp, self.ndim) if not ( shp.dtype in int_dtypes or (isinstance(shp, TensorConstant) and shp.data.size == 0) @@ -842,9 +849,14 @@ def _vectorize_reshape(op, node, x, shape): return reshape(x, new_shape, ndim=len(new_shape)).owner -def reshape(x, newshape, ndim=None): +def reshape( + x: "TensorLike", + newshape: Union["TensorLike", Sequence["TensorLike"]], + *, + ndim: int | None = None, +) -> TensorVariable: if ndim is None: - newshape = ptb.as_tensor_variable(newshape) + newshape = ptb.as_tensor_variable(newshape) # type: ignore if newshape.type.ndim != 1: raise TypeError( "New shape in reshape must be a vector or a list/tuple of" @@ -862,7 +874,7 @@ def reshape(x, newshape, ndim=None): ) op = Reshape(ndim) rval = op(x, newshape) - return rval + return typing_cast(TensorVariable, rval) def shape_padleft(t, n_ones=1): diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index 325567918a..7f0be47656 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -918,7 +918,7 @@ def _direct_solve_discrete_lyapunov( vec_Q = Q.ravel() vec_X = solve(eye - AxA, vec_Q, b_ndim=1) - return cast(TensorVariable, reshape(vec_X, A.shape)) + return reshape(vec_X, A.shape) def solve_discrete_lyapunov( diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 8911f56630..ac8576a8a1 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -332,7 +332,6 @@ def test_basic_tile(self): mode = rewrite_mode.including( "local_dimshuffle_lift", - "local_useless_dimshuffle_in_reshape", "local_alloc_sink_dimshuffle", ) f = function([x], [y], mode=mode) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index f1b71949d1..6fb0594ed5 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -56,7 +56,10 @@ from pytensor.tensor.math import round as pt_round from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift -from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape +from pytensor.tensor.rewriting.shape import ( + local_fuse_squeeze_reshape, + local_useless_expand_dims_in_reshape, +) from pytensor.tensor.shape import reshape from pytensor.tensor.type import ( TensorType, @@ -182,7 +185,7 @@ def test_dimshuffle_lift_multi_out_elemwise(self): assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner) -def test_local_useless_dimshuffle_in_reshape(): +def test_local_useless_expand_dims_in_reshape(): vec = TensorType(dtype="float64", shape=(None,))("vector") mat = TensorType(dtype="float64", shape=(None, None))("mat") row = TensorType(dtype="float64", shape=(1, None))("row") @@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape(): clone=False, ) assert len(g.apply_nodes) == 4 * 3 - useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) + useless_dimshuffle_in_reshape = out2in( + local_useless_expand_dims_in_reshape, + # Useless squeeze in reshape is not a canonicalization anymore + local_fuse_squeeze_reshape, + ) useless_dimshuffle_in_reshape.rewrite(g) assert equal_computations( g.outputs, @@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape(): # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") - # Check that the rewrite does not get applied when the order - # of dimensions has changed. + # Check that the rewrite does not mess meaningful transpositions before the reshape reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) assert len(h.apply_nodes) == 3 useless_dimshuffle_in_reshape.rewrite(h) - assert equal_computations( - h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)] - ) + assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)]) class TestFusion: diff --git a/tests/tensor/rewriting/test_shape.py b/tests/tensor/rewriting/test_shape.py index bbfd829070..27678bd630 100644 --- a/tests/tensor/rewriting/test_shape.py +++ b/tests/tensor/rewriting/test_shape.py @@ -6,7 +6,7 @@ import pytensor.tensor as pt from pytensor import shared from pytensor.compile.function import function -from pytensor.compile.mode import get_default_mode, get_mode +from pytensor.compile.mode import Mode, get_default_mode, get_mode from pytensor.compile.ops import deep_copy_op from pytensor.configdefaults import config from pytensor.graph.basic import Apply, Variable, equal_computations @@ -383,6 +383,13 @@ def test_all_but_one_match(self): new_out = rewrite_graph(out) assert new_out is out + # Or if more than one dimension cannot be matched + x = tensor(shape=(None, None, None)) + shape = [x.shape[0], 3, 3] + out = reshape(x, shape) + new_out = rewrite_graph(out) + assert new_out is out + class TestLocalReshapeToDimshuffle: def setup_method(self): @@ -419,6 +426,60 @@ def test_basic(self): assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape)) + def test_expand_dims(self): + x = pt.scalar() + # This reshape does an implicit expand_dims + out = x.reshape((1, -1)) + assert isinstance(out.owner.op, Reshape) + new_out = rewrite_graph(out, include=("canonicalize",)) + assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))]) + + def test_squeeze_of_alloc(self): + # This shows up in the graph of repeat + x = pt.vector("x", shape=(9,)) + bcast_x = pt.alloc(x, 1, 12, x.shape[0]) + + # This reshape does an implicit squeeze + out = bcast_x.reshape((12, x.shape[0])) + + new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt")) + assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False) + + +def test_expand_dims_squeeze_reshape_fusion(): + x = pt.tensor("x", shape=(1, 9)) + reshape_x = x.squeeze(0).reshape((3, 3))[..., None] + + assert isinstance(reshape_x.owner.op, DimShuffle) + assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape) + assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle) + + out = rewrite_graph(reshape_x, include=("specialize",)) + + # In this case we cannot get rid of the reshape, squeeze or expand_dims, + # so we fuse them all in one reshape + assert equal_computations([out], [x.reshape((3, 3, 1))]) + + +def test_implicit_broadcasting_via_repeat(): + x = pt.vector("x", shape=(3,), dtype=int) + y = pt.vector("y", shape=(9,), dtype=int) + out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1) + # There are two Reshapes in the graph + assert isinstance(out.owner.inputs[0].owner.op, Reshape) + assert isinstance(out.owner.inputs[1].owner.op, Reshape) + + new_out = rewrite_graph(out, include=("canonicalize", "specialize")) + assert equal_computations([new_out], [x[None] <= y[:, None]]) + + no_rewrite_mode = Mode(linker="py", optimizer=None) + x_test = np.arange(3) + 1 + y_test = np.arange(9) + np.testing.assert_allclose( + new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode), + ) + def test_local_reshape_lift(): x = tensor4() diff --git a/tests/tensor/test_shape.py b/tests/tensor/test_shape.py index 7700d2b14b..3f0b04d45d 100644 --- a/tests/tensor/test_shape.py +++ b/tests/tensor/test_shape.py @@ -98,6 +98,7 @@ def setup_method(self): Shape_i, DimShuffle, Elemwise, + SpecifyShape, ) super().setup_method() @@ -253,9 +254,7 @@ def test_bad_shape(self): f(a_val, [7, 5]) with pytest.raises(ValueError): f(a_val, [-1, -1]) - with pytest.raises( - ValueError, match=".*Shape argument to Reshape has incorrect length.*" - ): + with pytest.raises(AssertionError): f(a_val, [3, 4, 1]) def test_0(self):