Skip to content

Generalize multinomial moment to arbitrary dimensions #5476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 18, 2022

Conversation

markvrma
Copy link
Contributor

@markvrma markvrma commented Feb 15, 2022

Resolves #5393

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @markvrma thanks for opening a PR!

Besides the comment below, we also need to add test conditions to make sure the moment is indeed correct for higher dimensions. This PR is a good template: https://github.com/pymc-devs/pymc/pull/5225/files

Comment on lines 568 to 570
if (p.ndim == 1) and (n.ndim > 0):
n = at.shape_padright(n)
p = at.shape_padleft(p)
Copy link
Member

@ricardoV94 ricardoV94 Feb 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't need any these shape_padright/left anymore

Copy link
Contributor Author

@markvrma markvrma Feb 16, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried removing:

 if p.ndim > 1:
            n = at.shape_padright(n)
 if (p.ndim == 1) and (n.ndim > 0):
            n = at.shape_padright(n)
            p = at.shape_padleft(p)

but existing tests fail in test_distributions_moments.py

FAILED test_distributions_moments.py::test_multinomial_moment[p4-n4-None-expected4]
FAILED test_distributions_moments.py::test_multinomial_moment[p5-n5-None-expected5]
FAILED test_distributions_moments.py::test_multinomial_moment[p6-n6-2-expected6]

@codecov
Copy link

codecov bot commented Feb 20, 2022

Codecov Report

Merging #5476 (d82c05e) into main (2db28f0) will decrease coverage by 6.15%.
The diff coverage is 100.00%.

❗ Current head d82c05e differs from pull request most recent head 1785167. Consider uploading reports for the commit 1785167 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5476      +/-   ##
==========================================
- Coverage   87.60%   81.45%   -6.16%     
==========================================
  Files          76       81       +5     
  Lines       13745    14205     +460     
==========================================
- Hits        12041    11570     -471     
- Misses       1704     2635     +931     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 91.53% <100.00%> (-0.77%) ⬇️
pymc/distributions/mixture.py 21.17% <0.00%> (-71.97%) ⬇️
pymc/variational/test_functions.py 38.09% <0.00%> (-61.91%) ⬇️
pymc/variational/inference.py 29.37% <0.00%> (-57.87%) ⬇️
pymc/variational/stein.py 48.33% <0.00%> (-51.67%) ⬇️
pymc/variational/callbacks.py 48.00% <0.00%> (-48.00%) ⬇️
pymc/variational/flows.py 36.79% <0.00%> (-46.76%) ⬇️
pymc/variational/opvi.py 36.21% <0.00%> (-46.03%) ⬇️
pymc/variational/operators.py 50.00% <0.00%> (-42.86%) ⬇️
pymc/variational/approximations.py 36.30% <0.00%> (-29.76%) ⬇️
... and 36 more

@markvrma
Copy link
Contributor Author

Added new test for higher dimensions.

I couldn't remove:

 if p.ndim > 1:
            n = at.shape_padright(n)
 if (p.ndim == 1) and (n.ndim > 0):
            n = at.shape_padright(n)
            p = at.shape_padleft(p)

as existing tests fail in test_distributions_moments.py.

@ricardoV94
Copy link
Member

Added new test for higher dimensions.

I couldn't remove:

 if p.ndim > 1:
            n = at.shape_padright(n)
 if (p.ndim == 1) and (n.ndim > 0):
            n = at.shape_padright(n)
            p = at.shape_padleft(p)

as existing tests fail in test_distributions_moments.py.

We might need something else, instead of just removing these lines. Or perhaps those lines also give the right result for higher dimensions that are now supported, but I would be surprised if that was the case, since they seem to have specialized logic for <= 2D parameters.

I would brute-force with local tests where we try different combinations of p and n with shapes ranging from scalar to 3D and see if this function returns the correct output. One way might be to compare with a np.vectorized version of get_moment that works for the base case (1d vector of p and scalar n). There are some examples where we use this approach for testing the logp here:

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
"p",
[
([0.2, 0.3, 0.5]),
([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]),
(np.abs(np.random.randn(2, 2, 4))),
],
)
@pytest.mark.parametrize("size", [1, 2, (2, 3)])
def test_multinomial_vectorized(self, n, p, size):
n = intX(np.array(n))
p = floatX(np.array(p))
p /= p.sum(axis=-1, keepdims=True)
mn = pm.Multinomial.dist(n=n, p=p, size=size)
vals = mn.eval()
assert_almost_equal(
scipy.stats.multinomial.logpmf(vals, n, p),
pm.logp(mn, vals).eval(),
decimal=4,
err_msg=f"vals={vals}",
)

and here:

def _dirichlet_multinomial_logpmf(value, n, a):
if value.sum() == n and (0 <= value).all() and (value <= n).all():
sum_a = a.sum()
const = gammaln(n + 1) + gammaln(sum_a) - gammaln(n + sum_a)
series = gammaln(value + a) - gammaln(value + 1) - gammaln(a)
return const + series.sum()
else:
return -inf
dirichlet_multinomial_logpmf = np.vectorize(
_dirichlet_multinomial_logpmf, signature="(n),(),(n)->()"
)

@pytest.mark.parametrize("n", [(10), ([10, 11]), ([[5, 6], [10, 11]])])
@pytest.mark.parametrize(
"a",
[
([0.2, 0.3, 0.5]),
([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]),
(np.abs(np.random.randn(2, 2, 4))),
],
)
@pytest.mark.parametrize("size", [1, 2, (2, 3)])
def test_dirichlet_multinomial_vectorized(self, n, a, size):
n = intX(np.array(n))
a = floatX(np.array(a))
dm = pm.DirichletMultinomial.dist(n=n, a=a, size=size)
vals = dm.eval()
assert_almost_equal(
dirichlet_multinomial_logpmf(vals, n, a),
pm.logp(dm, vals).eval(),
decimal=4,
err_msg=f"vals={vals}",
)

@ricardoV94
Copy link
Member

Also did you figure out what exactly is failing? Looking at the DirichletMultinomial mode it seems like it should be exactly the same, except that here we already have the p, and can skip the first step of p = a / at.sum(a, axis=-1) in https://github.com/pymc-devs/pymc/pull/5225/files Unless there's an error in the implementation of the DirichletMultinomial moment as well

@ricardoV94 ricardoV94 force-pushed the multinomial_moment_to_arb_dims branch 6 times, most recently from 9271395 to adeb799 Compare March 18, 2022 09:32
@ricardoV94 ricardoV94 force-pushed the multinomial_moment_to_arb_dims branch from adeb799 to 1785167 Compare March 18, 2022 09:42
@ricardoV94
Copy link
Member

ricardoV94 commented Mar 18, 2022

@markvrma It seems than it was enough to add one dimension to the right of n. The old and new tests were quite helpful to figure this out, and also show that the DirichletMultinomial moment would fail for some configurations.

Thanks for the help!

@ricardoV94 ricardoV94 merged commit a0cff37 into pymc-devs:main Mar 18, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Generalize multinomial moment to arbitrary dimensions
2 participants