Skip to content

Commit 5f3d19a

Browse files
committed
Canonicalize squeeze out of reshape and specialize back
1 parent f4ad2f3 commit 5f3d19a

File tree

4 files changed

+182
-54
lines changed

4 files changed

+182
-54
lines changed

pytensor/tensor/rewriting/shape.py

Lines changed: 115 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
register_useless,
3737
topo_constant_folding,
3838
)
39+
from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift
3940
from pytensor.tensor.shape import (
4041
Reshape,
4142
Shape,
@@ -757,40 +758,36 @@ def apply(self, fgraph):
757758
pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
758759

759760

761+
@register_useless
760762
@register_canonicalize
761763
@node_rewriter([Reshape])
762-
def local_useless_dimshuffle_in_reshape(fgraph, node):
764+
def local_useless_expand_dims_in_reshape(fgraph, node):
763765
"""
764-
Removes useless DimShuffle operation inside Reshape:
765-
766-
reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767-
reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768-
reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769-
reshape(col.dimshuffle(0), shp) => reshape(col, shp)
766+
Removes useless expand_dims `DimShuffle` operations inside Reshape:
767+
reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp)
768+
reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp)
770769
770+
Implicit (and useless) squeezes are kept in the graph, as they are
771+
part of the canonical form of the graph.
771772
"""
772-
dimshuffled_x, new_shape = node.inputs
773+
expanded_x, new_shape = node.inputs
773774

774775
if not (
775-
dimshuffled_x.owner is not None
776-
and isinstance(dimshuffled_x.owner.op, DimShuffle)
776+
expanded_x.owner is not None
777+
and isinstance(expanded_x.owner.op, DimShuffle)
778+
and expanded_x.owner.op.augment
777779
):
778780
return False
779781

780-
[inp] = dimshuffled_x.owner.inputs
781-
new_order = dimshuffled_x.owner.op.new_order
782-
new_order_of_nonbroadcast = []
783-
for i, s in zip(new_order, node.inputs[0].type.shape, strict=True):
784-
if s != 1:
785-
new_order_of_nonbroadcast.append(i)
786-
no_change_in_order = all(
787-
new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1]
788-
for i in range(len(new_order_of_nonbroadcast) - 1)
789-
)
790-
if no_change_in_order:
791-
ret = inp.reshape(new_shape)
792-
copy_stack_trace(node.outputs[0], ret)
793-
return [ret]
782+
[x] = expanded_x.owner.inputs
783+
784+
new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x")
785+
if new_order != tuple(range(x.type.ndim)):
786+
x = x.dimshuffle(new_order)
787+
788+
new_reshaped_x = x.reshape(new_shape)
789+
copy_stack_trace(node.outputs[0], new_reshaped_x)
790+
return [new_reshaped_x]
794791

795792

796793
@register_canonicalize("shape_unsafe")
@@ -943,39 +940,113 @@ def local_useless_reshape(fgraph, node):
943940
return False
944941

945942

