Skip to content

Commit f4ad2f3

Browse files
committed
Refactor reshape + dimshuffle rewrites
1 parent 1ea8cb8 commit f4ad2f3

File tree

2 files changed

+149
-170
lines changed

2 files changed

+149
-170
lines changed

pytensor/graph/rewriting/basic.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,16 +2800,6 @@ def _check_chain(r, chain):
28002800
return r is not None
28012801

28022802

2803-
def check_chain(r, *chain):
2804-
"""
2805-
WRITEME
2806-
2807-
"""
2808-
if isinstance(r, Apply):
2809-
r = r.outputs[0]
2810-
return _check_chain(r, reduce(list.__iadd__, ([x, 0] for x in chain)))
2811-
2812-
28132803
def pre_greedy_node_rewriter(
28142804
fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable
28152805
) -> Variable:

pytensor/tensor/rewriting/shape.py

Lines changed: 149 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,17 @@
1212
from pytensor.graph.fg import FunctionGraph
1313
from pytensor.graph.rewriting.basic import (
1414
GraphRewriter,
15-
check_chain,
1615
copy_stack_trace,
1716
node_rewriter,
1817
)
1918
from pytensor.graph.utils import InconsistencyError, get_variable_trace_string
19+
from pytensor.scalar import ScalarType
2020
from pytensor.tensor.basic import (
2121
MakeVector,
2222
as_tensor_variable,
2323
cast,
2424
constant,
25+
expand_dims,
2526
get_scalar_constant_value,
2627
register_infer_shape,
2728
stack,
@@ -47,6 +48,7 @@
4748
from pytensor.tensor.subtensor import Subtensor, get_idx_list
4849
from pytensor.tensor.type import TensorType, discrete_dtypes, integer_dtypes
4950
from pytensor.tensor.type_other import NoneConst, NoneTypeT
51+
from pytensor.tensor.variable import TensorVariable
5052

5153

5254
class ShapeFeature(Feature):
@@ -755,6 +757,42 @@ def apply(self, fgraph):
755757
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
756758

757759

760+
@register_canonicalize
761+
@node_rewriter([Reshape])
762+
def local_useless_dimshuffle_in_reshape(fgraph, node):
763+
"""
764+
Removes useless DimShuffle operation inside Reshape:
765+
766+
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767+
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768+
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769+
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
770+
771+
"""
772+
dimshuffled_x, new_shape = node.inputs
773+
774+
if not (
775+
dimshuffled_x.owner is not None
776+
and isinstance(dimshuffled_x.owner.op, DimShuffle)
777+
):
778+
return False
779+
780+
[inp] = dimshuffled_x.owner.inputs
781+
new_order = dimshuffled_x.owner.op.new_order
782+
new_order_of_nonbroadcast = []
783+
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
784+
if s != 1:
785+
new_order_of_nonbroadcast.append(i)
786+
no_change_in_order = all(
787+
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
788+
for i in range(len(new_order_of_nonbroadcast) - 1)
789+
)
790+
if no_change_in_order:
791+
ret = inp.reshape(new_shape)
792+
copy_stack_trace(node.outputs[0], ret)
793+
return [ret]
794+
795+
758796
@register_canonicalize("shape_unsafe")
759797
@register_specialize("shape_unsafe")
760798
@node_rewriter([Reshape])
@@ -763,30 +801,89 @@ def local_reshape_chain(fgraph, node):
763801
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
764802
765803
"""
766-
if not check_chain(node, Reshape, Reshape):
804+
inner_reshape, final_shape = node.inputs
805+
806+
if not (inner_reshape.owner and isinstance(inner_reshape.owner.op, Reshape)):
807+
return None
808+
809+
x, _ = inner_reshape.owner.inputs
810+
new_reshape = node.op(x, final_shape)
811+
812+
copy_stack_trace(node.outputs, new_reshape)
813+
return [new_reshape]
814+
815+
816+
def _is_shape_i_of_x(
817+
var: TensorVariable,
818+
x: TensorVariable,
819+
i: int,
820+
shape_feature: ShapeFeature | None = None,
821+
) -> bool:
822+
if var.type.ndim != 0:
767823
return False
768824

769-
rval = node.op(node.inputs[0].owner.inputs[0], node.inputs[1])
770-
771-
# Copy over stacktrace from previous output node, as any error
772-
# in new computational graph would have been caused by last op
773-
# in the old computational graph.
774-
copy_stack_trace(node.outputs, rval)
775-
776-
# It might happen that the desired output of this node has a
777-
# broadcastable pattern that does not match that of 'rval'. This is
778-
# when originally, we were able to figure out that one of the
779-
# dimensions of the reshape is one, but some other transformation
780-
# replaced the shape by one for which this cannot be guessed.
781-
# We should try to figure out why we lost the information about this
782-
# constant value... but in the meantime, better not apply this
783-
# rewrite.
784-
if rval.type.ndim == node.outputs[0].type.ndim and all(
785-
s1 == s2
786-
for s1, s2 in zip(rval.type.shape, node.outputs[0].type.shape, strict=True)
787-
if s1 == 1 or s2 == 1
788-
):
789-
return [rval]
825+
constant_var = get_scalar_constant_value(
826+
var,
827+
only_process_constants=False,
828+
# Don't go through Elemwise to keep things fast
829+
elemwise=False,
830+
raise_not_constant=False,
831+
)
832+
833+
# Check var is a constant expression with the same value as x.type.shape[i]
834+
if constant_var == x.type.shape[i]:
835+
return True
836+
837+
# Match shape_of[x][i] or its constant equivalent
838+
if shape_feature is not None:
839+
i_shape_of_x = shape_feature.get_shape(x, i)
840+
if i_shape_of_x == var or (
841+
isinstance(i_shape_of_x, Constant) and (i_shape_of_x.data == constant_var)
842+
):
843+
return True
844+
845+
if var.owner is None:
846+
# No more constant possibilities
847+
return False
848+
849+
# Match Shape_i{i}(x)
850+
if isinstance(var.owner.op, Shape_i):
851+
return (var.owner.op.i == i) and (var.owner.inputs[0] == x)
852+
853+
# Match Subtensor((ScalarType,))(Shape(input), i)
854+
if isinstance(var.owner.op, Subtensor):
855+
return (
856+
# Check we have integer indexing operation
857+
# (and not slice or multiple indexing)
858+
len(var.owner.op.idx_list) == 1
859+
and isinstance(var.owner.op.idx_list[0], ScalarType)
860+
# Check we are indexing on the shape of x
861+
and var.owner.inputs[0].owner is not None
862+
and isinstance(var.owner.inputs[0].owner.op, Shape)
863+
and var.owner.inputs[0].owner.inputs[0] == x
864+
# Check that index == i
865+
and (
866+
get_scalar_constant_value(var.owner.inputs[1], raise_not_constant=False)
867+
== i
868+
)
869+
)
870+
871+
return False
872+
873+
874+
def _unpack_shape_vector(shape: TensorVariable) -> tuple[TensorVariable, ...] | None:
875+
"""Return the elements of a symbolic vector representing a shape.
876+
877+
Handles the most common constant vector or make_vector cases.
878+
879+
Returns tuple(shape) as fallback.
880+
"""
881+
if isinstance(shape, Constant):
882+
return tuple(as_tensor_variable(dim, ndim=0) for dim in shape.data)
883+
elif shape.owner and isinstance(shape.owner.op, MakeVector):
884+
return tuple(shape.owner.inputs)
885+
else:
886+
return tuple(shape)
790887

791888

792889
@register_useless("shape_unsafe")
@@ -821,86 +918,29 @@ def local_useless_reshape(fgraph, node):
821918
if shape_input == inp:
822919
return [inp]
823920

824-
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
825-
# broadcastable and constant dimensions
826-
if isinstance(output_shape, Constant) or (
827-
output_shape.owner and isinstance(output_shape.owner.op, MakeVector)
828-
):
829-
if isinstance(output_shape, Constant):
830-
output_shape_is = [
831-
as_tensor_variable(dim, ndim=0) for dim in output_shape.data
832-
]
833-
else:
834-
output_shape_is = output_shape.owner.inputs
835-
836-
shape_feature = getattr(fgraph, "shape_feature", None)
837-
838-
nb_m1 = 0
839-
shape_match = [False] * inp.type.ndim
840-
for dim in range(inp.type.ndim):
841-
outshp_i = output_shape_is[dim]
842-
# Match Shape_i{dim}(input)
843-
if (
844-
outshp_i.owner
845-
and isinstance(outshp_i.owner.op, Shape_i)
846-
and outshp_i.owner.op.i == dim
847-
and outshp_i.owner.inputs[0] == inp
848-
):
849-
shape_match[dim] = True
850-
continue
921+
shape_feature = getattr(fgraph, "shape_feature", None)
851922

852-
# Match Shape(input)[dim]
853-
if (
854-
outshp_i.owner
855-
and isinstance(outshp_i.owner.op, Subtensor)
856-
and len(outshp_i.owner.inputs) == 2
857-
and get_scalar_constant_value(
858-
outshp_i.owner.inputs[1], raise_not_constant=False
859-
)
860-
== dim
861-
):
862-
subtensor_inp = outshp_i.owner.inputs[0]
863-
if subtensor_inp.owner and isinstance(subtensor_inp.owner.op, Shape):
864-
shape_input_i = subtensor_inp.owner.inputs[0]
865-
if shape_input_i == inp:
866-
shape_match[dim] = True
867-
continue
868-
869-
# Match constant if input.type.shape[dim] == constant
870-
cst_outshp_i = get_scalar_constant_value(
871-
outshp_i, only_process_constants=True, raise_not_constant=False
872-
)
873-
if inp.type.shape[dim] == cst_outshp_i:
874-
shape_match[dim] = True
875-
continue
876-
877-
# Match -1
878-
if cst_outshp_i == -1:
879-
shape_match[dim] = True
880-
nb_m1 += 1
881-
continue
923+
# Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
924+
# or cases where all but one dimension are provably preserved
925+
output_shape_is = _unpack_shape_vector(output_shape)
882926

883-
# Match shape_of[input][dim] or its constant equivalent
884-
if shape_feature:
885-
inpshp_i = shape_feature.get_shape(inp, dim)
886-
if inpshp_i == outshp_i or (
887-
get_scalar_constant_value(
888-
inpshp_i, only_process_constants=True, raise_not_constant=False
889-
)
890-
== get_scalar_constant_value(
891-
outshp_i, only_process_constants=True, raise_not_constant=False
892-
)
893-
):
894-
shape_match[dim] = True
895-
continue
927+
nb_m1 = 0
928+
shape_match = [False] * inp.type.ndim
929+
for dim in range(inp.type.ndim):
930+
outshp_i = output_shape_is[dim]
931+
if _is_shape_i_of_x(outshp_i, inp, dim, shape_feature=shape_feature):
932+
shape_match[dim] = True
933+
elif isinstance(outshp_i, Constant) and outshp_i.data == -1:
934+
shape_match[dim] = True
935+
nb_m1 += 1
896936

897-
if nb_m1 <= 1 and all(shape_match):
898-
return [inp]
937+
if nb_m1 <= 1 and all(shape_match):
938+
return [inp]
899939

900-
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
901-
return [inp]
940+
if (nb_m1 == 0) and (shape_match.count(False) == output.type.ndim - 1):
941+
return [inp]
902942

903-
return False
943+
return False
904944

905945

906946
@register_canonicalize
@@ -914,39 +954,26 @@ def local_reshape_to_dimshuffle(fgraph, node):
914954
915955
For example:
916956
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
917-
- reshape(x, (1, m, 1, n, 1, 1))
918-
-> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
957+
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
919958
"""
920-
op = node.op
921959
inp, output_shape = node.inputs
922960
[output] = node.outputs
923961

924-
dimshuffle_new_order = []
962+
unpacked_shape = _unpack_shape_vector(output_shape)
963+
expand_axes = []
925964
new_output_shape = []
926-
index = 0 # index over the output of the new reshape
927-
for i in range(output.ndim):
928-
# Since output_shape is a symbolic vector, we trust get_scalar_constant_value
929-
# to go through however it is formed to see if its i-th element is 1.
930-
# We need only_process_constants=False for that.
931-
dim = get_scalar_constant_value(
932-
output_shape[i],
933-
only_process_constants=False,
934-
elemwise=False,
935-
raise_not_constant=False,
936-
)
937-
if dim == 1:
938-
dimshuffle_new_order.append("x")
965+
for i, dim in enumerate(unpacked_shape):
966+
if isinstance(dim, Constant) and dim.data == 1:
967+
expand_axes.append(i)
939968
else:
940-
dimshuffle_new_order.append(index)
941969
new_output_shape.append(dim)
942-
index = index + 1
943970

944-
if index != output.type.ndim:
945-
inner = op.__class__(len(new_output_shape))(inp, new_output_shape)
971+
if len(new_output_shape) != output.type.ndim:
972+
inner = inp.reshape(new_output_shape)
946973
copy_stack_trace(output, inner)
947-
new_node = [inner.dimshuffle(dimshuffle_new_order)]
948-
copy_stack_trace(output, new_node)
949-
return new_node
974+
new_out = expand_dims(inner, expand_axes)
975+
copy_stack_trace(output, new_out)
976+
return [new_out]
950977

951978

952979
@register_canonicalize
@@ -1186,44 +1213,6 @@ def local_track_shape_i(fgraph, node):
11861213
return [shape_feature.shape_of[replacement][node.op.i]]
11871214

11881215

1189-
@register_canonicalize
1190-
@node_rewriter([Reshape])
1191-
def local_useless_dimshuffle_in_reshape(fgraph, node):
1192-
"""
1193-
Removes useless DimShuffle operation inside Reshape:
1194-
1195-
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
1196-
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
1197-
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
1198-
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
1199-
1200-
"""
1201-
op = node.op
1202-
if not isinstance(op, Reshape):
1203-
return False
1204-
if not (
1205-
node.inputs[0].owner is not None
1206-
and isinstance(node.inputs[0].owner.op, DimShuffle)
1207-
):
1208-
return False
1209-
1210-
new_order = node.inputs[0].owner.op.new_order
1211-
inp = node.inputs[0].owner.inputs[0]
1212-
new_order_of_nonbroadcast = []
1213-
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
1214-
if s != 1:
1215-
new_order_of_nonbroadcast.append(i)
1216-
no_change_in_order = all(
1217-
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
1218-
for i in range(len(new_order_of_nonbroadcast) - 1)
1219-
)
1220-
if no_change_in_order:
1221-
shape = node.inputs[1]
1222-
ret = op.__class__(node.outputs[0].ndim)(inp, shape)
1223-
copy_stack_trace(node.outputs[0], ret)
1224-
return [ret]
1225-
1226-
12271216
@register_useless
12281217
@register_canonicalize
12291218
@register_specialize

0 commit comments

Comments
 (0)