diff --git a/candlegp/training/hmc.py b/candlegp/training/hmc.py index 9f1d325..3350d3b 100644 --- a/candlegp/training/hmc.py +++ b/candlegp/training/hmc.py @@ -56,10 +56,10 @@ def logprob_grads(): def thinning(thin_iterations, epsilon, lmin, lmax): logprob, grads = logprob_grads() for i in range(thin): - xs_prev = [p.data.clone() for p in model.parameters()] + xs_prev = [p.data.clone() for p in model.parameters() if p.requires_grad] grads_prev = grads logprob_prev = logprob - ps_init = [Variable(xs_prev[0].new(*p.size()).normal_()) for p in model.parameters()] + ps_init = [Variable(xs_prev[0].new(*p.size()).normal_()) for p in model.parameters() if p.requires_grad] ps = [p + 0.5 * epsilon * grad for p,grad in zip(ps_init, grads_prev)] max_iterations = int((torch.rand(1)*(lmax+1-lmin)+lmin)[0]) @@ -68,7 +68,7 @@ def thinning(thin_iterations, epsilon, lmin, lmax): proceed = True i_ps = 0 while proceed and i_ps < max_iterations: - for x, p in zip(model.parameters(), ps): + for x, p in zip([p for p in model.parameters() if p.requires_grad], ps): x.data += epsilon*p.data _, grads = logprob_grads() proceed = torch.stack([is_finite(grad).prod() for grad in grads], dim=0).prod().data[0] @@ -91,7 +91,7 @@ def thinning(thin_iterations, epsilon, lmin, lmax): proceed = False # otherwise keep new if not proceed: - for p,x_prev in zip(model.parameters(), xs_prev): + for p,x_prev in zip([p for p in model.parameters() if p.requires_grad], xs_prev): p.data = x_prev logprob = logprob_prev grads = grads_prev