-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathloss.py
288 lines (254 loc) · 11.3 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
import math
import torch.nn.functional as F
import pdb
class CrossEntropyLabelSmooth(nn.Module):
"""Cross entropy loss with label smoothing regularizer.
Reference:
Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
Equation: y = (1 - epsilon) * y + epsilon / K.
Args:
num_classes (int): number of classes.
epsilon (float): weight.
"""
def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
super(CrossEntropyLabelSmooth, self).__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.use_gpu = use_gpu
self.logsoftmax = nn.LogSoftmax(dim=1)
self.reduction = reduction
def forward(self, inputs, targets):
"""
Args:
inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
targets: ground truth labels with shape (num_classes)
"""
log_probs = self.logsoftmax(inputs)
targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
if self.use_gpu: targets = targets.cuda()
targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
loss = (- targets * log_probs).sum(dim=1)
if self.reduction:
return loss.mean()
else:
return loss
class MMD_loss(nn.Module):
def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):
super(MMD_loss, self).__init__()
self.kernel_num = kernel_num
self.kernel_mul = kernel_mul
self.fix_sigma = None
self.kernel_type = kernel_type
def guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def linear_mmd2(self, f_of_X, f_of_Y):
loss = 0.0
delta = f_of_X - f_of_Y
loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
return loss
def forward(self, source, target):
if self.kernel_type == 'linear':
return self.linear_mmd2(source, target)
elif self.kernel_type == 'rbf':
batch_size = int(source.size()[0])
kernels = self.guassian_kernel(
source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
with torch.no_grad():
XX = torch.mean(kernels[:batch_size, :batch_size])
YY = torch.mean(kernels[batch_size:, batch_size:])
XY = torch.mean(kernels[:batch_size, batch_size:])
YX = torch.mean(kernels[batch_size:, :batch_size])
loss = torch.mean(XX + YY - XY - YX)
torch.cuda.empty_cache()
return loss
def bsp_loss(feature):
train_bs = feature.size(0) // 2
feature_s = feature.narrow(0, 0, train_bs)
feature_t = feature.narrow(0, train_bs, train_bs)
_, s_s, _ = torch.svd(feature_s)
_, s_t, _ = torch.svd(feature_t)
sigma = torch.pow(s_s[0], 2) + torch.pow(s_t[0], 2)
sigma *= 0.0001
return sigma
def Entropy(input_):
bs = input_.size(0)
epsilon = 1e-5
entropy = -input_ * torch.log(input_ + epsilon)
entropy = torch.sum(entropy, dim=1)
return entropy
def grl_hook(coeff):
def fun1(grad):
return -coeff*grad.clone()
return fun1
def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
softmax_output = input_list[1].detach()
feature = input_list[0]
if random_layer is None:
op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
else:
random_out = random_layer.forward([feature, softmax_output])
ad_out = ad_net(random_out.view(-1, random_out.size(1)))
batch_size = softmax_output.size(0) // 2
dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
if entropy is not None:
entropy.register_hook(grl_hook(coeff))
entropy = 1.0+torch.exp(-entropy)
source_mask = torch.ones_like(entropy)
source_mask[feature.size(0)//2:] = 0
source_weight = entropy*source_mask
target_mask = torch.ones_like(entropy)
target_mask[0:feature.size(0)//2] = 0
target_weight = entropy*target_mask
weight = source_weight / torch.sum(source_weight).detach().item() + \
target_weight / torch.sum(target_weight).detach().item()
return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
else:
return nn.BCELoss()(ad_out, dc_target)
def DANN(features, ad_net, entropy=None, coeff=None):
ad_out = ad_net(features)
batch_size = ad_out.size(0) // 2
dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float().cuda()
if entropy is not None:
entropy.register_hook(grl_hook(coeff))
entropy = 1.0+torch.exp(-entropy)
source_mask = torch.ones_like(entropy)
source_mask[feature.size(0)//2:] = 0
source_weight = entropy*source_mask
target_mask = torch.ones_like(entropy)
target_mask[0:feature.size(0)//2] = 0
target_weight = entropy*target_mask
weight = source_weight / torch.sum(source_weight).detach().item() + \
target_weight / torch.sum(target_weight).detach().item()
return torch.sum(weight.view(-1, 1) * nn.BCELoss(reduction='none')(ad_out, dc_target)) / torch.sum(weight).detach().item()
else:
return nn.BCELoss()(ad_out, dc_target)
def CORAL(source, target):
d = source.size(1)
ns, nt = source.size(0), target.size(0)
# source covariance
tmp_s = torch.ones((1, ns)).cuda() @ source
cs = (source.t() @ source - (tmp_s.t() @ tmp_s) / ns) / (ns - 1)
# target covariance
tmp_t = torch.ones((1, nt)).cuda() @ target
ct = (target.t() @ target - (tmp_t.t() @ tmp_t) / nt) / (nt - 1)
# frobenius norm
loss = (cs - ct).pow(2).sum().sqrt()
loss = loss / (4 * d * d)
return loss
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
# pdb.set_trace()
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)#/len(kernel_val)
def DAN(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss1 = 0
for s1 in range(batch_size):
for s2 in range(s1+1, batch_size):
t1, t2 = s1+batch_size, s2+batch_size
loss1 += kernels[s1, s2] + kernels[t1, t2]
loss1 = loss1 / float(batch_size * (batch_size - 1) // 2)
loss2 = 0
for s1 in range(batch_size):
for s2 in range(batch_size):
t1, t2 = s1+batch_size, s2+batch_size
loss2 -= kernels[s1, t2] + kernels[s2, t1]
loss2 = loss2 / float(batch_size * batch_size)
return loss1 + loss2
def DAN_Linear(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
# Linear version
loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
return loss / float(batch_size)
def RTN():
pass
def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fix_sigma_list=[None, 1.68]):
batch_size = int(source_list[0].size()[0])
layer_num = len(source_list)
joint_kernels = None
for i in range(layer_num):
source = source_list[i]
target = target_list[i]
kernel_mul = kernel_muls[i]
kernel_num = kernel_nums[i]
fix_sigma = fix_sigma_list[i]
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels
loss1 = 0
for s1 in range(batch_size):
for s2 in range(s1 + 1, batch_size):
t1, t2 = s1 + batch_size, s2 + batch_size
loss1 += joint_kernels[s1, s2] + joint_kernels[t1, t2]
loss1 = loss1 / float(batch_size * (batch_size - 1) // 2)
loss2 = 0
for s1 in range(batch_size):
for s2 in range(batch_size):
t1, t2 = s1 + batch_size, s2 + batch_size
loss2 -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
loss2 = loss2 / float(batch_size * batch_size)
return loss1 + loss2
def JAN_Linear(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[5, 1], fix_sigma_list=[None, 1.68]):
batch_size = int(source_list[0].size()[0])
layer_num = len(source_list)
joint_kernels = None
for i in range(layer_num):
source = source_list[i]
target = target_list[i]
kernel_mul = kernel_muls[i]
kernel_num = kernel_nums[i]
fix_sigma = fix_sigma_list[i]
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels
# Linear version
loss = 0
for i in range(batch_size):
s1, s2 = i, (i+1)%batch_size
t1, t2 = s1+batch_size, s2+batch_size
loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
return loss / float(batch_size)
loss_dict = {"DAN":DAN, "DAN_Linear":DAN_Linear, "RTN":RTN, "JAN":JAN, "JAN_Linear":JAN_Linear}