|
4 | 4 | import warnings
|
5 | 5 |
|
6 | 6 | from .dist_math import *
|
| 7 | +from . import ChiSquared, Normal |
| 8 | +from .. import Deterministic |
7 | 9 |
|
8 | 10 | import numpy as np
|
| 11 | +import scipy |
9 | 12 | from . import transforms
|
10 | 13 |
|
11 | 14 | from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
|
12 |
| -from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all |
| 15 | +from theano.tensor import dot, cast, eye, diag, eq, le, ge, gt, all, zeros, sqrt, set_subtensor |
13 | 16 | from theano.printing import Print
|
14 | 17 | from pymc3.distributions.distribution import draw_values, generate_samples
|
15 | 18 | import scipy.stats as st
|
@@ -270,18 +273,18 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False):
|
270 | 273 | tril_idx = np.tril_indices_from(S, k=-1)
|
271 | 274 | n_diag = len(diag_idx[0])
|
272 | 275 | n_tril = len(tril_idx[0])
|
273 |
| - c = T.sqrt(pm.ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
274 |
| - z = pm.Normal('z', 0, 1, shape=n_tril) |
| 276 | + c = sqrt(ChiSquared('c', nu - np.arange(2, 2+n_diag), shape=n_diag)) |
| 277 | + z = Normal('z', 0, 1, shape=n_tril) |
275 | 278 | # Construct A matrix
|
276 |
| - A = T.zeros(S.shape, dtype=np.float32) |
277 |
| - A = T.set_subtensor(A[diag_idx], c) |
278 |
| - A = T.set_subtensor(A[tril_idx], z) |
| 279 | + A = zeros(S.shape, dtype=np.float32) |
| 280 | + A = set_subtensor(A[diag_idx], c) |
| 281 | + A = set_subtensor(A[tril_idx], z) |
279 | 282 |
|
280 | 283 | # L * A * A.T * L.T ~ Wishart(L*L.T, nu)
|
281 | 284 | if return_cholesky:
|
282 |
| - return pm.Deterministic(name, T.dot(L, A)) |
| 285 | + return Deterministic(name, dot(L, A)) |
283 | 286 | else:
|
284 |
| - return pm.Deterministic(name, T.dot(T.dot(T.dot(L, A), A.T), L.T)) |
| 287 | + return Deterministic(name, dot(dot(dot(L, A), A.T), L.T)) |
285 | 288 |
|
286 | 289 |
|
287 | 290 | class LKJCorr(Continuous):
|
|
0 commit comments