Skip to content

Commit

Permalink
upper bound for sgpr - not quite working
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Oct 28, 2017
1 parent 11d5186 commit f549776
Show file tree
Hide file tree
Showing 2 changed files with 579 additions and 55 deletions.
108 changes: 53 additions & 55 deletions candlegp/models/sgpr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2016 James Hensman, alexggmatthews, Mark van der Wilk
# Copyright 2017 Thomas Viehmann
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,28 +55,29 @@ def compute_upper_bound(self):
num_data = self.Y.size(0)

Kdiag = self.kern.Kdiag(self.X)
Kuu = self.kern.K(self.Z) + tf.eye(num_inducing, dtype=settings.tf_float) * settings.numerics.jitter_level
Kuf = self.kern.K(self.Z, self.X)
jitter = Variable(torch.eye(num_inducing, out=self.Z.data.new())) * self.jitter_level
Kuu = self.kern.K(self.Z.get()) + jitter
Kuf = self.kern.K(self.Z.get(), self.X)

L = tf.cholesky(Kuu, upper=False)
LB = tf.cholesky(Kuu + self.likelihood.variance ** -1.0 * tf.matmul(Kuf, Kuf, transpose_b=True), upper=False)
L = torch.potrf(Kuu, upper=False)
LB = torch.potrf(Kuu + self.likelihood.variance.get() ** -1.0 * torch.matmul(Kuf, Kuf.t()), upper=False)

LinvKuf = tf.matrix_triangular_solve(L, Kuf, lower=True)
LinvKuf, _ = torch.gesv(Kuf, L) # could use triangular solve
# Using the Trace bound, from Titsias' presentation
c = tf.reduce_sum(Kdiag) - tf.reduce_sum(LinvKuf ** 2.0)
c = Kdiag.sum() - (LinvKuf ** 2.0).sum()
# Kff = self.kern.K(self.X)
# Qff = tf.matmul(Kuf, LinvKuf, transpose_a=True)

# Alternative bound on max eigenval:
# c = tf.reduce_max(tf.reduce_sum(tf.abs(Kff - Qff), 0))
corrected_noise = self.likelihood.variance + c
corrected_noise = self.likelihood.variance.get() + c

const = -0.5 * num_data * tf.log(2 * np.pi * self.likelihood.variance)
logdet = tf.reduce_sum(tf.log(tf.diag_part(L))) - tf.reduce_sum(tf.log(tf.diag_part(LB)))
const = -0.5 * num_data * torch.log(2 * float(numpy.pi) * self.likelihood.variance.get())
logdet = torch.diag(L).log().sum() - torch.diag(LB).log().sum()

LC = tf.cholesky(Kuu + corrected_noise ** -1.0 * tf.matmul(Kuf, Kuf, transpose_b=True), upper=True)
v = tf.matrix_triangular_solve(LC, corrected_noise ** -1.0 * tf.matmul(Kuf, self.Y), lower=True)
quad = -0.5 * corrected_noise ** -1.0 * tf.reduce_sum(self.Y ** 2.0) + 0.5 * tf.reduce_sum(v ** 2.0)
LC = torch.potrf(Kuu + corrected_noise ** -1.0 * torch.matmul(Kuf, Kuf.t()), upper=False)
v, _ = torch.gesv(corrected_noise ** -1.0 * torch.matmul(Kuf, self.Y), LC)
quad = -0.5 * corrected_noise ** -1.0 * (self.Y ** 2.0).sum() + 0.5 * (v ** 2.0).sum()

return const + logdet + quad

