-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
76 lines (58 loc) · 2.62 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
class NCESoftmaxLoss(nn.Module):
def __init__(self, nce_t=0.07, nce_num_pairs=1024):
super().__init__()
self.nce_t = nce_t
self.nce_num_pairs = nce_num_pairs
self.cross_entropy = nn.CrossEntropyLoss()
def forward(self, features_1, features_2, map21):
"""Computes the NCE loss between two sets of features.
Args:
features_1 (torch.Tensor): The first set of features. Shape (B, N1, F).
features_2 (torch.Tensor): The second set of features. Shape (B, N2, F).
map21 (list): p2p correspondences between the two sets of features. Shape (B, N2).
Returns:
torch.Tensor: The NCE loss.
"""
map21 = map21.view(features_1.size(0), -1)
loss = 0
for i in range(features_1.size(0)):
map_21, feat1, feat2 = map21[i], features_1[i], features_2[i]
mask = map_21 != -1
map_21_masked = map_21[mask]
if map_21_masked.shape[0] > self.nce_num_pairs:
selected = np.random.choice(map_21_masked.shape[0], self.nce_num_pairs, replace=False)
else:
selected = torch.arange(map_21_masked.shape[0])
query = feat1[map_21_masked[selected]]
keys = feat2[mask][selected]
logits = - torch.cdist(query, keys)
logits = torch.div(logits, self.nce_t)
labels = torch.arange(selected.shape[0]).long().to(feat1.device)
loss += self.cross_entropy(logits, labels)
return loss
class LIELoss(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, xyz_1, features_1, features_2, map21):
"""Computes the LIE loss between two sets of features.
Args:
xyz_1 (torch.Tensor): The xyz coordinates of the first set. Shape (B, N1, 3).
features_1 (torch.Tensor): The first set of features. Shape (B, N1, F).
features_2 (torch.Tensor): The second set of features. Shape (B, N2, F).
map21 (list): p2p correspondences between the two sets of features. Shape (B, N2).
Returns:
torch.Tensor: The LIE loss.
"""
S21 = F.softmax(-torch.cdist(features_2, features_1), dim=-1)
xyz21 = torch.bmm(S21, xyz_1)
loss = 0
for i in range(map21.size(0)):
map2_optim = map21[i] != -1
map1_optim = map21[i, map2_optim]
loss += (xyz21[i, map2_optim] - xyz_1[i, map1_optim]).square().sum()
loss = loss / map21.size(0)
return loss