-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
150 lines (125 loc) · 5.23 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# list all the additional loss functions
from inspect import ArgSpec
import torch
import torch.nn as nn
import torch.nn.functional as F
from opts import parser
################## entropy loss (continuous target) #####################
def cross_entropy_soft(pred):
softmax = nn.Softmax(dim=1)
logsoftmax = nn.LogSoftmax(dim=1)
loss = torch.mean(torch.sum(-softmax(pred) * logsoftmax(pred), 1))
return loss
################## attentive entropy loss (source + target) #####################
def attentive_entropy(pred, pred_domain):
softmax = nn.Softmax(dim=1)
logsoftmax = nn.LogSoftmax(dim=1)
# attention weight
entropy = torch.sum(-softmax(pred_domain) * logsoftmax(pred_domain), 1)
weights = 1 + entropy
# attentive entropy
loss = torch.mean(weights * torch.sum(-softmax(pred) * logsoftmax(pred), 1))
return loss
################## ensemble-based loss #####################
# discrepancy loss used in MCD (CVPR 18)
def dis_MCD(out1, out2):
return torch.mean(torch.abs(F.softmax(out1,dim=1) - F.softmax(out2, dim=1)))
################## MMD-based loss #####################
def mmd_linear(f_of_X, f_of_Y):
# Consider linear time MMD with a linear kernel:
# K(f(x), f(y)) = f(x)^Tf(y)
# h(z_i, z_j) = k(x_i, x_j) + k(y_i, y_j) - k(x_i, y_j) - k(x_j, y_i)
# = [f(x_i) - f(y_i)]^T[f(x_j) - f(y_j)]
#
# f_of_X: batch_size * k
# f_of_Y: batch_size * k
delta = f_of_X - f_of_Y
loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))
return loss
def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
n_samples = int(source.size()[0])+int(target.size()[0])
total = torch.cat([source, target], dim=0)
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
L2_distance = ((total0-total1)**2).sum(2)
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
return sum(kernel_val)
def mmd_rbf(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None, ver=2):
batch_size = int(source.size()[0])
kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
loss = 0
if ver==1:
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
loss += kernels[s1, s2] + kernels[t1, t2]
loss -= kernels[s1, t2] + kernels[s2, t1]
loss = loss.abs_() / float(batch_size)
elif ver==2:
XX = kernels[:batch_size, :batch_size]
YY = kernels[batch_size:, batch_size:]
XY = kernels[:batch_size, batch_size:]
YX = kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
else:
raise ValueError('ver == 1 or 2')
return loss
def JAN(source_list, target_list, kernel_muls=[2.0, 2.0], kernel_nums=[2, 5], fix_sigma_list=[None, None], ver=2):
batch_size = int(source_list[0].size()[0])
layer_num = len(source_list)
joint_kernels = None
for i in range(layer_num):
source = source_list[i]
target = target_list[i]
kernel_mul = kernel_muls[i]
kernel_num = kernel_nums[i]
fix_sigma = fix_sigma_list[i]
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
if joint_kernels is not None:
joint_kernels = joint_kernels * kernels
else:
joint_kernels = kernels
loss = 0
if ver==1:
for i in range(batch_size):
s1, s2 = i, (i + 1) % batch_size
t1, t2 = s1 + batch_size, s2 + batch_size
loss += joint_kernels[s1, s2] + joint_kernels[t1, t2]
loss -= joint_kernels[s1, t2] + joint_kernels[s2, t1]
loss = loss.abs_() / float(batch_size)
elif ver==2:
XX = joint_kernels[:batch_size, :batch_size]
YY = joint_kernels[batch_size:, batch_size:]
XY = joint_kernels[:batch_size, batch_size:]
YX = joint_kernels[batch_size:, :batch_size]
loss = torch.mean(XX + YY - XY - YX)
else:
raise ValueError('ver == 1 or 2')
return loss
def get_L2norm_loss_self_drivenHAFN(x):
global args
args=parser.parse_args()
radius=args.radius_hafn
l = (x.norm(p=2, dim=1).mean() - radius) ** 2
return args.weight_afn * l
def get_cls_lossHAFN(pred, gt):
cls_loss = F.nll_loss(F.log_softmax(pred), gt)
return cls_loss
def get_cls_lossSAFN(pred, gt):
cls_loss = F.nll_loss(F.log_softmax(pred), gt)
return cls_loss
def get_L2norm_loss_self_drivenSAFN(x):
global args
args=parser.parse_args()
radius = x.norm(p=2, dim=1).detach()
assert radius.requires_grad == False
radius = radius + args.radius_safn
l = ((x.norm(p=2, dim=1) - radius) ** 2).mean()
return args.weight_afn* l