Expand Down Expand Up @@ -128,7 +130,7 @@ def compute_log_likelihood(self):
err = self.Y - self.mean_function(self.X)
Kdiag = self.kern.Kdiag(self.X)
Kuf = self.kern.K(self.Z.get(), self.X)
jitter = Variable(torch.eye(self.Z.get().size(0), out=self.Z.data.new())) * self.jitter_level
jitter = Variable(torch.eye(num_inducing, out=self.Z.data.new())) * self.jitter_level
Kuu = self.kern.K(self.Z.get()) + jitter
L = torch.potrf(Kuu, upper=False)
sigma = self.likelihood.variance.get()**0.5
Expand Down Expand Up @@ -161,7 +163,7 @@ def predict_f(self, Xnew, full_cov=False):
num_inducing = self.Z.size(0)
err = self.Y - self.mean_function(self.X)
Kuf = self.kern.K(self.Z.get(), self.X)
jitter = Variable(torch.eye(self.Z.get().size(0), out=self.Z.data.new())) * self.jitter_level
jitter = Variable(torch.eye(num_inducing, out=self.Z.data.new())) * self.jitter_level
Kuu = self.kern.K(self.Z.get()) + jitter
Kus = self.kern.K(self.Z.get(), Xnew)
sigma = self.likelihood.variance.get()**0.5
Expand All @@ -176,10 +178,10 @@ def predict_f(self, Xnew, full_cov=False):
mean = torch.matmul(tmp2.t(), c)
if full_cov:
var = self.kern.K(Xnew) + torch.matmul(tmp2.t(), tmp2) - torch.matmul(tmp1.t(), tmp1)
var = var.unsqueeze(2).expand(var.size(0),var.size(0), self.Y.size(1))
var = var.unsqueeze(2).expand(-1, -1, self.Y.size(1))
else:
var = self.kern.Kdiag(Xnew) + (tmp2**2).sum(0) - (tmp1**2).sum(0)
var = var.unsqueeze(1).expand(var.size(0), self.Y.size(1))
var = var.unsqueeze(1).expand(-1, self.Y.size(1))
return mean + self.mean_function(Xnew), var


Expand Down Expand Up @@ -209,37 +211,36 @@ def __init__(self, X, Y, kern, Z, mean_function=None, **kwargs): # was mean_func
This method only works with a Gaussian likelihood.
"""
X = DataHolder(X)
Y = DataHolder(Y)
likelihood = likelihoods.Gaussian()
GPModel.__init__(self, X, Y, kern, likelihood, mean_function, **kwargs)
self.Z = Parameter(Z)
self.num_data = X.shape[0]
self.num_latent = Y.shape[1]

def _build_common_terms(self):
num_inducing = tf.shape(self.Z)[0]
likelihood = likelihoods.Gaussian(ttype=type(X.data))
super(SGPR,self).__init__(X, Y, kern, likelihood, mean_function, **kwargs)
self.Z = parameter.Param(Z)
self.num_data = X.size(0)
self.num_latent = Y.size(1)

def _common_terms(self):
num_inducing = self.Z.size(0)
err = self.Y - self.mean_function(self.X) # size N x R
Kdiag = self.kern.Kdiag(self.X)
Kuf = self.kern.K(self.Z, self.X)
Kuu = self.kern.K(self.Z) + tf.eye(num_inducing, dtype=settings.tf_float) * settings.jitter
Kuf = self.kern.K(self.Z.get(), self.X)
jitter = Variable(torch.eye(num_inducing, out=self.Z.data.new())) * self.jitter_level
Kuu = self.kern.K(self.Z.get()) + jitter

Luu = tf.cholesky(Kuu) # => Luu Luu^T = Kuu
V = tf.matrix_triangular_solve(Luu, Kuf) # => V^T V = Qff = Kuf^T Kuu^-1 Kuf
Luu = torch.potrf(Kuu, upper=False) # => Luu Luu^T = Kuu
V, _ = torch.gesv(Kuf, Luu) # => V^T V = Qff = Kuf^T Kuu^-1 Kuf

diagQff = tf.reduce_sum(tf.square(V), 0)
nu = Kdiag - diagQff + self.likelihood.variance
diagQff = (V**2).sum(0)
nu = Kdiag - diagQff + self.likelihood.variance.get()

B = tf.eye(num_inducing, dtype=settings.tf_float) + tf.matmul(V / nu, V, transpose_b=True)
L = tf.cholesky(B)
beta = err / tf.expand_dims(nu, 1) # size N x R
alpha = tf.matmul(V, beta) # size N x R
B = torch.eye(num_inducing, out=V.data.new()) + torch.matmul(V / nu, V.t())
L = torch.potrf(B, upper=False)
beta = err / nu.unsqueeze(1) # size N x R
alpha = torch.matmul(V, beta) # size N x R

gamma = tf.matrix_triangular_solve(L, alpha, lower=True) # size N x R
gamma, _ = torch.gesv(alpha, L) # size N x R

return err, nu, Luu, L, alpha, beta, gamma

def _build_likelihood(self):
def compute_log_likelihood(self):
"""
Construct a tensorflow function to compute the bound on the marginal
likelihood.
Expand All @@ -263,10 +264,9 @@ def _build_likelihood(self):
# and let \alpha = V \beta
# then Mahalanobis term = -0.5* ( \beta^T err - \alpha^T Solve( I + V \diag( \nu^{-1} ) V^T, alpha ) )

