File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -578,11 +578,15 @@ def specify_shape(
578
578
x = ptb .as_tensor_variable (x ) # type: ignore[arg-type,unused-ignore]
579
579
# The above is a type error in Python 3.9 but not 3.12.
580
580
# Thus we need to ignore unused-ignore on 3.12.
581
- new_shape_info = any (
582
- s != xts for (s , xts ) in zip (shape , x .type .shape , strict = False ) if s is not None
583
- )
581
+
584
582
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
585
- if not new_shape_info and len (shape ) == x .type .ndim :
583
+ if len (shape ) != x .type .ndim :
584
+ return _specify_shape (x , * shape )
585
+
586
+ new_shape_matches = all (
587
+ s == xts for (s , xts ) in zip (shape , x .type .shape , strict = True ) if s is not None
588
+ )
589
+ if new_shape_matches :
586
590
return x
587
591
588
592
return _specify_shape (x , * shape )
You can’t perform that action at this time.
0 commit comments