Skip to content

Commit

Permalink
fixes for hmc progress - see notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Oct 25, 2017
1 parent 9630de8 commit bbc64fd
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 132 deletions.
2 changes: 1 addition & 1 deletion candlegp/densities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def multivariate_normal(x, mu, L):
"""
d = x - mu
if d.dim()==1:
d = d.unsqeeze(1)
d = d.unsqueeze(1)
alpha,_ = torch.gesv(d, L)
alpha = alpha.squeeze(1)
num_col = 1 if x.dim() == 1 else x.size(1)
Expand Down
4 changes: 2 additions & 2 deletions candlegp/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def exKxz(self, Z, Xmu, Xcov):
fXcovt = tf.concat((Xcov[0, :-1, :, :], Xcov[1, :-1, :, :]), 2) # NxDx2D
fXcovb = tf.concat((tf.transpose(Xcov[1, :-1, :, :], (0, 2, 1)), Xcov[0, 1:, :, :]), 2)
fXcov = tf.concat((fXcovt, fXcovb), 1)
return mvnquad(lambda x: self.K(x[:, :D], Z).unsqeeze(2) *
x[:, D:].unsqeeze(1),
return mvnquad(lambda x: self.K(x[:, :D], Z).unsqueeze(2) *
x[:, D:].unsqueeze(1),
fXmu, fXcov, self.num_gauss_hermite_points,
2 * D, Dout=(M, D))

Expand Down
4 changes: 2 additions & 2 deletions candlegp/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def _check_targets(self, Y_np): # pylint: disable=R0201


class Gaussian(Likelihood):
def __init__(self):
def __init__(self, ttype=torch.FloatTensor):
Likelihood.__init__(self)
self.variance = parameter.PositiveParam(1.0)
self.variance = parameter.PositiveParam(1.0, ttype=ttype)

def logp(self, F, Y):
return densities.gaussian(F, Y, self.variance.get())
Expand Down
8 changes: 4 additions & 4 deletions candlegp/models/gpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, X, Y, kern, mean_function=None, **kwargs):
Y is a data matrix, size N x R
kern, mean_function are appropriate GPflow objects
"""
likelihood = likelihoods.Gaussian()
likelihood = likelihoods.Gaussian(ttype=type(X.data))
#X = DataHolder(X)
#Y = DataHolder(Y)
super(GPR,self).__init__(X, Y, kern, likelihood, mean_function, **kwargs)
Expand All @@ -55,7 +55,7 @@ def compute_log_likelihood(self):
\log p(Y | theta).
"""
K = self.kern.K(self.X) + Variable(torch.eye(self.X.size(0))) * self.likelihood.variance.get()
K = self.kern.K(self.X) + Variable(torch.eye(self.X.size(0),out=self.X.data.new())) * self.likelihood.variance.get()
L = torch.potrf(K, upper=False)
m = self.mean_function(self.X)
return densities.multivariate_normal(self.Y, m, L)
Expand All @@ -74,14 +74,14 @@ def predict_f(self, Xnew, full_cov=False):
"""
Kx = self.kern.K(self.X, Xnew)
K = self.kern.K(self.X) + Variable(torch.eye(self.X.size(0))) * self.likelihood.variance.get()
K = self.kern.K(self.X) + Variable(torch.eye(self.X.size(0),out=self.X.data.new())) * self.likelihood.variance.get()
L = torch.potrf(K, upper=False)
A,_ = torch.gesv(Kx, L) # could use triangular solve, note gesv has B first, then A in AX=B
V,_ = torch.gesv(self.Y - self.mean_function(self.X),L) # could use triangular solve
fmean = torch.mm(A.t(), V) + self.mean_function(Xnew)
if full_cov:
fvar = self.kern.K(Xnew) - torch.mm(A.t(), A)
fvar = fvar.unsqeeze(2).expand(fvar.size(0), fvar.size(1), self.Y.size(1))
fvar = fvar.unsqueeze(2).expand(fvar.size(0), fvar.size(1), self.Y.size(1))
else:
fvar = self.kern.Kdiag(Xnew) - (A**2).sum(0)
fvar = fvar.view(-1,1)
Expand Down
13 changes: 7 additions & 6 deletions candlegp/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ class GPModel(torch.nn.Module):
>>> m.Y = Ynew
"""

def __init__(self, X, Y, kern, likelihood, mean_function, name=None):
def __init__(self, X, Y, kern, likelihood, mean_function, name=None, jitter_level=1e-6):
super(GPModel, self).__init__()
self.name = name
self.mean_function = mean_function or mean_functions.Zero()
self.kern = kern
self.likelihood = likelihood
self.jitter_level = jitter_level

if isinstance(X, numpy.ndarray):
# X is a data matrix; each row represents one instance
Expand Down Expand Up @@ -109,14 +110,14 @@ def predict_f_samples(self, Xnew, num_samples):
Produce samples from the posterior latent function(s) at the points
Xnew.
"""
mu, var = self.predict_f_full_cov(Xnew, full_cov=True)
jitter = Variable(torch.eye(tf.shape(mu)[0])) * settings.numerics.jitter_level # TV-Todo: GPU-friendly
mu, var = self.predict_f(Xnew, full_cov=True)
jitter = Variable(torch.eye(mu.size(0), out=mu.data.new())) * self.jitter_level # TV-Todo: GPU-friendly
samples = []
for i in range(self.num_latent): # TV-Todo: batch??
L = torch.potrf(var[:, :, i] + jitter)
V = Variable(torch.randn(L.size(0), num_samples))
L = torch.potrf(var[:, :, i] + jitter, upper=False)
V = Variable(mu.data.new(L.size(0), num_samples).normal_())
samples.append(mu[:, i:i + 1] + torch.matmul(L, V))
return torch.stack(samples, axis=0) # TV-Todo: transpose?
return torch.stack(samples, dim=0) # TV-Todo: transpose?

