|
36 | 36 | register_useless,
|
37 | 37 | topo_constant_folding,
|
38 | 38 | )
|
| 39 | +from pytensor.tensor.rewriting.elemwise import apply_local_dimshuffle_lift |
39 | 40 | from pytensor.tensor.shape import (
|
40 | 41 | Reshape,
|
41 | 42 | Shape,
|
@@ -757,40 +758,36 @@ def apply(self, fgraph):
|
757 | 758 | pytensor.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), position=10)
|
758 | 759 |
|
759 | 760 |
|
| 761 | +@register_useless |
760 | 762 | @register_canonicalize
|
761 | 763 | @node_rewriter([Reshape])
|
762 |
| -def local_useless_dimshuffle_in_reshape(fgraph, node): |
| 764 | +def local_useless_expand_dims_in_reshape(fgraph, node): |
763 | 765 | """
|
764 |
| - Removes useless DimShuffle operation inside Reshape: |
765 |
| -
|
766 |
| - reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp) |
767 |
| - reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp) |
768 |
| - reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp) |
769 |
| - reshape(col.dimshuffle(0), shp) => reshape(col, shp) |
| 766 | + Removes useless expand_dims `DimShuffle` operations inside Reshape: |
| 767 | + reshape(expand_dims(vector, axis=0), shp) => reshape(vector, shp) |
| 768 | + reshape(expand_dims(matrix, axis=(0, 2), shp) => reshape(matrix, shp) |
770 | 769 |
|
| 770 | + Implicit (and useless) squeezes are kept in the graph, as they are |
| 771 | + part of the canonical form of the graph. |
771 | 772 | """
|
772 |
| - dimshuffled_x, new_shape = node.inputs |
| 773 | + expanded_x, new_shape = node.inputs |
773 | 774 |
|
774 | 775 | if not (
|
775 |
| - dimshuffled_x.owner is not None |
776 |
| - and isinstance(dimshuffled_x.owner.op, DimShuffle) |
| 776 | + expanded_x.owner is not None |
| 777 | + and isinstance(expanded_x.owner.op, DimShuffle) |
| 778 | + and expanded_x.owner.op.augment |
777 | 779 | ):
|
778 | 780 | return False
|
779 | 781 |
|
780 |
| - [inp] = dimshuffled_x.owner.inputs |
781 |
| - new_order = dimshuffled_x.owner.op.new_order |
782 |
| - new_order_of_nonbroadcast = [] |
783 |
| - for i, s in zip(new_order, node.inputs[0].type.shape, strict=True): |
784 |
| - if s != 1: |
785 |
| - new_order_of_nonbroadcast.append(i) |
786 |
| - no_change_in_order = all( |
787 |
| - new_order_of_nonbroadcast[i] <= new_order_of_nonbroadcast[i + 1] |
788 |
| - for i in range(len(new_order_of_nonbroadcast) - 1) |
789 |
| - ) |
790 |
| - if no_change_in_order: |
791 |
| - ret = inp.reshape(new_shape) |
792 |
| - copy_stack_trace(node.outputs[0], ret) |
793 |
| - return [ret] |
| 782 | + [x] = expanded_x.owner.inputs |
| 783 | + |
| 784 | + new_order = tuple(o for o in expanded_x.owner.op.new_order if o != "x") |
| 785 | + if new_order != tuple(range(x.type.ndim)): |
| 786 | + x = x.dimshuffle(new_order) |
| 787 | + |
| 788 | + new_reshaped_x = x.reshape(new_shape) |
| 789 | + copy_stack_trace(node.outputs[0], new_reshaped_x) |
| 790 | + return [new_reshaped_x] |
794 | 791 |
|
795 | 792 |
|
796 | 793 | @register_canonicalize("shape_unsafe")
|
@@ -944,39 +941,113 @@ def local_useless_reshape(fgraph, node):
|
944 | 941 | return False
|
945 | 942 |
|
946 | 943 |
|
947 |
| -@register_canonicalize |
| 944 | +@register_canonicalize("shape_unsafe") |
948 | 945 | @node_rewriter([Reshape])
|
949 | 946 | def local_reshape_to_dimshuffle(fgraph, node):
|
950 |
| - r"""Replace broadcastable dimensions in `Reshape` nodes with `DimShuffle`\s. |
951 |
| -
|
952 |
| - The goal is to avoid using `Reshape` to add or remove broadcastable |
953 |
| - dimensions, and to use `DimShuffle` instead, since `DimShuffle`\s can |
954 |
| - cancel out and/or be removed later on. |
| 947 | + r"""Remove `Reshape` operations over length-1 (broadcastable) dimensions. |
955 | 948 |
|
956 | 949 | For example:
|
957 |
| - - reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,)) |
958 |
| - - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n))) |
| 950 | + - reshape(col, (m, n)) -> reshape(squeeze(col, axis=1), (m, n)) |
| 951 | + - reshape(col, (1, m, n)) -> expand_dims(reshape(squeeze(col, axis=1), (m, n)), axis=0) |
| 952 | + - reshape(x, (1, m, 1, n, 1, 1)) -> expand_dims(reshape(x, (m, n)), axis=(0, 2, 4, 5)) |
| 953 | +
|
959 | 954 | """
|
960 | 955 | inp, output_shape = node.inputs
|
961 | 956 | [output] = node.outputs
|
962 | 957 |
|
963 |
| - unpacked_shape = _unpack_shape_vector(output_shape) |
964 |
| - expand_axes = [] |
965 |
| - new_output_shape = [] |
966 |
| - for i, dim in enumerate(unpacked_shape): |
967 |
| - if isinstance(dim, Constant) and dim.data == 1: |
968 |
| - expand_axes.append(i) |
969 |
| - else: |
970 |
| - new_output_shape.append(dim) |
| 958 | + # Remove any broadcastable dimensions from the input |
| 959 | + squeeze_axes = [i for i, bcast in enumerate(inp.type.broadcastable) if bcast] |
| 960 | + |
| 961 | + # Trivial case, all dimensions of input/output are known to be broadcastable: |
| 962 | + # there's nothing to reshape |
| 963 | + if all(inp.type.broadcastable) or all(output.type.broadcastable): |
| 964 | + new_output_shape = [] |
| 965 | + expand_axes = tuple(range(output.type.ndim)) |
| 966 | + |
| 967 | + else: |
| 968 | + unpacked_shape = _unpack_shape_vector(output_shape) |
| 969 | + new_output_shape = [] |
| 970 | + expand_axes = [] |
| 971 | + for i, dim in enumerate(unpacked_shape): |
| 972 | + if isinstance(dim, Constant) and ( |
| 973 | + dim.data == 1 |
| 974 | + # -1 can be an implicit expand_dims, but it's tricky to prove |
| 975 | + # as we would need to check whether all other dimensions |
| 976 | + # already explain the full size of the array. We rely on the output |
| 977 | + # static shape which will have figure it out for some (but not all) cases |
| 978 | + or (dim.data == -1 and output.type.shape[i] == 1) |
| 979 | + ): |
| 980 | + expand_axes.append(i) |
| 981 | + else: |
| 982 | + new_output_shape.append(dim) |
| 983 | + |
| 984 | + if squeeze_axes or expand_axes: |
| 985 | + new_out = inp.squeeze(squeeze_axes) |
| 986 | + |
| 987 | + if new_output_shape: |
| 988 | + new_out = new_out.reshape(new_output_shape) |
| 989 | + copy_stack_trace(output, new_out) |
| 990 | + |
| 991 | + new_out = expand_dims(new_out, expand_axes) |
| 992 | + |
| 993 | + if not new_output_shape: |
| 994 | + # Eagerly merge consecutive squeeze and expand_dims |
| 995 | + new_out = apply_local_dimshuffle_lift(fgraph, new_out) |
971 | 996 |
|
972 |
| - if len(new_output_shape) != output.type.ndim: |
973 |
| - inner = inp.reshape(new_output_shape) |
974 |
| - copy_stack_trace(output, inner) |
975 |
| - new_out = expand_dims(inner, expand_axes) |
976 | 997 | copy_stack_trace(output, new_out)
|
977 | 998 | return [new_out]
|
978 | 999 |
|
979 | 1000 |
|
| 1001 | +@register_specialize |
| 1002 | +@node_rewriter([Reshape]) |
| 1003 | +def local_fuse_squeeze_reshape(fgraph, node): |
| 1004 | + r"""If there is a squeeze right before a reshape, merge them. |
| 1005 | +
|
| 1006 | + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. |
| 1007 | + """ |
| 1008 | + x, new_shape = node.inputs |
| 1009 | + |
| 1010 | + if ( |
| 1011 | + x.owner is not None |
| 1012 | + and isinstance(x.owner.op, DimShuffle) |
| 1013 | + and x.owner.op.is_squeeze |
| 1014 | + ): |
| 1015 | + # A reshape can always subsume a squeeze. |
| 1016 | + x = x.owner.inputs[0] |
| 1017 | + return [x.reshape(new_shape)] |
| 1018 | + |
| 1019 | + |
| 1020 | +@register_specialize |
| 1021 | +@node_rewriter([DimShuffle]) |
| 1022 | +def local_fuse_expand_dims_reshape(fgraph, node): |
| 1023 | + r"""If there is an expand_dims right after a reshape, merge them. |
| 1024 | +
|
| 1025 | + This undoes the effect of `local_reshape_to_dimshuffle` that is applied during canonicalization. |
| 1026 | + """ |
| 1027 | + if not node.op.is_expand_dims: |
| 1028 | + return None |
| 1029 | + |
| 1030 | + reshaped_x = node.inputs[0] |
| 1031 | + |
| 1032 | + if not (reshaped_x.owner and isinstance(reshaped_x.owner.op, Reshape)): |
| 1033 | + return None |
| 1034 | + |
| 1035 | + if len(fgraph.clients[reshaped_x]) > 1: |
| 1036 | + # The reshape is used elsewhere, don't fuse as it can sometimes require a copy. |
| 1037 | + return None |
| 1038 | + |
| 1039 | + x, new_shape = reshaped_x.owner.inputs |
| 1040 | + |
| 1041 | + # Add expand_dims to shape |
| 1042 | + new_shape = list(_unpack_shape_vector(new_shape)) |
| 1043 | + for i in node.op.augment: |
| 1044 | + new_shape.insert(i, 1) |
| 1045 | + |
| 1046 | + new_reshaped_x = x.reshape(new_shape) |
| 1047 | + copy_stack_trace(node.outputs[0], new_reshaped_x) |
| 1048 | + return [new_reshaped_x] |
| 1049 | + |
| 1050 | + |
980 | 1051 | @register_canonicalize
|
981 | 1052 | @register_specialize
|
982 | 1053 | @node_rewriter([Reshape])
|
|
0 commit comments