-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_latent_features.py
66 lines (54 loc) · 2.5 KB
/
extract_latent_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
from argparse import ArgumentParser
from os.path import join, split
import numpy as np
import torch
from deep_utils import DirUtils
import dataset
from utils import load_model
def get_features(model_path: str, dataset_path: str, device: str, threshold: float = 0.5):
"""
:param model_path: path to the model
:param dataset_path: path to numpy preprocessed dataset
:param device: device on which to run the model
:param threshold: threshold for the classification
:return:
"""
device = torch.device(device)
train_dataset = dataset.CTDataset3D(dataset_path, augm_transform=None, get_img_names=True)
net = load_model(model_path=model_path, device=device).eval()
features_list = []
features_names = []
labels = []
predicted_lbl = []
with torch.no_grad():
for img, lbl, img_name in train_dataset:
features = net.features(img[None, ...].to(device))
logits = net(img[None, ...].to(device))
features_list.append(features[0].cpu().numpy().reshape(-1))
features_names.append(img_name)
labels.append(lbl)
predicted_lbl.append(1 if logits[0, 0].item() > threshold else 0)
return features_names, np.array(features_list), labels, predicted_lbl
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--model_path", help="Path to model", type=str)
parser.add_argument("--suffix", default="_features", help="suffix to save the extracted features", type=str)
parser.add_argument("--dataset_path", nargs="+", help="Path to numpy preprocessed dataset")
parser.add_argument("--output_path", default="latent_features",
help="Directory to put the outputs",
)
parser.add_argument("--device", default="cuda", help="Device to run the model on")
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
for path in args.dataset_path:
if "cuda" in args.device:
torch.cuda.empty_cache()
features_names, features_list, labels, predicted_lbl = get_features(args.model_path, path, args.device)
npz_path = join(args.output_path, DirUtils.split_extension(split(path)[-1], suffix=args.suffix))
# save the final extracted features!
np.savez_compressed(npz_path,
names=features_names,
features=features_list,
labels=labels,
predicted_lbl=predicted_lbl)