12
12
from pytensor .graph .fg import FunctionGraph
13
13
from pytensor .graph .rewriting .basic import (
14
14
GraphRewriter ,
15
- check_chain ,
16
15
copy_stack_trace ,
17
16
node_rewriter ,
18
17
)
19
18
from pytensor .graph .utils import InconsistencyError , get_variable_trace_string
19
+ from pytensor .scalar import ScalarType
20
20
from pytensor .tensor .basic import (
21
21
MakeVector ,
22
22
as_tensor_variable ,
23
23
cast ,
24
24
constant ,
25
+ expand_dims ,
25
26
get_scalar_constant_value ,
26
27
register_infer_shape ,
27
28
stack ,
47
48
from pytensor .tensor .subtensor import Subtensor , get_idx_list
48
49
from pytensor .tensor .type import TensorType , discrete_dtypes , integer_dtypes
49
50
from pytensor .tensor .type_other import NoneConst , NoneTypeT
51
+ from pytensor .tensor .variable import TensorVariable
50
52
51
53
52
54
class ShapeFeature (Feature ):
@@ -755,6 +757,42 @@ def apply(self, fgraph):
755
757
pytensor .compile .mode .optdb .register ("UnShapeOpt" , UnShapeOptimizer (), position = 10 )
756
758
757
759
760
+ @register_canonicalize
761
+ @node_rewriter ([Reshape ])
762
+ def local_useless_dimshuffle_in_reshape (fgraph , node ):
763
+ """
764
+ Removes useless DimShuffle operation inside Reshape:
765
+
766
+ reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
767
+ reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
768
+ reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
769
+ reshape(col.dimshuffle(0), shp) => reshape(col, shp)
770
+
771
+ """
772
+ dimshuffled_x , new_shape = node .inputs
773
+
774
+ if not (
775
+ dimshuffled_x .owner is not None
776
+ and isinstance (dimshuffled_x .owner .op , DimShuffle )
777
+ ):
778
+ return False
779
+
780
+ [inp ] = dimshuffled_x .owner .inputs
781
+ new_order = dimshuffled_x .owner .op .new_order
782
+ new_order_of_nonbroadcast = []
783
+ for i , s in zip (new_order , node .inputs [0 ].type .shape , strict = True ):
784
+ if s != 1 :
785
+ new_order_of_nonbroadcast .append (i )
786
+ no_change_in_order = all (
787
+ new_order_of_nonbroadcast [i ] <= new_order_of_nonbroadcast [i + 1 ]
788
+ for i in range (len (new_order_of_nonbroadcast ) - 1 )
789
+ )
790
+ if no_change_in_order :
791
+ ret = inp .reshape (new_shape )
792
+ copy_stack_trace (node .outputs [0 ], ret )
793
+ return [ret ]
794
+
795
+
758
796
@register_canonicalize ("shape_unsafe" )
759
797
@register_specialize ("shape_unsafe" )
760
798
@node_rewriter ([Reshape ])
@@ -763,30 +801,89 @@ def local_reshape_chain(fgraph, node):
763
801
Reshape(Reshape(x, shape1),shape2) -> Reshape(x, shape2)
764
802
765
803
"""
766
- if not check_chain (node , Reshape , Reshape ):
804
+ inner_reshape , final_shape = node .inputs
805
+
806
+ if not (inner_reshape .owner and isinstance (inner_reshape .owner .op , Reshape )):
807
+ return None
808
+
809
+ x , _ = inner_reshape .owner .inputs
810
+ new_reshape = node .op (x , final_shape )
811
+
812
+ copy_stack_trace (node .outputs , new_reshape )
813
+ return [new_reshape ]
814
+
815
+
816
+ def _is_shape_i_of_x (
817
+ var : TensorVariable ,
818
+ x : TensorVariable ,
819
+ i : int ,
820
+ shape_feature : ShapeFeature | None = None ,
821
+ ) -> bool :
822
+ if var .type .ndim != 0 :
767
823
return False
768
824
769
- rval = node .op (node .inputs [0 ].owner .inputs [0 ], node .inputs [1 ])
770
-
771
- # Copy over stacktrace from previous output node, as any error
772
- # in new computational graph would have been caused by last op
773
- # in the old computational graph.
774
- copy_stack_trace (node .outputs , rval )
775
-
776
- # It might happen that the desired output of this node has a
777
- # broadcastable pattern that does not match that of 'rval'. This is
778
- # when originally, we were able to figure out that one of the
779
- # dimensions of the reshape is one, but some other transformation
780
- # replaced the shape by one for which this cannot be guessed.
781
- # We should try to figure out why we lost the information about this
782
- # constant value... but in the meantime, better not apply this
783
- # rewrite.
784
- if rval .type .ndim == node .outputs [0 ].type .ndim and all (
785
- s1 == s2
786
- for s1 , s2 in zip (rval .type .shape , node .outputs [0 ].type .shape , strict = True )
787
- if s1 == 1 or s2 == 1
788
- ):
789
- return [rval ]
825
+ constant_var = get_scalar_constant_value (
826
+ var ,
827
+ only_process_constants = False ,
828
+ # Don't go through Elemwise to keep things fast
829
+ elemwise = False ,
830
+ raise_not_constant = False ,
831
+ )
832
+
833
+ # Check var is a constant expression with the same value as x.type.shape[i]
834
+ if constant_var == x .type .shape [i ]:
835
+ return True
836
+
837
+ # Match shape_of[x][i] or its constant equivalent
838
+ if shape_feature is not None :
839
+ i_shape_of_x = shape_feature .get_shape (x , i )
840
+ if i_shape_of_x == var or (
841
+ isinstance (i_shape_of_x , Constant ) and (i_shape_of_x .data == constant_var )
842
+ ):
843
+ return True
844
+
845
+ if var .owner is None :
846
+ # No more constant possibilities
847
+ return False
848
+
849
+ # Match Shape_i{i}(x)
850
+ if isinstance (var .owner .op , Shape_i ):
851
+ return (var .owner .op .i == i ) and (var .owner .inputs [0 ] == x )
852
+
853
+ # Match Subtensor((ScalarType,))(Shape(input), i)
854
+ if isinstance (var .owner .op , Subtensor ):
855
+ return (
856
+ # Check we have integer indexing operation
857
+ # (and not slice or multiple indexing)
858
+ len (var .owner .op .idx_list ) == 1
859
+ and isinstance (var .owner .op .idx_list [0 ], ScalarType )
860
+ # Check we are indexing on the shape of x
861
+ and var .owner .inputs [0 ].owner is not None
862
+ and isinstance (var .owner .inputs [0 ].owner .op , Shape )
863
+ and var .owner .inputs [0 ].owner .inputs [0 ] == x
864
+ # Check that index == i
865
+ and (
866
+ get_scalar_constant_value (var .owner .inputs [1 ], raise_not_constant = False )
867
+ == i
868
+ )
869
+ )
870
+
871
+ return False
872
+
873
+
874
+ def _unpack_shape_vector (shape : TensorVariable ) -> tuple [TensorVariable , ...] | None :
875
+ """Return the elements of a symbolic vector representing a shape.
876
+
877
+ Handles the most common constant vector or make_vector cases.
878
+
879
+ Returns tuple(shape) as fallback.
880
+ """
881
+ if isinstance (shape , Constant ):
882
+ return tuple (as_tensor_variable (dim , ndim = 0 ) for dim in shape .data )
883
+ elif shape .owner and isinstance (shape .owner .op , MakeVector ):
884
+ return tuple (shape .owner .inputs )
885
+ else :
886
+ return tuple (shape )
790
887
791
888
792
889
@register_useless ("shape_unsafe" )
@@ -821,86 +918,29 @@ def local_useless_reshape(fgraph, node):
821
918
if shape_input == inp :
822
919
return [inp ]
823
920
824
- # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for
825
- # broadcastable and constant dimensions
826
- if isinstance (output_shape , Constant ) or (
827
- output_shape .owner and isinstance (output_shape .owner .op , MakeVector )
828
- ):
829
- if isinstance (output_shape , Constant ):
830
- output_shape_is = [
831
- as_tensor_variable (dim , ndim = 0 ) for dim in output_shape .data
832
- ]
833
- else :
834
- output_shape_is = output_shape .owner .inputs
835
-
836
- shape_feature = getattr (fgraph , "shape_feature" , None )
837
-
838
- nb_m1 = 0
839
- shape_match = [False ] * inp .type .ndim
840
- for dim in range (inp .type .ndim ):
841
- outshp_i = output_shape_is [dim ]
842
- # Match Shape_i{dim}(input)
843
- if (
844
- outshp_i .owner
845
- and isinstance (outshp_i .owner .op , Shape_i )
846
- and outshp_i .owner .op .i == dim
847
- and outshp_i .owner .inputs [0 ] == inp
848
- ):
849
- shape_match [dim ] = True
850
- continue
921
+ shape_feature = getattr (fgraph , "shape_feature" , None )
851
922
852
- # Match Shape(input)[dim]
853
- if (
854
- outshp_i .owner
855
- and isinstance (outshp_i .owner .op , Subtensor )
856
- and len (outshp_i .owner .inputs ) == 2
857
- and get_scalar_constant_value (
858
- outshp_i .owner .inputs [1 ], raise_not_constant = False
859
- )
860
- == dim
861
- ):
862
- subtensor_inp = outshp_i .owner .inputs [0 ]
863
- if subtensor_inp .owner and isinstance (subtensor_inp .owner .op , Shape ):
864
- shape_input_i = subtensor_inp .owner .inputs [0 ]
865
- if shape_input_i == inp :
866
- shape_match [dim ] = True
867
- continue
868
-
869
- # Match constant if input.type.shape[dim] == constant
870
- cst_outshp_i = get_scalar_constant_value (
871
- outshp_i , only_process_constants = True , raise_not_constant = False
872
- )
873
- if inp .type .shape [dim ] == cst_outshp_i :
874
- shape_match [dim ] = True
875
- continue
876
-
877
- # Match -1
878
- if cst_outshp_i == - 1 :
879
- shape_match [dim ] = True
880
- nb_m1 += 1
881
- continue
923
+ # Match Reshape(x, [x.shape[0], ..., x.shape[-1]]), accounting for -1
924
+ # or cases where all but one dimension are provably preserved
925
+ output_shape_is = _unpack_shape_vector (output_shape )
882
926
883
- # Match shape_of[input][dim] or its constant equivalent
884
- if shape_feature :
885
- inpshp_i = shape_feature .get_shape (inp , dim )
886
- if inpshp_i == outshp_i or (
887
- get_scalar_constant_value (
888
- inpshp_i , only_process_constants = True , raise_not_constant = False
889
- )
890
- == get_scalar_constant_value (
891
- outshp_i , only_process_constants = True , raise_not_constant = False
892
- )
893
- ):
894
- shape_match [dim ] = True
895
- continue
927
+ nb_m1 = 0
928
+ shape_match = [False ] * inp .type .ndim
929
+ for dim in range (inp .type .ndim ):
930
+ outshp_i = output_shape_is [dim ]
931
+ if _is_shape_i_of_x (outshp_i , inp , dim , shape_feature = shape_feature ):
932
+ shape_match [dim ] = True
933
+ elif isinstance (outshp_i , Constant ) and outshp_i .data == - 1 :
934
+ shape_match [dim ] = True
935
+ nb_m1 += 1
896
936
897
- if nb_m1 <= 1 and all (shape_match ):
898
- return [inp ]
937
+ if nb_m1 <= 1 and all (shape_match ):
938
+ return [inp ]
899
939
900
- if (nb_m1 == 0 ) and (shape_match .count (False ) == output .type .ndim - 1 ):
901
- return [inp ]
940
+ if (nb_m1 == 0 ) and (shape_match .count (False ) == output .type .ndim - 1 ):
941
+ return [inp ]
902
942
903
- return False
943
+ return False
904
944
905
945
906
946
@register_canonicalize
@@ -914,39 +954,26 @@ def local_reshape_to_dimshuffle(fgraph, node):
914
954
915
955
For example:
916
956
- reshape(x, (1, n)) -> DimShuffle{x,0}(Reshape(x, (n,))
917
- - reshape(x, (1, m, 1, n, 1, 1))
918
- -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
957
+ - reshape(x, (1, m, 1, n, 1, 1)) -> DimShuffle{x,0,x,1,x,x}(Reshape(x, (m, n)))
919
958
"""
920
- op = node .op
921
959
inp , output_shape = node .inputs
922
960
[output ] = node .outputs
923
961
924
- dimshuffle_new_order = []
962
+ unpacked_shape = _unpack_shape_vector (output_shape )
963
+ expand_axes = []
925
964
new_output_shape = []
926
- index = 0 # index over the output of the new reshape
927
- for i in range (output .ndim ):
928
- # Since output_shape is a symbolic vector, we trust get_scalar_constant_value
929
- # to go through however it is formed to see if its i-th element is 1.
930
- # We need only_process_constants=False for that.
931
- dim = get_scalar_constant_value (
932
- output_shape [i ],
933
- only_process_constants = False ,
934
- elemwise = False ,
935
- raise_not_constant = False ,
936
- )
937
- if dim == 1 :
938
- dimshuffle_new_order .append ("x" )
965
+ for i , dim in enumerate (unpacked_shape ):
966
+ if isinstance (dim , Constant ) and dim .data == 1 :
967
+ expand_axes .append (i )
939
968
else :
940
- dimshuffle_new_order .append (index )
941
969
new_output_shape .append (dim )
942
- index = index + 1
943
970
944
- if index != output .type .ndim :
945
- inner = op . __class__ ( len ( new_output_shape ))( inp , new_output_shape )
971
+ if len ( new_output_shape ) != output .type .ndim :
972
+ inner = inp . reshape ( new_output_shape )
946
973
copy_stack_trace (output , inner )
947
- new_node = [ inner . dimshuffle ( dimshuffle_new_order )]
948
- copy_stack_trace (output , new_node )
949
- return new_node
974
+ new_out = expand_dims ( inner , expand_axes )
975
+ copy_stack_trace (output , new_out )
976
+ return [ new_out ]
950
977
951
978
952
979
@register_canonicalize
@@ -1186,44 +1213,6 @@ def local_track_shape_i(fgraph, node):
1186
1213
return [shape_feature .shape_of [replacement ][node .op .i ]]
1187
1214
1188
1215
1189
- @register_canonicalize
1190
- @node_rewriter ([Reshape ])
1191
- def local_useless_dimshuffle_in_reshape (fgraph , node ):
1192
- """
1193
- Removes useless DimShuffle operation inside Reshape:
1194
-
1195
- reshape(vector.dimshuffle('x', 0), shp) => reshape(vector, shp)
1196
- reshape(matrix.dimshuffle('x', 0, 'x', 1), shp) => reshape(matrix, shp)
1197
- reshape(row.dimshuffle(1, 'x'), shp) => reshape(row, shp)
1198
- reshape(col.dimshuffle(0), shp) => reshape(col, shp)
1199
-
1200
- """
1201
- op = node .op
1202
- if not isinstance (op , Reshape ):
1203
- return False
1204
- if not (
1205
- node .inputs [0 ].owner is not None
1206
- and isinstance (node .inputs [0 ].owner .op , DimShuffle )
1207
- ):
1208
- return False
1209
-
1210
- new_order = node .inputs [0 ].owner .op .new_order
1211
- inp = node .inputs [0 ].owner .inputs [0 ]
1212
- new_order_of_nonbroadcast = []
1213
- for i , s in zip (new_order , node .inputs [0 ].type .shape , strict = True ):
1214
- if s != 1 :
1215
- new_order_of_nonbroadcast .append (i )
1216
- no_change_in_order = all (
1217
- new_order_of_nonbroadcast [i ] <= new_order_of_nonbroadcast [i + 1 ]
1218
- for i in range (len (new_order_of_nonbroadcast ) - 1 )
1219
- )
1220
- if no_change_in_order :
1221
- shape = node .inputs [1 ]
1222
- ret = op .__class__ (node .outputs [0 ].ndim )(inp , shape )
1223
- copy_stack_trace (node .outputs [0 ], ret )
1224
- return [ret ]
1225
-
1226
-
1227
1216
@register_useless
1228
1217
@register_canonicalize
1229
1218
@register_specialize
0 commit comments