diff --git a/candlegp/training/hmc.py b/candlegp/training/hmc.py index 747d189..9f1d325 100644 --- a/candlegp/training/hmc.py +++ b/candlegp/training/hmc.py @@ -50,7 +50,7 @@ def hmc_sample(model, num_samples, epsilon, lmin=1, lmax=2, thin=1, burn=0): def logprob_grads(): logprob = -model.objective() - grads = torch.autograd.grad(logprob, model.parameters()) + grads = torch.autograd.grad(logprob, [p for p in model.parameters() if p.requires_grad]) return logprob, grads def thinning(thin_iterations, epsilon, lmin, lmax):