-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathloss.py
109 lines (73 loc) · 3.63 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""
Pytorch adaptation of https://omoindrot.github.io/triplet-loss
https://github.com/omoindrot/tensorflow-triplet-loss
"""
import torch
import torch.nn as nn
class TripletMarginLoss(nn.Module):
def __init__(self, margin=1.0, p=2.0, mining="batch_all"):
super().__init__()
self.margin = margin
self.p = p
self.mining = mining
if mining == "batch_all":
self.loss_fn = batch_all_triplet_loss
if mining == "batch_hard":
self.loss_fn = batch_hard_triplet_loss
def forward(self, embeddings, labels):
return self.loss_fn(labels, embeddings, self.margin, self.p)
def batch_hard_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
mask_anchor_positive = _get_anchor_positive_triplet_mask(labels).float()
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# hardest positive for every anchor
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
mask_anchor_negative = _get_anchor_negative_triplet_mask(labels).float()
# Add max value in each row to invalid negatives
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# hardest negative for every anchor
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
triplet_loss = hardest_positive_dist - hardest_negative_dist + margin
triplet_loss[triplet_loss < 0] = 0
triplet_loss = triplet_loss.mean()
return triplet_loss, -1
def batch_all_triplet_loss(labels, embeddings, margin, p):
pairwise_dist = torch.cdist(embeddings, embeddings, p=p)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
triplet_loss = anchor_positive_dist - anchor_negative_dist + margin
mask = _get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (easy triplets)
triplet_loss[triplet_loss < 0] = 0
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
num_valid_triplets = mask.sum()
fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss, fraction_positive_triplets
def _get_triplet_mask(labels):
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
def _get_anchor_positive_triplet_mask(labels):
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), dtype=torch.bool, device=labels.device)
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
def _get_anchor_negative_triplet_mask(labels):
return labels.unsqueeze(0) != labels.unsqueeze(1)