Skip to content

Commit 005f0f6

Browse files
committed
Add WishartBartlett example.
1 parent 8b48451 commit 005f0f6

File tree

2 files changed

+50
-8
lines changed

2 files changed

+50
-8
lines changed

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import warnings
55

66
from .dist_math import *
7+
from . import ChiSquared, Normal
8+
from .. import Deterministic
79

810
import numpy as np
11+
import scipy
912
from . import transforms
1013

1114
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
12-
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all
15+
from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor
1316
from theano.printing import Print
1417
from pymc3.distributions.distribution import draw_values, generate_samples
1518
import scipy.stats as st
@@ -270,18 +273,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
270273
tril_idx = np.tril_indices_from(S, k=-1)
271274
n_diag = len(diag_idx[0])
272275
n_tril = len(tril_idx[0])
273-
c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
274-
z = pm.Normal('z', 0, 1, shape=n_tril)
276+
c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag))
277+
z = Normal('z', 0, 1, shape=n_tril)
275278
# Construct A matrix
276-
A = T.zeros(S.shape, dtype=np.float32)
277-
A = T.set_subtensor(A[diag_idx], c)
278-
A = T.set_subtensor(A[tril_idx], z)
279+
A = zeros(S.shape, dtype=np.float32)
280+
A = set_subtensor(A[diag_idx], c)
281+
A = set_subtensor(A[tril_idx], z)
279282

280283
# L * A * A.T * L.T ~ Wishart(L*L.T, nu)
281284
if return_cholesky:
282-
return pm.Deterministic(name, T.dot(L, A))
285+
return Deterministic(name, dot(L, A))
283286
else:
284-
return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T))
287+
return Deterministic(name, dot(dot(dot(L, A), A.T), L.T))
285288

286289

287290
class LKJCorr(Continuous):

pymc3/examples/wishart.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pymc3 as pm
2+
import numpy as np
3+
import theano
4+
import theano.tensor as T
5+
import scipy.stats
6+
import matplotlib.pyplot as plt
7+
8+
# Covariance matrix we want to recover
9+
covariance = np.matrix([[2, .5, -.5],
10+
[.5, 1., 0.],
11+
[-.5, 0., 0.5]])
12+
13+
prec = np.linalg.inv(covariance)
14+
15+
mean = [.5, 1, .2]
16+
data = scipy.stats.multivariate_normal(mean, covariance).rvs(5000)
17+
18+
plt.scatter(data[:, 0], data[:, 1])
19+
20+
with pm.Model() as model:
21+
S = np.eye(3)
22+
nu = 5
23+
mu = pm.Normal('mu', mu=0, sd=1, shape=3)
24+
25+
# Use the transformed Wishart distribution
26+
# Under the hood this will do a Cholesky decomposition
27+
# of S and add two RVs to the sampler: c and z
28+
prec = pm.WishartBartlett('prec', S, nu)
29+
30+
# To be able to compare it to truth, convert precision to covariance
31+
cov = pm.Deterministic('cov', T.nlinalg.matrix_inverse(prec))
32+
33+
lp = pm.MvNormal('likelihood', mu=mu, tau=prec, observed=data)
34+
35+
start = pm.find_MAP()
36+
step = pm.NUTS(scaling=start)
37+
trace = pm.sample(500, step)
38+
39+
pm.traceplot(trace[100:]);

0 commit comments

Comments
 (0)