Skip to content

Commit 6124822

Browse files
committed
Make non-strict zip strict in tensor/shape.py
1 parent fe7018c commit 6124822

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

pytensor/tensor/shape.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,11 +578,15 @@ def specify_shape(
578578
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
579579
# The above is a type error in Python 3.9 but not 3.12.
580580
# Thus we need to ignore unused-ignore on 3.12.
581-
new_shape_info = any(
582-
s != xts for (s, xts) in zip(shape, x.type.shape, strict=False) if s is not None
583-
)
581+
584582
# If shape does not match x.ndim, we rely on the `Op` to raise a ValueError
585-
if not new_shape_info and len(shape) == x.type.ndim:
583+
if len(shape) != x.type.ndim:
584+
return _specify_shape(x, *shape)
585+
586+
new_shape_matches = all(
587+
s == xts for (s, xts) in zip(shape, x.type.shape, strict=True) if s is not None
588+
)
589+
if new_shape_matches:
586590
return x
587591

588592
return _specify_shape(x, *shape)

0 commit comments

Comments
 (0)