def predict_y(self, Xnew):
"""
Expand Down
8 changes: 5 additions & 3 deletions candlegp/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ def log_jacobian_tensor(self):
@abc.abstractstaticmethod
def untransform(t, out=None):
pass
def __new__(cls, val, prior=None):
def __init__(self, val, prior=None, ttype=torch.FloatTensor):
pass
def __new__(cls, val, prior=None, ttype=torch.FloatTensor): # for some reaosn unknown to me, it is impossible to pass a different tensor Type as ttype...
if isinstance(val, torch.autograd.Variable):
val = val.data
elif numpy.isscalar(val):
val = torch.FloatTensor([val])
val = ttype([val])
raw = cls.untransform(val)
obj = super(ParamWithPrior, cls).__new__(cls, raw)
obj.prior = prior
Expand All @@ -45,7 +47,7 @@ def set(self, t):
if isinstance(t, torch.autograd.Variable):
t = t.data
elif numpy.isscalar(t):
t = torch.FloatTensor([t])
t = self.data.new(1).fill_(t)
self.untransform(t, out=self.data)
def get_prior(self):
if self.prior is None:
Expand Down
50 changes: 25 additions & 25 deletions candlegp/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

from . import densities

def wrap(x, **argd):
def wrap(x, ttype=torch.Tensor, **argd):
if numpy.isscalar(x):
x = Variable(torch.Tensor([x]),**argd)
x = Variable(ttype([x]),**argd)
elif isinstance(x, [torch.Tensor, torch.DoubleTensor]):
x = Variable(x, **argd)
return x
Expand All @@ -31,24 +31,24 @@ class Prior:

# these should be ocnverted to use torch.densities
class Gaussian(Prior):
def __init__(self, mu, var):
def __init__(self, mu, var, ttype=torch.Tensor):
Prior.__init__(self)
self.mu = wrap(mu)
self.var = wrap(var)
self.mu = wrap(mu, ttype=ttype)
self.var = wrap(var, ttype=ttype)

def logp(self, x):
return densities.gaussian(x, self.mu, self.var).sum()

def sample(self, shape=(1,)):
return self.mu + (self.var**0.5) * Variable(torch.randn(*shape))
return self.mu + (self.var**0.5) * Variable(self.mu.data.new(*shape).normal_())

def __str__(self):
return "N("+str(self.mu.data.cpu().numpy()) + "," + str(self.var.data.cpu().numpy()) + ")"


class LogNormal(Prior):
def __init__(self, mu, var):
Prior.__init__(self)
def __init__(self, mu, var, ttype=torch.Tensor):
Prior.__init__(self)
self.mu = wrap(mu)
self.var = wrap(var)

Expand All @@ -63,58 +63,59 @@ def __str__(self):


class Gamma(Prior):
def __init__(self, shape, scale):
def __init__(self, shape, scale, ttype=torch.Tensor):
Prior.__init__(self)
self.shape = wrap(shape)
self.scale = wrap(scale)
self.shape = wrap(shape, ttype=ttype)
self.scale = wrap(scale, ttype=ttype)

def logp(self, x):
return densities.gamma(self.shape, self.scale, x).sum()

def sample(self, shape=(1,)):
return Variable(torch.Tensor(numpy.random.gamma(self.shape, self.scale, size=shape)))
return Variable(type(self.shape.data)(numpy.random.gamma(self.shape, self.scale, size=shape)))

def __str__(self):
return "Ga("+str(self.shape.data.cpu().numpy()) + "," + str(self.scale.data.cpu().numpy()) + ")"


class Laplace(Prior):
def __init__(self, mu, sigma):
def __init__(self, mu, sigma, ttype=torch.Tensor):
Prior.__init__(self)
self.mu = wrap(mu)
self.sigma = wrap(sigma)
self.mu = wrap(mu, ttype=ttype)
self.sigma = wrap(sigma, ttype=ttype)

def logp(self, x):
return densities.laplace(self.mu, self.sigma, x).sum()

def sample(self, shape=(1,)):
return Variable(torch.Tensor(numpy.random.laplace(self.mu, self.sigma, size=shape)))
return Variable(type(self.shape.data)(numpy.random.laplace(self.mu, self.sigma, size=shape)))

def __str__(self):
return "Lap.("+str(self.mu.data.cpu().numpy()) + "," + str(self.sigma.data.cpu().numpy()) + ")"


class Beta(Prior):
def __init__(self, a, b):
def __init__(self, a, b, ttype=torch.Tensor):
Prior.__init__(self)
self.a = wrap(a)
self.b = wrap(b)
self.a = wrap(a, ttype=ttype)
self.b = wrap(b, ttype=ttype)

def logp(self, x):
return tf.reduce_sum(densities.beta(self.a, self.b, x))

def sample(self, shape=(1,)):
return Variable(torch.Tensor(self.a, self.b, size=shape))
BROKEN
return Variable(type(self.shape.data)(self.a, self.b, size=shape))

def __str__(self):
return "Beta(" + str(self.a.data.cpu().numpy()) + "," + str(self.b.data.cpu().numpy()) + ")"


class Uniform(Prior):
def __init__(self, lower=0., upper=1.):
def __init__(self, lower=0., upper=1., ttype=torch.Tensor):
Prior.__init__(self)
lower = wrap(lower)
upper = wrap(upper)
lower = wrap(lower, ttype=ttype)
upper = wrap(upper, ttype=ttype)
self.log_height = - torch.log(upper - lower)
self.lower, self.upper = lower, upper

Expand All @@ -123,8 +124,7 @@ def logp(self, x):
return self.log_height * x.size(0)

def sample(self, shape=(1,)):
return (self.lower +
(self.upper - self.lower)*torch.rand(*shape))
return self.lower +(self.upper - self.lower)*self.lower.new(*shape).normal_()

def __str__(self):
return "U("+str(self.lower.data.cpu().numpy()) + "," + str(self.upper.data.cpu().numpy()) + ")"
4 changes: 2 additions & 2 deletions candlegp/training/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from torch.autograd import Variable
import numpy

INF = torch.Tensor([100000]).exp()
def is_finite(x):
INF = x.data.new(1).fill_(numpy.inf)
if isinstance(x, Variable):
return Variable((x.data<INF) & (x.data > -INF))
else:
Expand Down Expand Up @@ -59,7 +59,7 @@ def thinning(thin_iterations, epsilon, lmin, lmax):
xs_prev = [p.data.clone() for p in model.parameters()]
grads_prev = grads
logprob_prev = logprob
ps_init = [Variable(torch.randn(*p.size())) for p in model.parameters()]
ps_init = [Variable(xs_prev[0].new(*p.size()).normal_()) for p in model.parameters()]
ps = [p + 0.5 * epsilon * grad for p,grad in zip(ps_init, grads_prev)]

max_iterations = int((torch.rand(1)*(lmax+1-lmin)+lmin)[0])
Expand Down
233 changes: 146 additions & 87 deletions notebooks/gp_regression.ipynb

Large diffs are not rendered by default.

0 comments on commit bbc64fd

Please sign in to comment.