-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinference.py
126 lines (99 loc) · 4.36 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import sys
import json
import argparse
from tqdm import tqdm
from termcolor import colored, cprint
from utils import print_inference_header
import numpy as np
import pandas as pd
import torch
import modules.losses as losses_
import modules.metrics as metrics_
import modules.trainers as trainers_
from torch.utils.data import DataLoader
from modules.datasets import FinHerTiles
from modules.networks import instantiate_network
def main(config):
models = config['models']
# --- prepare main data-frame ---
root = config["root"]
# get first-level subdirs:
dir_list = [ o for o in os.listdir(root) if os.path.isdir(os.path.join(root,o)) and o.startswith('FinHer')]
for slide_name in tqdm(dir_list, leave=True):
slide_dir = os.path.join(root, slide_name)
tile_dir = os.path.join(slide_dir, 'Normalized-Tiles')
# read & filter data_frame with tiles
df = 'Normalized-{}-Tiles.csv'.format(slide_name)
df = pd.read_csv(os.path.join(slide_dir, df))
df['id'] = df.index
for model_dict in models:
#print(colored('> {}'.format(model_dict['snapshot']), 'magenta', attrs=['bold']))
model_config = torch.load(model_dict['snapshot'])['config']
# --- loaders ---
data_set = FinHerTiles(df, tile_dir, t_size=config['tile_size'])
data_loaders = dict()
data_loaders['val'] = DataLoader(
data_set,
batch_size = config["batch_size"],
shuffle = False,
num_workers = config["num_workers"],
drop_last = config["drop_last"]
)
# --- model ---
model = instantiate_network(model_config)
# --- loss & metrics ---
loss_dict = dict()
metric_dict = dict()
for d_ in model_config['architecture']['args']['aux_outputs']:
loss_dict[d_['name']] = getattr(losses_, d_['loss']['type'])(**d_['loss']['args'])
metric_dict[d_['name']] = [getattr(metrics_, met) for met in d_['metrics']]
# --- instantiate model instance ---
trainer = trainers_.Trainer(model, loss_dict, metric_dict,
resume = model_dict['snapshot'],
config = model_config,
data_loaders = data_loaders,
train_logger = None,
inference = True )
# --- inference ---
result_dict = trainer._predict(vrbs=False)
# --- drop attention maps when applicabel ---
_ = result_dict.pop("a1", None)
_ = result_dict.pop("a2", None)
# --- rename column ---
cc_ = list(result_dict.keys())
cc_.remove('id')
for col in cc_:
ccnew_ = '{}{}'.format(col,model_dict['name'])
result_dict[ccnew_] = result_dict.pop(col)
# --- append results to a data-frame ---
df = pd.merge(
left = df,
right = pd.DataFrame(result_dict, index=None),
on = 'id',
how = 'left',
validate = 'one_to_one'
)
# --- save results to csv-file ---
csv_name = '{}-{}.csv'.format(slide_name, config['analysis'])
df.to_csv(
os.path.join(slide_dir, csv_name),
index=False
)
# ------------------------- __main__-------------------------------------
if __name__ == '__main__':
print(colored('----- Inference on tiles -----', 'blue', attrs=['bold']))
parser = argparse.ArgumentParser(description='Inference script')
parser.add_argument('-c', '--config', default=None, type=str,
help='json config file path (default: None)')
args = parser.parse_args()
# --- validate arguments --- #
if args.config:
# load config file
config = json.load(open(args.config))
else:
print(colored('Configuration file needs to be specified: --config', 'red'))
sys.exit(1)
main(config)
print(colored('\n--------------- Done ----------------', 'blue', attrs=['bold']))
# ---------------------- Done -----------------------------------------