Skip to content

Commit 202d8a0

Browse files
committed
Fix casting bug
1 parent d322b81 commit 202d8a0

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

pymc/aesaraf.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from aesara.graph.fg import FunctionGraph
4848
from aesara.graph.op import Op, compute_test_value
4949
from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream
50+
from aesara.scalar.basic import Cast
5051
from aesara.tensor.elemwise import Elemwise
5152
from aesara.tensor.random.op import RandomVariable
5253
from aesara.tensor.shape import SpecifyShape
@@ -232,6 +233,9 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
232233
return x.data
233234
if isinstance(x, SharedVariable):
234235
return x.get_value()
236+
if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast):
237+
array_data = extract_obs_data(x.owner.inputs[0])
238+
return array_data.astype(x.type.dtype)
235239
if x.owner and isinstance(x.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)):
236240
array_data = extract_obs_data(x.owner.inputs[0])
237241
mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:])

pymc/tests/test_aesaraf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,14 @@ def test_extract_obs_data():
380380
assert isinstance(res, np.ndarray)
381381
assert np.ma.allequal(res, data_m)
382382

383+
# Cast check
384+
data = np.array(5)
385+
t = at.cast(at.as_tensor(5.0), np.int64)
386+
res = extract_obs_data(t)
387+
388+
assert isinstance(res, np.ndarray)
389+
assert np.array_equal(res, data)
390+
383391

384392
@pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"])
385393
def test_pandas_to_array(input_dtype):

pymc/tests/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,15 +1078,15 @@ def test_shared(self):
10781078
o = pm.Deterministic("o", obs)
10791079
gen1 = pm.sample_prior_predictive(draws)
10801080

1081-
assert gen1.prior["y"].shape == (1, draws, n1)
1081+
assert gen1.prior_predictive["y"].shape == (1, draws, n1)
10821082
assert gen1.prior["o"].shape == (1, draws, n1)
10831083

10841084
n2 = 20
10851085
obs.set_value(np.random.rand(n2) < 0.5)
10861086
with m:
10871087
gen2 = pm.sample_prior_predictive(draws)
10881088

1089-
assert gen2.prior["y"].shape == (1, draws, n2)
1089+
assert gen2.prior_predictive["y"].shape == (1, draws, n2)
10901090
assert gen2.prior["o"].shape == (1, draws, n2)
10911091

10921092
def test_density_dist(self):

0 commit comments

Comments
 (0)