Skip to content

Commit

Permalink
allow fixed params in hmc
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Nov 24, 2017
1 parent bd48929 commit 3b4c2f4
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions candlegp/training/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def logprob_grads():
def thinning(thin_iterations, epsilon, lmin, lmax):
logprob, grads = logprob_grads()
for i in range(thin):
xs_prev = [p.data.clone() for p in model.parameters()]
xs_prev = [p.data.clone() for p in model.parameters() if p.requires_grad]
grads_prev = grads
logprob_prev = logprob
ps_init = [Variable(xs_prev[0].new(*p.size()).normal_()) for p in model.parameters()]
ps_init = [Variable(xs_prev[0].new(*p.size()).normal_()) for p in model.parameters() if p.requires_grad]
ps = [p + 0.5 * epsilon * grad for p,grad in zip(ps_init, grads_prev)]

max_iterations = int((torch.rand(1)*(lmax+1-lmin)+lmin)[0])
Expand All @@ -68,7 +68,7 @@ def thinning(thin_iterations, epsilon, lmin, lmax):
proceed = True
i_ps = 0
while proceed and i_ps < max_iterations:
for x, p in zip(model.parameters(), ps):
for x, p in zip([p for p in model.parameters() if p.requires_grad], ps):
x.data += epsilon*p.data
_, grads = logprob_grads()
proceed = torch.stack([is_finite(grad).prod() for grad in grads], dim=0).prod().data[0]
Expand All @@ -91,7 +91,7 @@ def thinning(thin_iterations, epsilon, lmin, lmax):
proceed = False
# otherwise keep new
if not proceed:
for p,x_prev in zip(model.parameters(), xs_prev):
for p,x_prev in zip([p for p in model.parameters() if p.requires_grad], xs_prev):
p.data = x_prev
logprob = logprob_prev
grads = grads_prev
Expand Down

0 comments on commit 3b4c2f4

Please sign in to comment.