Skip to content

Commit 4f266c6

Browse files
committed
Add WishartBartlett example.
1 parent 6d3e7ba commit 4f266c6

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44
import warnings
55

66
from .dist_math import *
7+
from . import ChiSquared, Normal
8+
from .. import Deterministic
79

810
import numpy as np
11+
import scipy
912

1013
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
11-
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all
14+
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor
1215
from theano.printing import Print
1316

1417
__all__ = ['MvNormal', 'Dirichlet', 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr']
@@ -233,18 +236,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
233236
tril_idx = np.tril_indices_from(S, k=-1)
234237
n_diag = len(diag_idx[0])
235238
n_tril = len(tril_idx[0])
236-
c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
237-
z = pm.Normal('z', 0, 1, shape=n_tril)
239+
c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
240+
z = Normal('z', 0, 1, shape=n_tril)
238241
# Construct A matrix
239-
A = T.zeros(S.shape, dtype=np.float32)
240-
A = T.set_subtensor(A[diag_idx], c)
241-
A = T.set_subtensor(A[tril_idx], z)
242+
A = zeros(S.shape, dtype=np.float32)
243+
A = set_subtensor(A[diag_idx], c)
244+
A = set_subtensor(A[tril_idx], z)
242245

243246
# L * A * A.T * L.T ~ Wishart(L*L.T, nu)
244247
if return_cholesky:
245-
return pm.Deterministic(name, T.dot(L, A))
248+
return Deterministic(name, dot(L, A))
246249
else:
247-
return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T))
250+
return Deterministic(name, dot(dot(dot(L, A), A.T), L.T))
248251

249252

250253
class LKJCorr(Continuous):

pymc3/examples/wishart.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pymc3 as pm
2+
import numpy as np
3+
import theano
4+
import theano.tensor as T
5+
import scipy.stats
6+
import matplotlib.pyplot as plt
7+
8+
# Covariance matrix we want to recover
9+
covariance = np.matrix([[2, .5, -.5],
10+
[.5, 1., 0.],
11+
[-.5, 0., 0.5]])
12+
13+
prec = np.linalg.inv(covariance)
14+
15+
mean = [.5, 1, .2]
16+
data = scipy.stats.multivariate_normal(mean, covariance).rvs(5000)
17+
18+
plt.scatter(data[:, 0], data[:, 1])
19+
20+
with pm.Model() as model:
21+
S = np.eye(3)
22+
nu = 5
23+
mu = pm.Normal('mu', mu=0, sd=1, shape=3)
24+
25+
# Use the transformed Wishart distribution
26+
# Under the hood this will do a Cholesky decomposition
27+
# of S and add two RVs to the sampler: c and z
28+
prec = pm.WishartBartlett('prec', S, nu)
29+
30+
# To be able to compare it to truth, convert precision to covariance
31+
cov = pm.Deterministic('cov', T.nlinalg.matrix_inverse(prec))
32+
33+
lp = pm.MvNormal('likelihood', mu=mu, tau=prec, observed=data)
34+
35+
start = pm.find_MAP()
36+
step = pm.NUTS(scaling=start)
37+
trace = pm.sample(500, step)
38+
39+
pm.traceplot(trace[100:]);

0 commit comments

Comments
 (0)