From 70aae973d2c5e05e16c09d352644d8af7e02e67a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Jul 2023 13:49:36 +0200 Subject: [PATCH] Fix overly strict check in `local_pow_specialize` rewrite --- pytensor/tensor/rewriting/math.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2f79bf20fa..c585b70096 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2072,11 +2072,10 @@ def local_pow_specialize(fgraph, node): if np.all(y == -2): rval = [reciprocal(sqr(xsym))] if rval: + if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable: + return None rval[0] = cast(rval[0], odtype) - assert rval[0].type.is_super(node.outputs[0].type), ( - rval[0].type, - node.outputs[0].type, - ) + assert rval[0].type.dtype == node.outputs[0].type.dtype return rval else: return False