Skip to content

aet.switch fails local_useless_switch optimizations when used with certain conditions.  #616

@kc611

Description

@kc611

MWE:

import aesara
import aesara.tensor as aet
import numpy as np

aesara.config.floatX = 'float32'

cond = aet.as_tensor_variable(np.array([True]))
ift = aet.as_tensor_variable(np.array([0]))
iff = aet.as_tensor_variable(np.array([0]))

a = aet.switch(cond, ift, iff)
b = a.eval() # Throws Error while optimizations

The solutions also looks straightforward, the issue is probably that when cond=True, it fails to go through this particular branch:

if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance(
cond, np.number
):

Adding an additional condition like or cond in [True, False] should fix this issue.

Note that this only happens when floatX is set to float32.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is neededimportant

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions