Skip to content

Commit

Permalink
kernels and exp likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Nov 5, 2017
1 parent dc1e26e commit 13ff9a4
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 15 deletions.
75 changes: 74 additions & 1 deletion candlegp/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch
from torch.autograd import Variable

import numpy
from . import parameter

class Kern(torch.nn.Module):
Expand Down Expand Up @@ -366,3 +366,76 @@ def K(self, X, X2=None, presliced=False):
r = self.euclid_dist(X, X2)
return self.variance.get() * torch.cos(r)


class Combination(Kern):
"""
Combine a list of kernels, e.g. by adding or multiplying (see inheriting
classes).
The names of the kernels to be combined are generated from their class
names.
"""

def __init__(self, kern_list, name=None):
for k in kern_list:
assert isinstance(k, Kern), "can only add/multiply Kern instances"

input_dim = numpy.max([k.input_dim if type(k.active_dims) is slice else numpy.max(k.active_dims) + 1 for k in kern_list])
super(Combination, self).__init__(input_dim=input_dim, name=name)

# add kernels to a list, flattening out instances of this class therein
self.kern_list = torch.nn.ModuleList()
for k in kern_list:
if isinstance(k, self.__class__):
self.kern_list.extend(k.kern_list)
else:
self.kern_list.append(k)

@property
def on_separate_dimensions(self):
"""
Checks whether the kernels in the combination act on disjoint subsets
of dimensions. Currently, it is hard to asses whether two slice objects
will overlap, so this will always return False.
:return: Boolean indicator.
"""
if numpy.any([isinstance(k.active_dims, slice) for k in self.kern_list]):
# Be conservative in the case of a slice object
return False
else:
dimlist = [k.active_dims for k in self.kern_list]
overlapping = False
for i, dims_i in enumerate(dimlist):
for dims_j in dimlist[i + 1:]:
if numpy.any(dims_i.reshape(-1, 1) == dims_j.reshape(1, -1)):
overlapping = True
return not overlapping


class Add(Combination):
def K(self, X, X2=None, presliced=False):
res = 0.0
for k in self.kern_list:
res += k.K(X, X2, presliced=presliced)
return res

def Kdiag(self, X, presliced=False):
res = 0.0
for k in self.kern_list:
res += k.Kdiag(X, presliced=presliced)
return res


class Prod(Combination):
def K(self, X, X2=None, presliced=False):
res = 1.0
for k in self.kern_list:
res *= k.K(X, X2, presliced=presliced)
return res

def Kdiag(self, X, presliced=False):
res = 1.0
for k in self.kern_list:
res *= k.Kdiag(X, presliced=presliced)
return res

48 changes: 36 additions & 12 deletions candlegp/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,15 @@ def predict_density(self, Fmu, Fvar, Y):
Here, we implement a default Gauss-Hermite quadrature routine, but some
likelihoods (Gaussian, Poisson) will implement specific cases.
"""
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)

gh_w = gh_w.reshape(-1, 1) / np.sqrt(np.pi)
shape = tf.shape(Fmu)
Fmu, Fvar, Y = [tf.reshape(e, (-1, 1)) for e in (Fmu, Fvar, Y)]
X = gh_x[None, :] * tf.sqrt(2.0 * Fvar) + Fmu

Y = tf.tile(Y, [1, self.num_gauss_hermite_points]) # broadcast Y to match X
gh_x, gh_w = quadrature.hermgauss(self.num_gauss_hermite_points, ttype=type(Fmu.data))

gh_w = gh_w.reshape(-1, 1) / float(numpy.sqrt(numpy.pi))
shape = Fmu.size()
Fmu, Fvar, Y = [e.view(-1, 1) for e in (Fmu, Fvar, Y)]
X = gh_x * (2.0 * Fvar)**0.5 + Fmu
Y = Y.expand(-1, self.num_gauss_hermite_points) # broadcast Y to match X
logp = self.logp(X, Y)
return tf.reshape(tf.log(tf.matmul(tf.exp(logp), gh_w)), shape)
return torch.matmul(logp.exp(), gh_w).view(*shape)

def variational_expectations(self, Fmu, Fvar, Y):
"""
Expand Down Expand Up @@ -137,10 +135,10 @@ def _check_targets(self, Y_np): # pylint: disable=R0201
and consists only of floats. The float requirement is so that AutoFlow
can work with Model.predict_density.
"""
if not len(Y_np.shape) == 2:
if not Y.dim() == 2:
raise ValueError('targets must be shape N x D')
if np.array(list(Y_np)).dtype != settings.np_float:
raise ValueError('use {}, even for discrete variables'.format(settings.np_float))
#if np.array(list(Y_np)).dtype != settings.np_float:
# raise ValueError('use {}, even for discrete variables'.format(settings.np_float))


class Gaussian(Likelihood):
Expand Down Expand Up @@ -204,3 +202,29 @@ def conditional_mean(self, F):
def conditional_variance(self, F):
p = self.invlink(F)
return p - p**2


class Exponential(Likelihood):
def __init__(self, invlink=torch.exp):
Likelihood.__init__(self)
self.invlink = invlink

def _check_targets(self, Y):
super(Exponential, self)._check_targets(Y)
if (Y < 0).any():
raise ValueError('exponential variables must be positive')

def logp(self, F, Y):
return densities.exponential(self.invlink(F), Y)

def conditional_mean(self, F):
return self.invlink(F)

def conditional_variance(self, F):
return (self.invlink(F))**2

def variational_expectations(self, Fmu, Fvar, Y):
if self.invlink is torch.exp:
return - torch.exp(-Fmu + Fvar / 2) * Y - Fmu
return super(Exponential, self).variational_expectations(Fmu, Fvar, Y)

1 change: 1 addition & 0 deletions candlegp/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .sgpr import SGPR
from .svgp import SVGP
from .vgp import VGP
from .gpmc import GPMC
4 changes: 2 additions & 2 deletions candlegp/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def get_prior(self):
if self.prior is None:
return 0.0

log_jacobian = self.log_jacobian() #(unconstrained_tensor)
logp_var = self.prior.logp(self.get())
log_jacobian = self.log_jacobian().sum() #(unconstrained_tensor)
logp_var = self.prior.logp(self.get()).sum()
return log_jacobian+logp_var

class PositiveParam(ParamWithPrior): # log(1+exp(r))
Expand Down

0 comments on commit 13ff9a4

Please sign in to comment.