Skip to content

Commit

Permalink
multiclass likelihood and example
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi committed Nov 25, 2017
1 parent 3b4c2f4 commit 59e03ce
Show file tree
Hide file tree
Showing 5 changed files with 753 additions and 36 deletions.
2 changes: 1 addition & 1 deletion candlegp/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def conditional(Xnew, X, kern, f, full_cov=False, q_sqrt=None, whiten=False, jit

if q_sqrt is not None:
if q_sqrt.dim() == 2:
LTA = A * q_sqrt.unsqueeze(2) # K x M x N
LTA = A * q_sqrt.t().unsqueeze(2) # K x M x N
elif q_sqrt.dim() == 3:
L = batch_tril(q_sqrt.permute(2, 0, 1)) # K x M x M
# A_tiled = tf.tile(tf.expand_dims(A, 0), tf.stack([num_func, 1, 1])) # I don't think I need this
Expand Down
1 change: 0 additions & 1 deletion candlegp/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def K(self, X, X2=None, presliced=False):
d = self.variance.get().expand(X.size(0))
return torch.diag(d)
else:
shape = tf.stack([tf.shape(X)[0], tf.shape(X2)[0]])
return Variable(X.data.new(X.size(0),X2.size(0)).zero_())


Expand Down
2 changes: 1 addition & 1 deletion candlegp/kullback_leiblers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def gauss_kl_white_diag(q_mu, q_sqrt):

KL = 0.5 * (q_mu**2).sum() # Mahalanobis term
KL += -0.5 * q_sqrt.numel()
KL -= q_sqrt().abs().log() # Log-det of q-cov
KL = KL - q_sqrt.abs().log().sum() # Log-det of q-cov
KL += 0.5 * (q_sqrt**2).sum() # Trace term
return KL

Expand Down
61 changes: 28 additions & 33 deletions candlegp/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy
import torch
from torch.autograd import Variable
from . import parameter
from . import quadrature
from . import densities
Expand Down Expand Up @@ -256,27 +257,25 @@ def __call__(self, F):
def prob_is_largest(self, Y, mu, var, gh_x, gh_w):
Y = Y.long()
# work out what the mean and variance is of the indicated latent function.
oh_on = tf.cast(tf.one_hot(tf.reshape(Y, (-1,)), self.num_classes, 1., 0.), settings.tf_float)
mu_selected = tf.reduce_sum(oh_on * mu, 1)
var_selected = tf.reduce_sum(oh_on * var, 1)
oh_on = Variable(mu.data.new(Y.numel(), self.num_classes).fill_(0.).scatter_(1,Y.data,1))
mu_selected = (oh_on * mu ).sum(1)
var_selected = (oh_on * var).sum(1)

# generate Gauss Hermite grid
X = tf.reshape(mu_selected, (-1, 1)) + gh_x * tf.reshape(
tf.sqrt(tf.clip_by_value(2. * var_selected, 1e-10, np.inf)), (-1, 1))
X = mu_selected.view(-1, 1) + gh_x * ((2. * var_selected).clamp(min=1e-10)**0.5).view(-1,1)

# compute the CDF of the Gaussian between the latent functions and the grid (including the selected function)
dist = (tf.expand_dims(X, 1) - tf.expand_dims(mu, 2)) / tf.expand_dims(
tf.sqrt(tf.clip_by_value(var, 1e-10, np.inf)), 2)
cdfs = 0.5 * (1.0 + tf.erf(dist / np.sqrt(2.0)))
dist = (X.unsqueeze(1) - mu.unsqueeze(2)) / (var.clamp(min=1e-10)**0.5).unsqueeze(2)
cdfs = 0.5 * (1.0 + torch.erf(dist / 2.0**0.5))

cdfs = cdfs * (1 - 2e-4) + 1e-4

# blank out all the distances on the selected latent function
oh_off = tf.cast(tf.one_hot(tf.reshape(Y, (-1,)), self.num_classes, 0., 1.), settings.tf_float)
cdfs = cdfs * tf.expand_dims(oh_off, 2) + tf.expand_dims(oh_on, 2)
oh_off = Variable(mu.data.new(Y.numel(), self.num_classes).fill_(1.).scatter_(1,Y.data,0))
cdfs = cdfs * oh_off.unsqueeze(2) + oh_on.unsqueeze(2)

# take the product over the latent functions, and the sum over the GH grid.
return tf.matmul(tf.reduce_prod(cdfs, reduction_indices=[1]), tf.reshape(gh_w / np.sqrt(np.pi), (-1, 1)))
return torch.matmul(cdfs.prod(1), gh_w.view(-1,1) / (numpy.pi**0.5))


class MultiClass(Likelihood):
Expand All @@ -286,65 +285,61 @@ def __init__(self, num_classes, invlink=None):
Currently the only valid choice
of inverse-link function (invlink) is an instance of RobustMax.
"""
Likelihood.__init__(self)
super(MultiClass, self).__init__()
self.num_classes = num_classes
if invlink is None:
invlink = RobustMax(self.num_classes)
elif not isinstance(invlink, RobustMax):
raise NotImplementedError
raise NotImplementedError("Multiclass currently only supports RobustMax link")
self.invlink = invlink

def _check_targets(self, Y_np):
super(MultiClass, self)._check_targets(Y_np)
if not set(Y_np.flatten()).issubset(set(np.arange(self.num_classes))):
if not set(Y_np.view(-1)).issubset(set(range(self.num_classes))):
raise ValueError('multiclass likelihood expects inputs to be in {0., 1., 2.,...,k-1}')
if Y_np.shape[1] != 1:
if Y_np.size(1) != 1:
raise ValueError('only one dimension currently supported for multiclass likelihood')

def logp(self, F, Y):
if isinstance(self.invlink, RobustMax):
hits = tf.equal(tf.expand_dims(tf.argmax(F, 1), 1), tf.cast(Y, tf.int64))
yes = tf.ones(tf.shape(Y), dtype=settings.tf_float) - self.invlink.epsilon
no = tf.zeros(tf.shape(Y), dtype=settings.tf_float) + self.invlink._eps_K1
p = tf.where(hits, yes, no)
return tf.log(p)
p = (torch.max(F, 1)[1].unsqueeze(1)==Y.long())*(1-self.invlink.epsilon-self.invlink._eps_K1)+self.invlink._eps_K1
return torch.log(p)
else:
raise NotImplementedError
raise NotImplementedError("Multiclass currently only supports RobustMax link")

def variational_expectations(self, Fmu, Fvar, Y):
if isinstance(self.invlink, RobustMax):
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
gh_x, gh_w = quadrature.hermgauss(self.num_gauss_hermite_points, ttype=type(Fmu.data))
p = self.invlink.prob_is_largest(Y, Fmu, Fvar, gh_x, gh_w)
return p * np.log(1 - self.invlink.epsilon) + (1. - p) * np.log(self.invlink._eps_K1)
return p * numpy.log(1 - self.invlink.epsilon) + (1. - p) * numpy.log(self.invlink._eps_K1)
else:
raise NotImplementedError
raise NotImplementedError("Multiclass currently only supports RobustMax link")

def predict_mean_and_var(self, Fmu, Fvar):
if isinstance(self.invlink, RobustMax):
# To compute this, we'll compute the density for each possible output
possible_outputs = [tf.fill(tf.stack([tf.shape(Fmu)[0], 1]), np.array(i, dtype=np.int64)) for i in
range(self.num_classes)]
possible_outputs = [Variable(Fmu.data.new().long().resize_(Fmu.size(0),1).fill_(i)) for i in range(self.num_classes)]
ps = [self._predict_non_logged_density(Fmu, Fvar, po) for po in possible_outputs]
ps = tf.transpose(tf.stack([tf.reshape(p, (-1,)) for p in ps]))
return ps, ps - tf.square(ps)
ps = torch.stack([p.view(-1) for p in ps],1)
return ps, ps - ps**2
else:
raise NotImplementedError
raise NotImplementedError("Multiclass currently only supports RobustMax link")

def predict_density(self, Fmu, Fvar, Y):
return tf.log(self._predict_non_logged_density(Fmu, Fvar, Y))
return torch.log(self._predict_non_logged_density(Fmu, Fvar, Y))

def _predict_non_logged_density(self, Fmu, Fvar, Y):
if isinstance(self.invlink, RobustMax):
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
gh_x, gh_w = quadrature.hermgauss(self.num_gauss_hermite_points, ttype=type(Fmu.data))
p = self.invlink.prob_is_largest(Y, Fmu, Fvar, gh_x, gh_w)
return p * (1 - self.invlink.epsilon) + (1. - p) * (self.invlink._eps_K1)
else:
raise NotImplementedError
raise NotImplementedError("Multiclass currently only supports RobustMax link")

def conditional_mean(self, F):
return self.invlink(F)

def conditional_variance(self, F):
p = self.conditional_mean(F)
return p - tf.square(p)
return p - p**2

723 changes: 723 additions & 0 deletions notebooks/multiclass.ipynb

Large diffs are not rendered by default.

0 comments on commit 59e03ce

Please sign in to comment.