From e6c7f918019c1747be90e2b0c3ebe9be93b3c349 Mon Sep 17 00:00:00 2001 From: Morgan Pihl Date: Thu, 25 Nov 2021 21:17:17 +0100 Subject: [PATCH] Adds tests and mode for dirichlet multinomial distribution --- pymc/distributions/multivariate.py | 17 +++++++++-- pymc/tests/test_distributions_moments.py | 36 ++++++++++++++++++++---- 2 files changed, 45 insertions(+), 8 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index 42d9daabc4..ff4e34d68d 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -453,9 +453,7 @@ def get_moment(rv, size, a): norm_constant = at.sum(a, axis=-1)[..., None] moment = a / norm_constant if not rv_size_is_none(size): - if isinstance(size, int): - size = (size,) - moment = at.full((*size, *a.shape), moment) + moment = at.full(at.concatenate([size, a.shape]), moment) return moment def logp(value, a): @@ -684,6 +682,19 @@ def dist(cls, n, a, *args, **kwargs): return super().dist([n, a], **kwargs) + def get_moment(rv, size, n, a): + p = a / at.sum(a, axis=-1) + mode = at.round(n * p) + diff = n - at.sum(mode, axis=-1, keepdims=True) + inc_bool_arr = at.abs_(diff) > 0 + mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) + # Reshape mode according to base shape (ignoring size) + mode = at.reshape(mode, rv.shape[size.size :]) + if not rv_size_is_none(size): + output_size = at.concatenate([size, mode.shape]) + mode = at.full(output_size, mode) + return mode + def logp(value, n, a): """ Calculate log-probability of DirichletMultinomial distribution diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 1f0192b7c8..fd38e43b0f 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -21,6 +21,7 @@ Constant, DensityDist, Dirichlet, + DirichletMultinomial, DiscreteUniform, DiscreteWeibull, ExGaussian, @@ -112,7 +113,6 @@ def test_all_distributions_have_moments(): # Distributions that have been refactored but don't yet have moments not_implemented |= { - dist_module.multivariate.DirichletMultinomial, dist_module.multivariate.Wishart, } @@ -797,10 +797,7 @@ def test_discrete_weibull_moment(q, beta, size, expected): ), ( np.full(shape=np.array([7, 3]), fill_value=np.array([13, 17, 19])), - ( - 11, - 5, - ), + (11, 5), np.broadcast_to([13, 17, 19], shape=[11, 5, 7, 3]) / 49, ), ], @@ -1461,3 +1458,32 @@ def test_lkjcholeskycov_moment(n, eta, size, expected): sd_dist = pm.Exponential.dist(1, size=(*to_tuple(size), n)) LKJCholeskyCov("x", n=n, eta=eta, sd_dist=sd_dist, size=size, compute_corr=False) assert_moment_is_expected(model, expected, check_finite_logp=size is None) + + +@pytest.mark.parametrize( + "a, n, size, expected", + [ + (np.array([2, 2, 2, 2]), 1, None, np.array([1, 0, 0, 0])), + (np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])), + (np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])), + ( + np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 + np.array([[1], [10]]), # Dim: 2 x 1 + None, + np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 + ), + ( + np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 + np.array([[1], [10]]), # Dim: 2 x 1 + (2, 1), + np.full( + (2, 1, 2, 1, 4), + np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 + ), + ), + ], +) +def test_dirichlet_multinomial_moment(a, n, size, expected): + with Model() as model: + DirichletMultinomial("x", n=n, a=a, size=size) + assert_moment_is_expected(model, expected)