Skip to content

Commit b08610c

Browse files
committed
Fix DiscreteMarkovChain logp
1 parent e163a6f commit b08610c

File tree

3 files changed

+349
-330
lines changed

3 files changed

+349
-330
lines changed

notebooks/discrete_markov_chain.ipynb

Lines changed: 330 additions & 325 deletions
Large diffs are not rendered by default.

pymc_experimental/distributions/timeseries.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
22-
from pymc.pytensorf import intX
22+
from pymc.pytensorf import constant_fold, intX
2323
from pymc.util import check_dist_not_registered
2424
from pytensor.graph.basic import Node
2525
from pytensor.tensor import TensorVariable
@@ -252,11 +252,16 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
252252
mc_logprob = logp(init_dist, value[..., :n_lags]).sum(axis=-1)
253253
mc_logprob += pt.log(P[tuple(indexes)]).sum(axis=-1)
254254

255+
# We cannot leave any RV in the logp graph, even if just for an assert
256+
[init_dist_leading_dim] = constant_fold(
257+
[pt.atleast_1d(init_dist).shape[0]], raise_not_constant=False
258+
)
259+
255260
return check_parameters(
256261
mc_logprob,
257262
pt.all(pt.eq(P.shape[-(n_lags + 1) :], P.shape[-1])),
258263
pt.all(pt.allclose(P.sum(axis=-1), 1.0)),
259-
pt.eq(pt.atleast_1d(init_dist).shape[0], n_lags),
264+
pt.eq(init_dist_leading_dim, n_lags),
260265
msg="Last (n_lags + 1) dimensions of P must be square, "
261266
"P must sum to 1 along the last axis, "
262267
"First dimension of init_dist must be n_lags",

pymc_experimental/tests/distributions/test_discrete_markov_chain.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,19 @@ def test_logp_with_default_init_dist(self):
9090
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))
9191
x0 = pm.Categorical.dist(p=np.ones(3) / 3)
9292

93-
chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
93+
value = np.array([0, 1, 2])
94+
logp_expected = np.log((1 / 3) * 0.5 * 0.3)
9495

95-
logp = pm.logp(chain, [0, 1, 2]).eval()
96-
assert logp == pytest.approx(np.log((1 / 3) * 0.5 * 0.3), rel=1e-6)
96+
# Test dist directly
97+
chain = DiscreteMarkovChain.dist(P=P, init_dist=x0, steps=3)
98+
logp_eval = pm.logp(chain, value).eval()
99+
np.testing.assert_allclose(logp_eval, logp_expected, rtol=1e-6)
100+
101+
# Test via Model
102+
with pm.Model() as m:
103+
DiscreteMarkovChain("chain", P=P, init_dist=x0, steps=3)
104+
model_logp_eval = m.compile_logp()({"chain": value})
105+
np.testing.assert_allclose(model_logp_eval, logp_expected, rtol=1e-6)
97106

98107
def test_logp_with_user_defined_init_dist(self):
99108
P = pt.as_tensor_variable(np.array([[0.1, 0.5, 0.4], [0.3, 0.4, 0.3], [0.9, 0.05, 0.05]]))

0 commit comments

Comments
 (0)