-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfskd.py
61 lines (47 loc) · 2.67 KB
/
fskd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import numpy as np
import torch
from torch_cluster import nearest
def normalize(v):
scale_test = float(torch.sqrt(torch.sum((torch.max(torch.from_numpy(v), axis=0).values - torch.min(torch.from_numpy(v), axis=0).values) ** 2)))
v /= scale_test
return v
def fskd(verts_target, feat_target, feat_templates, annotation_templates, threshold_matching=0.05, sigma=0.01):
"""
Few Shot Keypoint Detection (FSKD) algorithm.
Args:
verts_target (torch.Tensor): The target shape vertices. Shape (N, 3).
feat_target (torch.Tensor): The target shape features, extracted by the trained network. Shape (N, F).
feat_templates (list): The list of template shape features. Shape [(M, F)].
annotation_templates (list): The list of template shape annotations. Shape [(M,)].
threshold_matching (float): The threshold for matching keypoints.
sigma (float): The sigma value for the Gaussian kernel.
Returns:
np.ndarray: The predicted keypoints.
"""
num_source = len(feat_templates)
# normalize target
verts_target = normalize(verts_target)
# compute maps between templates and target
T_tem_tes_s = [nearest(feat_template, feat_target).numpy() for feat_template in feat_templates]
T_tes_tem_s = [nearest(feat_target, feat_template).numpy() for feat_template in feat_templates]
# check if keypoints on template exist on the other shape
dist_temp = [np.linalg.norm(verts_target[T_tem_tes_s[i][T_tes_tem_s[i]]] - verts_target, axis=1) for i in range(num_source)]
chosen_maps_dist = [dist_temp[i][T_tem_tes_s[i][annotation_templates[i]]] for i in range(num_source)]
chosen_maps_test_thresh = [dist_temp[i][T_tem_tes_s[i][annotation_templates[i]]] < threshold_matching for i in range(num_source)]
final_mask = np.vstack([np.logical_and(annotation_templates[i], chosen_maps_test_thresh[i]) for i in range(num_source)])
# predict
potential_kpts = [T_tem_tes_s[i][map_template] for i, map_template in enumerate(annotation_templates)]
potential_kpts = np.vstack(potential_kpts)
predicted_kpts = np.ones_like(annotation_templates[0]) * -1
for kp_id in range(len(predicted_kpts)):
point, total = 0, 0
for i in range(num_source):
if final_mask[i, kp_id]:
weight = np.exp(- chosen_maps_dist[i][kp_id] / sigma)
point += verts_target[potential_kpts[i, kp_id]] * weight
total += weight
if total:
mean_point = point / total
predicted_kpts[kp_id] = int(nearest(torch.tensor(mean_point[None, :]), torch.tensor(verts_target)))
pred_kpts = predicted_kpts[predicted_kpts != -1]
return pred_kpts