diff --git a/candlegp/kernels.py b/candlegp/kernels.py index 0d5c5be..1f0d950 100644 --- a/candlegp/kernels.py +++ b/candlegp/kernels.py @@ -367,6 +367,110 @@ def K(self, X, X2=None, presliced=False): return self.variance.get() * torch.cos(r) +class ArcCosine(Kern): + """ + The Arc-cosine family of kernels which mimics the computation in neural + networks. The order parameter specifies the assumed activation function. + The Multi Layer Perceptron (MLP) kernel is closely related to the ArcCosine + kernel of order 0. The key reference is + + :: + + @incollection{NIPS2009_3628, + title = {Kernel Methods for Deep Learning}, + author = {Youngmin Cho and Lawrence K. Saul}, + booktitle = {Advances in Neural Information Processing Systems 22}, + year = {2009}, + url = {http://papers.nips.cc/paper/3628-kernel-methods-for-deep-learning.pdf} + } + """ + + implemented_orders = {0, 1, 2} + def __init__(self, input_dim, + order=0, + variance=1.0, weight_variances=1., bias_variance=1., + active_dims=None, ARD=False, name=None): + """ + - input_dim is the dimension of the input to the kernel + - order specifies the activation function of the neural network + the function is a rectified monomial of the chosen order. + - variance is the initial value for the variance parameter + - weight_variances is the initial value for the weight_variances parameter + defaults to 1.0 (ARD=False) or np.ones(input_dim) (ARD=True). + - bias_variance is the initial value for the bias_variance parameter + defaults to 1.0. + - active_dims is a list of length input_dim which controls which + columns of X are used. + - ARD specifies whether the kernel has one weight_variance per dimension + (ARD=True) or a single weight_variance (ARD=False). + """ + super(ArcCosine, self).__init__(input_dim, active_dims, name=name) + + if order not in self.implemented_orders: + raise ValueError('Requested kernel order is not implemented.') + self.order = order + + self.variance = parameter.PositiveParam(variance) + self.bias_variance = parameter.PositiveParam(variance) + self.ARD = ARD + if ARD: + if weight_variances is None: + weight_variances = self.variance.data.new(input_dim).fill_(1.0) + else: + # accepts float or Tensor: + weight_variances = weight_variances * self.variance.data.new(input_dim).fill_(1.0) + self.weight_variances = parameter.PositiveParam(weight_variances) + else: + if weight_variances is None: + weight_variances = 1.0 + self.weight_variances = parameter.PositiveParam(weight_variances) + + def _weighted_product(self, X, X2=None): + if X2 is None: + return (self.weight_variances.get() * (X**2)).sum(1) + self.bias_variance.get() + return torch.matmul(self.weight_variances.get() * X, X2.t()) + self.bias_variance.get() + + def _J(self, theta): + """ + Implements the order dependent family of functions defined in equations + 4 to 7 in the reference paper. + """ + if self.order == 0: + return float(numpy.pi) - theta + elif self.order == 1: + return torch.sin(theta) + (float(numpy.pi) - theta) * torch.cos(theta) + elif self.order == 2: + return 3. * torch.sin(theta) * torch.cos(theta) + (float(numpy.pi) - theta) * (1. + 2. * torch.cos(theta) ** 2) + + def K(self, X, X2=None, presliced=False): + if not presliced: + X, X2 = self._slice(X, X2) + + X_denominator = self._weighted_product(X)**0.5 + if X2 is None: + X2 = X + X2_denominator = X_denominator + else: + X2_denominator = self._weighted_product(X2)**0.5 + + numerator = self._weighted_product(X, X2) + cos_theta = numerator / X_denominator[:, None] / X2_denominator[None, :] + jitter = 1e-15 + theta = torch.acos(jitter + (1 - 2 * jitter) * cos_theta) + + return ( self.variance.get() * (1. / float(numpy.pi)) * self._J(theta) + *X_denominator[:, None] ** self.order + *X2_denominator[None, :] ** self.order) + + def Kdiag(self, X, presliced=False): + if not presliced: + X, _ = self._slice(X, None) + + X_product = self._weighted_product(X) + theta = 0 + return self.variance.get() * (1. / float(numpy.pi)) * self._J(theta) * X_product ** self.order + + class Combination(Kern): """ Combine a list of kernels, e.g. by adding or multiplying (see inheriting