-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathcommon.py
84 lines (73 loc) · 2.77 KB
/
common.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from copy import deepcopy
import numpy as np
import skimage.io as io
import torch
import torchvision
from silk.backbones.silk.silk import SiLKVGG as SiLK
from silk.backbones.superpoint.vgg import ParametricVGG
from silk.config.model import load_model_from_checkpoint
from silk.models.silk import matcher
# CHECKPOINT_PATH = os.path.join(os.path.dirname(__file__), "../../assets/models/silk/analysis/alpha/pvgg-4.ckpt")
CHECKPOINT_PATH = os.path.join(
os.path.dirname(__file__), "../../assets/models/silk/coco-rgb-aug.ckpt"
)
DEVICE = "cuda:0"
SILK_NMS = 0 # NMS radius, 0 = disabled
SILK_BORDER = 0 # remove detection on border, 0 = disabled
SILK_THRESHOLD = 1.0 # keypoint score thresholding, if # of keypoints is less than provided top-k, then will add keypoints to reach top-k value, 1.0 = disabled
SILK_TOP_K = 10000 # minimum number of best keypoints to output, could be higher if threshold specified above has low value
SILK_DEFAULT_OUTPUT = ( # outputs required when running the model
"dense_positions",
"normalized_descriptors",
"probability",
)
SILK_SCALE_FACTOR = 1.41 # scaling of descriptor output, do not change
SILK_BACKBONE = ParametricVGG(
use_max_pooling=False,
padding=0,
normalization_fn=[torch.nn.BatchNorm2d(i) for i in (64, 64, 128, 128)],
)
SILK_MATCHER = matcher(postprocessing="ratio-test", threshold=0.6)
# SILK_MATCHER = matcher(postprocessing="double-softmax", threshold=0.6, temperature=0.1)
# SILK_MATCHER = matcher(postprocessing="none")
def load_images(*paths, as_gray=True):
images = np.stack([io.imread(path, as_gray=as_gray) for path in paths])
images = torch.tensor(images, device=DEVICE, dtype=torch.float32)
if not as_gray:
images = images.permute(0, 3, 1, 2)
images = images / 255.0
else:
images = images.unsqueeze(1) # add channel dimension
return images
def get_model(
checkpoint=CHECKPOINT_PATH,
nms=SILK_NMS,
device=DEVICE,
default_outputs=SILK_DEFAULT_OUTPUT,
):
# load model
model = SiLK(
in_channels=1,
backbone=deepcopy(SILK_BACKBONE),
detection_threshold=SILK_THRESHOLD,
detection_top_k=SILK_TOP_K,
nms_dist=nms,
border_dist=SILK_BORDER,
default_outputs=default_outputs,
descriptor_scale_factor=SILK_SCALE_FACTOR,
padding=0,
)
model = load_model_from_checkpoint(
model,
checkpoint_path=checkpoint,
state_dict_fn=lambda x: {k[len("_mods.model.") :]: v for k, v in x.items()},
device=device,
freeze=True,
eval=True,
)
return model