diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 2f79bf20fa..c585b70096 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -2072,11 +2072,10 @@ def local_pow_specialize(fgraph, node): if np.all(y == -2): rval = [reciprocal(sqr(xsym))] if rval: + if not rval[0].type.broadcastable == node.outputs[0].type.broadcastable: + return None rval[0] = cast(rval[0], odtype) - assert rval[0].type.is_super(node.outputs[0].type), ( - rval[0].type, - node.outputs[0].type, - ) + assert rval[0].type.dtype == node.outputs[0].type.dtype return rval else: return False