946-
@register_canonicalize
943+
@register_canonicalize("shape_unsafe")
947944
@node_rewriter([Reshape])
948945
def local_reshape_to_dimshuffle(fgraph, node):
949-
r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s.
950-
951-
The goal is to avoid using `Reshape` to add or remove broadcastable
952-
dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can
953-
cancel out and/or be removed later on.
946+
r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions.
954947
955948
For example:
956-
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
957-
- reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
949+
- reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n))
950+
- reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0)
951+
- reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5))
952+
958953
"""
959954
inp, output_shape = node.inputs
960955
[output] = node.outputs
961956

962-
unpacked_shape = _unpack_shape_vector(output_shape)
963-
expand_axes = []
964-
new_output_shape = []
965-
for i, dim in enumerate(unpacked_shape):
966-
if isinstance(dim, Constant) and dim.data == 1:
967-
expand_axes.append(i)
968-
else:
969-
new_output_shape.append(dim)
957+
# Remove any broadcastable dimensions from the input
958+
squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast]
959+
960+
# Trivial case, all dimensions of input/output are known to be broadcastable:
961+
# there's nothing to reshape
962+
if all(inp.type.broadcastable) or all(output.type.broadcastable):
963+
new_output_shape = []
964+
expand_axes = tuple(range(output.type.ndim))
965+
966+
else:
967+
unpacked_shape = _unpack_shape_vector(output_shape)
968+
new_output_shape = []
969+
expand_axes = []
970+
for i, dim in enumerate(unpacked_shape):
971+
if isinstance(dim, Constant) and (
972+
dim.data == 1
973+
# -1 can be an implicit expand_dims, but it's tricky to prove
974+
# as we would need to check whether all other dimensions
975+
# already explain the full size of the array. We rely on the output
976+
# static shape which will have figure it out for some (but not all) cases
977+
or (dim.data == -1 and output.type.shape[i] == 1)
978+
):
979+
expand_axes.append(i)
980+
else:
981+
new_output_shape.append(dim)
982+
983+
if squeeze_axes or expand_axes:
984+
new_out = inp.squeeze(squeeze_axes)
985+
986+
if new_output_shape:
987+
new_out = new_out.reshape(new_output_shape)
988+
copy_stack_trace(output, new_out)
989+
990+
new_out = expand_dims(new_out, expand_axes)
991+
992+
if not new_output_shape:
993+
# Eagerly merge consecutive squeeze and expand_dims
994+
new_out = apply_local_dimshuffle_lift(fgraph, new_out)
970995

971-
if len(new_output_shape) != output.type.ndim:
972-
inner = inp.reshape(new_output_shape)
973-
copy_stack_trace(output, inner)
974-
new_out = expand_dims(inner, expand_axes)
975996
copy_stack_trace(output, new_out)
976997
return [new_out]
977998

978999

1000+
@register_specialize
1001+
@node_rewriter([Reshape])
1002+
def local_fuse_squeeze_reshape(fgraph, node):
1003+
r"""If there is a squeeze right before a reshape, merge them.
1004+
1005+
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1006+
"""
1007+
x, new_shape = node.inputs
1008+
1009+
if (
1010+
x.owner is not None
1011+
and isinstance(x.owner.op, DimShuffle)
1012+
and x.owner.op.is_squeeze
1013+
):
1014+
# A reshape can always subsume a squeeze.
1015+
x = x.owner.inputs[0]
1016+
return [x.reshape(new_shape)]
1017+
1018+
1019+
@register_specialize
1020+
@node_rewriter([DimShuffle])
1021+
def local_fuse_expand_dims_reshape(fgraph, node):
1022+
r"""If there is an expand_dims right after a reshape, merge them.
1023+
1024+
This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization.
1025+
"""
1026+
if not node.op.is_expand_dims:
1027+
return None
1028+
1029+
reshaped_x = node.inputs[0]
1030+
1031+
if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)):
1032+
return None
1033+
1034+
if len(fgraph.clients[reshaped_x]) > 1:
1035+
# The reshape is used elsewhere, don't fuse as it can sometimes require a copy.
1036+
return None
1037+
1038+
x, new_shape = reshaped_x.owner.inputs
1039+
1040+
# Add expand_dims to shape
1041+
new_shape = list(_unpack_shape_vector(new_shape))
1042+
for i in node.op.augment:
1043+
new_shape.insert(i, 1)
1044+
1045+
new_reshaped_x = x.reshape(new_shape)
1046+
copy_stack_trace(node.outputs[0], new_reshaped_x)
1047+
return [new_reshaped_x]
1048+
1049+
9791050
@register_canonicalize
9801051
@register_specialize
9811052
@node_rewriter([Reshape])

tests/tensor/rewriting/test_basic.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,6 @@ def test_basic_tile(self):
332332

333333
mode = rewrite_mode.including(
334334
"local_dimshuffle_lift",
335-
"local_useless_dimshuffle_in_reshape",
336335
"local_alloc_sink_dimshuffle",
337336
)
338337
f = function([x], [y], mode=mode)

tests/tensor/rewriting/test_elemwise.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@
5656
from pytensor.tensor.math import round as pt_round
5757
from pytensor.tensor.math import sum as pt_sum
5858
from pytensor.tensor.rewriting.elemwise import FusionOptimizer, local_dimshuffle_lift
59-
from pytensor.tensor.rewriting.shape import local_useless_dimshuffle_in_reshape
59+
from pytensor.tensor.rewriting.shape import (
60+
local_fuse_squeeze_reshape,
61+
local_useless_expand_dims_in_reshape,
62+
)
6063
from pytensor.tensor.shape import reshape
6164
from pytensor.tensor.type import (
6265
TensorType,
@@ -182,7 +185,7 @@ def test_dimshuffle_lift_multi_out_elemwise(self):
182185
assert not local_dimshuffle_lift.transform(g, g.outputs[0].owner)
183186

184187

185-
def test_local_useless_dimshuffle_in_reshape():
188+
def test_local_useless_expand_dims_in_reshape():
186189
vec = TensorType(dtype="float64", shape=(None,))("vector")
187190
mat = TensorType(dtype="float64", shape=(None, None))("mat")
188191
row = TensorType(dtype="float64", shape=(1, None))("row")
@@ -204,7 +207,11 @@ def test_local_useless_dimshuffle_in_reshape():
204207
clone=False,
205208
)
206209
assert len(g.apply_nodes) == 4 * 3
207-
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
210+
useless_dimshuffle_in_reshape = out2in(
211+
local_useless_expand_dims_in_reshape,
212+
# Useless squeeze in reshape is not a canonicalization anymore
213+
local_fuse_squeeze_reshape,
214+
)
208215
useless_dimshuffle_in_reshape.rewrite(g)
209216
assert equal_computations(
210217
g.outputs,
@@ -218,15 +225,12 @@ def test_local_useless_dimshuffle_in_reshape():
218225
# Check stacktrace was copied over correctly after rewrite was applied
219226
assert check_stack_trace(g, ops_to_check="all")
220227

221-
# Check that the rewrite does not get applied when the order
222-
# of dimensions has changed.
228+
# Check that the rewrite does not mess meaningful transpositions before the reshape
223229
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
224230
h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
225231
assert len(h.apply_nodes) == 3
226232
useless_dimshuffle_in_reshape.rewrite(h)
227-
assert equal_computations(
228-
h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
229-
)
233+
assert equal_computations(h.outputs, [reshape(mat.dimshuffle(1, 0), mat.shape)])
230234

231235

232236
class TestFusion:

tests/tensor/rewriting/test_shape.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor.tensor as pt
77
from pytensor import shared
88
from pytensor.compile.function import function
9-
from pytensor.compile.mode import get_default_mode, get_mode
9+
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1010
from pytensor.compile.ops import deep_copy_op
1111
from pytensor.configdefaults import config
1212
from pytensor.graph.basic import Apply, Variable, equal_computations
@@ -419,6 +419,60 @@ def test_basic(self):
419419

420420
assert check_stack_trace(g, ops_to_check=(DimShuffle, Reshape))
421421

422+
def test_expand_dims(self):
423+
x = pt.scalar()
424+
# This reshape does an implicit expand_dims
425+
out = x.reshape((1, -1))
426+
assert isinstance(out.owner.op, Reshape)
427+
new_out = rewrite_graph(out, include=("canonicalize",))
428+
assert equal_computations([new_out], [pt.expand_dims(x, (0, 1))])
429+
430+
def test_squeeze_of_alloc(self):
431+
# This shows up in the graph of repeat
432+
x = pt.vector("x", shape=(9,))
433+
bcast_x = pt.alloc(x, 1, 12, x.shape[0])
434+
435+
# This reshape does an implicit squeeze
436+
out = bcast_x.reshape((12, x.shape[0]))
437+
438+
new_out = rewrite_graph(out, include=("canonicalize", "ShapeOpt"))
439+
assert equal_computations([new_out], [pt.alloc(x, 12, 9)], strict_dtype=False)
440+
441+
442+
def test_expand_dims_squeeze_reshape_fusion():
443+
x = pt.tensor("x", shape=(1, 9))
444+
reshape_x = x.squeeze(0).reshape((3, 3))[..., None]
445+
446+
assert isinstance(reshape_x.owner.op, DimShuffle)
447+
assert isinstance(reshape_x.owner.inputs[0].owner.op, Reshape)
448+
assert isinstance(reshape_x.owner.inputs[0].owner.inputs[0].owner.op, DimShuffle)
449+
450+
out = rewrite_graph(reshape_x, include=("specialize",))
451+
452+
# In this case we cannot get rid of the reshape, squeeze or expand_dims,
453+
# so we fuse them all in one reshape
454+
assert equal_computations([out], [x.reshape((3, 3, 1))])
455+
456+
457+
def test_implicit_broadcasting_via_repeat():
458+
x = pt.vector("x", shape=(3,), dtype=int)
459+
y = pt.vector("y", shape=(9,), dtype=int)
460+
out = x[None, :].repeat(9, axis=0) <= y[:, None].repeat(3, axis=1)
461+
# There are two Reshapes in the graph
462+
assert isinstance(out.owner.inputs[0].owner.op, Reshape)
463+
assert isinstance(out.owner.inputs[1].owner.op, Reshape)
464+
465+
new_out = rewrite_graph(out, include=("canonicalize", "specialize"))
466+
assert equal_computations([new_out], [x[None] <= y[:, None]])
467+
468+
no_rewrite_mode = Mode(linker="py", optimizer=None)
469+
x_test = np.arange(3) + 1
470+
y_test = np.arange(9)
471+
np.testing.assert_allclose(
472+
new_out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
473+
out.eval({x: x_test, y: y_test}, mode=no_rewrite_mode),
474+
)
475+
422476

423477
def test_local_reshape_lift():
424478
x = tensor4()

0 commit comments

Comments
 (0)