err, nu, Luu, L, alpha, beta, gamma = self._build_common_terms()
err, nu, Luu, L, alpha, beta, gamma = self._common_terms()

mahalanobisTerm = -0.5 * tf.reduce_sum(tf.square(err) / tf.expand_dims(nu, 1)) \
+ 0.5 * tf.reduce_sum(tf.square(gamma))
mahalanobisTerm = -0.5 * (err**2 / nu.unsqueeze(1)).sum() + 0.5 * (gamma**2).sum()

# We need to compute the log normalizing term -N/2 \log 2 pi - 0.5 \log \det( K_fitc )

Expand All @@ -277,8 +277,8 @@ def _build_likelihood(self):
# = \log [ \det \diag( \nu ) \det( I + V \diag( \nu^{-1} ) V^T ) ]
# = \log [ \det \diag( \nu ) ] + \log [ \det( I + V \diag( \nu^{-1} ) V^T ) ]

constantTerm = -0.5 * self.num_data * tf.log(tf.constant(2. * np.pi, settings.tf_float))
logDeterminantTerm = -0.5 * tf.reduce_sum(tf.log(nu)) - tf.reduce_sum(tf.log(tf.matrix_diag_part(L)))
constantTerm = -0.5 * self.num_data * float(2*numpy.pi)
logDeterminantTerm = -0.5 * nu.log().sum() - torch.diag(L).log().sum()
logNormalizingTerm = constantTerm + logDeterminantTerm

return mahalanobisTerm + logNormalizingTerm * self.num_latent
Expand All @@ -288,22 +288,20 @@ def _build_predict(self, Xnew, full_cov=False):
Compute the mean and variance of the latent function at some new points
Xnew.
"""
_, _, Luu, L, _, _, gamma = self._build_common_terms()
Kus = self.kern.K(self.Z, Xnew) # size M x Xnew
_, _, Luu, L, _, _, gamma = self._common_terms()
Kus = self.kern.K(self.Z.get(), Xnew) # size M x Xnew

w = tf.matrix_triangular_solve(Luu, Kus, lower=True) # size M x Xnew
w, _ = torch.gesv(Kus, Luu) # size M x Xnew

tmp = tf.matrix_triangular_solve(tf.transpose(L), gamma, lower=False)
mean = tf.matmul(w, tmp, transpose_a=True) + self.mean_function(Xnew)
intermediateA = tf.matrix_triangular_solve(L, w, lower=True)
tmp, _ = torch.gesv(gamma, L.t())
mean = torch.matmul(w.t(), tmp) + self.mean_function(Xnew)
intermediateA, _ = torch.gesv(w, L)

if full_cov:
var = self.kern.K(Xnew) - tf.matmul(w, w, transpose_a=True) \
+ tf.matmul(intermediateA, intermediateA, transpose_a=True)
var = tf.tile(tf.expand_dims(var, 2), tf.stack([1, 1, tf.shape(self.Y)[1]]))
var = self.kern.K(Xnew) - torch.matmul(w.t(), w) + torch.matmul(intermediateA.t(), intermediateA)
var = torch.unsqueeze(2).expand(-1, -1, self.Y.size(1))
else:
var = self.kern.Kdiag(Xnew) - tf.reduce_sum(tf.square(w), 0) \
+ tf.reduce_sum(tf.square(intermediateA), 0) # size Xnew,
var = tf.tile(tf.expand_dims(var, 1), tf.stack([1, tf.shape(self.Y)[1]]))
var = self.kern.Kdiag(Xnew) - (w**2).sum(0) + (intermediateA**2).sum(0) # size Xnew,
var = torch.unsuqeeze(2).expand(-1, self.Y.size(1))

return mean, var
Loading

0 comments on commit f549776

Please sign in to comment.