|
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | from .dist_math import *
|
| 7 | +from . import ChiSquared, Normal |
| 8 | +from .. import Deterministic |
7 | 9 |
|
8 | 10 | import numpy as np
|
| 11 | +import scipy |
9 | 12 |
|
10 | 13 | from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
|
11 |
| -from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all |
| 14 | +from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor |
12 | 15 | from theano.printing import Print
|
13 | 16 |
|
14 | 17 | __all__ = ['MvNormal', 'Dirichlet', 'Multinomial', 'Wishart', 'WishartBartlett', 'LKJCorr']
|
@@ -233,18 +236,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
|
233 | 236 | tril_idx = np.tril_indices_from(S, k=-1)
|
234 | 237 | n_diag = len(diag_idx[0])
|
235 | 238 | n_tril = len(tril_idx[0])
|
236 |
| - c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
237 |
| - z = pm.Normal('z', 0, 1, shape=n_tril) |
| 239 | + c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
| 240 | + z = Normal('z', 0, 1, shape=n_tril) |
238 | 241 | # Construct A matrix
|
239 |
| - A = T.zeros(S.shape, dtype=np.float32) |
240 |
| - A = T.set_subtensor(A[diag_idx], c) |
241 |
| - A = T.set_subtensor(A[tril_idx], z) |
| 242 | + A = zeros(S.shape, dtype=np.float32) |
| 243 | + A = set_subtensor(A[diag_idx], c) |
| 244 | + A = set_subtensor(A[tril_idx], z) |
242 | 245 |
|
243 | 246 | # L * A * A.T * L.T ~ Wishart(L*L.T, nu)
|
244 | 247 | if return_cholesky:
|
245 |
| - return pm.Deterministic(name, T.dot(L, A)) |
| 248 | + return Deterministic(name, dot(L, A)) |
246 | 249 | else:
|
247 |
| - return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T)) |
| 250 | + return Deterministic(name, dot(dot(dot(L, A), A.T), L.T)) |
248 | 251 |
|
249 | 252 |
|
250 | 253 | class LKJCorr(Continuous):
|
|
0 commit comments