diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 7f9b299165..62aa7fc3f0 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -47,6 +47,7 @@ from aesara.graph.fg import FunctionGraph from aesara.graph.op import Op, compute_test_value from aesara.sandbox.rng_mrg import MRG_RandomStream as RandomStream +from aesara.scalar.basic import Cast from aesara.tensor.elemwise import Elemwise from aesara.tensor.random.op import RandomVariable from aesara.tensor.shape import SpecifyShape @@ -232,6 +233,9 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray: return x.data if isinstance(x, SharedVariable): return x.get_value() + if x.owner and isinstance(x.owner.op, Elemwise) and isinstance(x.owner.op.scalar_op, Cast): + array_data = extract_obs_data(x.owner.inputs[0]) + return array_data.astype(x.type.dtype) if x.owner and isinstance(x.owner.op, (AdvancedIncSubtensor, AdvancedIncSubtensor1)): array_data = extract_obs_data(x.owner.inputs[0]) mask_idx = tuple(extract_obs_data(i) for i in x.owner.inputs[2:]) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 74622a16a3..4d87adbc81 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -380,6 +380,14 @@ def test_extract_obs_data(): assert isinstance(res, np.ndarray) assert np.ma.allequal(res, data_m) + # Cast check + data = np.array(5) + t = at.cast(at.as_tensor(5.0), np.int64) + res = extract_obs_data(t) + + assert isinstance(res, np.ndarray) + assert np.array_equal(res, data) + @pytest.mark.parametrize("input_dtype", ["int32", "int64", "float32", "float64"]) def test_pandas_to_array(input_dtype): diff --git a/pymc/tests/test_sampling.py b/pymc/tests/test_sampling.py index c53d5f1478..d534553061 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/test_sampling.py @@ -1078,7 +1078,7 @@ def test_shared(self): o = pm.Deterministic("o", obs) gen1 = pm.sample_prior_predictive(draws) - assert gen1.prior["y"].shape == (1, draws, n1) + assert gen1.prior_predictive["y"].shape == (1, draws, n1) assert gen1.prior["o"].shape == (1, draws, n1) n2 = 20 @@ -1086,7 +1086,7 @@ def test_shared(self): with m: gen2 = pm.sample_prior_predictive(draws) - assert gen2.prior["y"].shape == (1, draws, n2) + assert gen2.prior_predictive["y"].shape == (1, draws, n2) assert gen2.prior["o"].shape == (1, draws, n2) def test_density_dist(self):