-
-
Notifications
You must be signed in to change notification settings - Fork 151
Closed
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededimportant
Description
MWE:
import aesara
import aesara.tensor as aet
import numpy as np
aesara.config.floatX = 'float32'
cond = aet.as_tensor_variable(np.array([True]))
ift = aet.as_tensor_variable(np.array([0]))
iff = aet.as_tensor_variable(np.array([0]))
a = aet.switch(cond, ift, iff)
b = a.eval() # Throws Error while optimizations
The solutions also looks straightforward, the issue is probably that when cond=True
, it fails to go through this particular branch:
aesara/aesara/tensor/basic_opt.py
Lines 2538 to 2540 in 29032f3
if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( | |
cond, np.number | |
): |
Adding an additional condition like or cond in [True, False]
should fix this issue.
Note that this only happens when floatX
is set to float32
.
brandonwillard
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is neededimportant