From a7d7fb5428f9f6991d03347520706bfada46a5d6 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Fri, 18 Mar 2022 10:27:20 +0100 Subject: [PATCH 1/2] Make rv_size_is_none more robust --- pymc/distributions/shape_utils.py | 4 ++-- pymc/tests/test_distributions_moments.py | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/pymc/distributions/shape_utils.py b/pymc/distributions/shape_utils.py index c761a2caf1..5932f57d88 100644 --- a/pymc/distributions/shape_utils.py +++ b/pymc/distributions/shape_utils.py @@ -22,7 +22,7 @@ import numpy as np -from aesara.graph.basic import Constant, Variable +from aesara.graph.basic import Variable from aesara.tensor.var import TensorVariable from typing_extensions import TypeAlias @@ -618,4 +618,4 @@ def find_size( def rv_size_is_none(size: Variable) -> bool: """Check wether an rv size is None (ie., at.Constant([]))""" - return isinstance(size, Constant) and size.data.size == 0 + return size.type.shape == (0,) diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 0727963a92..7971480285 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -133,6 +133,9 @@ def test_rv_size_is_none(): rv = Normal.dist(0, 1, size=None) assert rv_size_is_none(rv.owner.inputs[1]) + rv = Normal.dist(0, 1, size=()) + assert rv_size_is_none(rv.owner.inputs[1]) + rv = Normal.dist(0, 1, size=1) assert not rv_size_is_none(rv.owner.inputs[1]) From 1785167e783a7d9d94e69cba7d3ff92bba2e58f3 Mon Sep 17 00:00:00 2001 From: markvrma Date: Tue, 15 Feb 2022 19:34:38 +0400 Subject: [PATCH 2/2] Generalize Multinomial moment to arbitrary dimensions --- pymc/distributions/multivariate.py | 25 ++++------------------- pymc/tests/test_distributions_moments.py | 26 +++++++++++++++--------- 2 files changed, 20 insertions(+), 31 deletions(-) diff --git a/pymc/distributions/multivariate.py b/pymc/distributions/multivariate.py index de500a3353..d8d5645337 100644 --- a/pymc/distributions/multivariate.py +++ b/pymc/distributions/multivariate.py @@ -56,7 +56,7 @@ logpow, multigammaln, ) -from pymc.distributions.distribution import Continuous, Discrete +from pymc.distributions.distribution import Continuous, Discrete, get_moment from pymc.distributions.shape_utils import ( broadcast_dist_samples_to, rv_size_is_none, @@ -558,11 +558,7 @@ def dist(cls, n, p, *args, **kwargs): return super().dist([n, p], *args, **kwargs) def get_moment(rv, size, n, p): - if p.ndim > 1: - n = at.shape_padright(n) - if (p.ndim == 1) & (n.ndim > 0): - n = at.shape_padright(n) - p = at.shape_padleft(p) + n = at.shape_padright(n) mode = at.round(n * p) diff = n - at.sum(mode, axis=-1, keepdims=True) inc_bool_arr = at.abs_(diff) > 0 @@ -682,21 +678,8 @@ def dist(cls, n, a, *args, **kwargs): return super().dist([n, a], **kwargs) def get_moment(rv, size, n, a): - p = a / at.sum(a, axis=-1) - mode = at.round(n * p) - diff = n - at.sum(mode, axis=-1, keepdims=True) - inc_bool_arr = at.abs_(diff) > 0 - mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) - - # Reshape mode according to dimensions implied by the parameters - # This can include axes of length 1 - _, p_bcast = broadcast_params([n, p], ndims_params=[0, 1]) - mode = at.reshape(mode, p_bcast.shape) - - if not rv_size_is_none(size): - output_size = at.concatenate([size, [p.shape[-1]]]) - mode = at.full(output_size, mode) - return mode + p = a / at.sum(a, axis=-1, keepdims=True) + return get_moment(Multinomial.dist(n=n, p=p, size=size)) def logp(value, n, a): """ diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 7971480285..433dda5a31 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -1308,13 +1308,13 @@ def test_polyagamma_moment(h, z, size, expected): np.array([[4, 6, 0, 0], [4, 2, 2, 2]]), ), ( - np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), - np.array([1, 10]), - None, - np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), + np.array([0.3, 0.6, 0.05, 0.05]), + np.array([2, 10]), + (1, 2), + np.array([[[1, 1, 0, 0], [4, 6, 0, 0]]]), ), ( - np.array([0.26, 0.26, 0.26, 0.22]), + np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), np.array([1, 10]), None, np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), @@ -1322,8 +1322,8 @@ def test_polyagamma_moment(h, z, size, expected): ( np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]), np.array([1, 10]), - (2, 2), - np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]), + (3, 2), + np.full((3, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]), ), ], ) @@ -1470,10 +1470,16 @@ def test_lkjcholeskycov_moment(n, eta, size, expected): (np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])), (np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])), ( - np.array([[26, 26, 26, 22]]), # Dim: 1 x 4 - np.array([[1], [10]]), # Dim: 2 x 1 + np.array([[30, 60, 5, 5], [26, 26, 26, 22]]), + 10, + (1, 2), + np.array([[[4, 6, 0, 0], [2, 3, 3, 2]]]), + ), + ( + np.array([26, 26, 26, 22]), + np.array([1, 10]), None, - np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4 + np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), ), ( np.array([[26, 26, 26, 22]]), # Dim: 1 x 4