Skip to content

Commit e69f1a2

Browse files
committed
split_hmc now looks like it works
1 parent 3d47b94 commit e69f1a2

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

examples/logistic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def invlogit(x):
99
return np.exp(x)/(1 + np.exp(x))
1010

11-
npred = 20
11+
npred = 4
1212
n = 4000
1313

1414
effects_a = np.random.normal(size = npred)
@@ -39,8 +39,8 @@ def tinvlogit(x):
3939
chain = find_MAP(model, chain)
4040
hmc_cov = approx_cov(model, chain) #find a good orientation using the hessian at the MAP
4141

42-
step_method = hmc_step(model, model.vars, hmc_cov)
43-
#step_method = split_hmc_step(model, model.vars, hmc_cov, chain, hmc_cov)
42+
#step_method = hmc_step(model, model.vars, hmc_cov)
43+
step_method = split_hmc_step(model, model.vars, hmc_cov, chain, hmc_cov)
4444

4545
ndraw = 3e3
4646

mcex/step_methods/hmc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def step(logp_d, state, q0):
5454

5555
p = -p
5656

57-
mr = logp - logp0 + K(C, p0) - K(C, p)
57+
# - H(q*, p*) + H(q, p) = -H(q, p) + H(q0, p0) = -(- logp(q) + K(p)) + (-logp(q0) + K(p0))
58+
mr = (-logp0) + K(C, p0) - ((-logp) + K(C, p))
5859
state.metrops.append(mr)
5960

6061
return state, metrop_select(mr, q, q0)

mcex/step_methods/split_hmc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def split_hmc_step(model, vars, C, approx_loc, approx_C, step_size_scaling = .25
3131
D, gamma = eig(A)
3232

3333
e = step_size
34-
R = real(gamma.dot(diag(exp(D* e))).dot(gamma.T))
34+
R = real(gamma.dot(diag(exp(D * e))).dot(gamma.conj().T))
3535
def step(logp_d, state, q0):
3636

3737
if state is None:
@@ -56,7 +56,7 @@ def step(logp_d, state, q0):
5656
x = concatenate((q - approx_loc, p))
5757
x = dot(R, x)
5858
q = x[:n] + approx_loc
59-
59+
p = x[n:]
6060
logp, dlogp = logp_d(q)
6161
dlogp = dlogp + dot(approx_C, q - approx_loc)
6262

@@ -65,7 +65,8 @@ def step(logp_d, state, q0):
6565

6666
p = -p
6767

68-
mr = logp - logp0 + K(C, p0) - K(C, p)
68+
# - H(q*, p*) + H(q, p) = -H(q, p) + H(q0, p0) = -(- logp(q) + K(p)) + (-logp(q0) + K(p0))
69+
mr = (-logp0) + K(C, p0) - ((-logp) + K(C, p))
6970
state.metrops.append(mr)
7071

7172
return state, metrop_select(mr, q, q0)

0 commit comments

Comments
 (0)