-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
155 lines (130 loc) · 4.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import os
from argparse import ArgumentParser
# local imports
from dataloader import build_dataloader
from utils import (
set_encoder_decoder,
create_model,
create_save_folders,
Model_Inference,
)
# *------------------------------------------------------------------ Arguments --------------------------------------------------------------------*
parser = ArgumentParser()
parser.add_argument("--dataset_name", type=str, default="kits23") #
parser.add_argument("--test_path", type=str, default="data/kits23/test/")
parser.add_argument(
"--color",
type=str,
default="rgb",
help="possible values: grey (greyscale) and color (rgb).",
)
parser.add_argument(
"--amount_classes",
type=int,
default=4, # 33 for opg, 17 for word, 2 for binary, 34 for cityscapes
help="Only needed for mutliple class segmentation. Number of classes + Background (e.g. 4 classes + Background = 5).",
)
parser.add_argument("--resize", type=int, default=(512, 512)) # height, width
# opg multiclass: 560, 992
# word multiclass: 70, 102
# cityscapes: 512, 1024
# kits: 512, 512
# Model choice
parser.add_argument(
"--model", type=str, default="vanilla_unet", help="possible values: vanilla_unet"
)
parser.add_argument(
"--encoder_depth",
type=int,
default=4,
help="possible values up to 5. Depth 1 = [1, 64] for the encoder. Depth 2 = [1, 64, 128]. And so on.",
)
# Seg-Grad CAM or Seg-HiRes-Grad CAM
parser.add_argument(
"--cam",
type=str,
default="gradcam",
help="Decide whether to use Seg-Grad CAM (gradcam) or Seg-HiRes-Grad CAM (hirescam).",
)
parser.add_argument(
"--level",
type=int,
default="4",
help="Decide which layer to use for CAM.",
)
parser.add_argument(
"--px_set",
type=str,
default="class",
help="Decide whether to use image, class, point or zero. Best results with class.",
)
parser.add_argument(
"--px_set_point",
type=str,
default=(300, 300),
help="X, y coordinates if px_set is chosen to be point.",
)
# Output data parameters
parser.add_argument("--result_folder", type=str, default="results/")
parser.add_argument("--extension", type=str, default=".pdf") # .png
# Operating System
parser.add_argument("--gpu", type=str, default="0")
args = parser.parse_args()
# *--------------------------------------------------------------- Argument Parsing ----------------------------------------------------------------*
# Set resize argument
transform_resize = {"resize": args.resize}
model_name = f"{args.model}"
# Set the result folder in more detail
if args.cam == "gradcam":
path_save_folder_grad_cam = f"{args.result_folder}{model_name}/visualizations/{args.dataset_name}/seg_grad_cam/gradcam_{args.px_set}/"
if args.cam == "hirescam":
path_save_folder_grad_cam = f"{args.result_folder}{model_name}/visualizations/{args.dataset_name}/seg_grad_cam/hirescam_{args.px_set}/"
path_load_model = f"{args.result_folder}{model_name}/models/{args.dataset_name}/model"
# Set the visible and used GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"DEVICE CHOICE: {device}")
classes = "multiclass" if args.amount_classes != 1 else "binary"
# Set encoder and decoder values
encoder_channels, decoder_channels = set_encoder_decoder(
color=args.color, amount_classes=args.amount_classes, depth=args.encoder_depth
)
# Create the model
model, path_save_folder_grad_cam, path_load_model = create_model(
modelname=model_name,
encoder_channels=encoder_channels,
decoder_channels=decoder_channels,
path_save_folder_grad_cam=path_save_folder_grad_cam,
path_load_model=path_load_model,
)
# validate that save folders are existing
create_save_folders(path_save_folder_grad_cam=path_save_folder_grad_cam)
# *------------------------------------------------------------------ Main -----------------------------------------------------------------------*
if __name__ == "__main__":
# Create the dataloader
dataloader = build_dataloader(
dir_root=args.test_path,
transform={**transform_resize},
num_workers=1,
batch_size=1,
return_filenames=True,
shuffle=False,
dataset_color=args.color,
classes=classes,
)
# Load the respective model
Model_Class = Model_Inference(
model=model,
model_name=model_name,
dataset_name=args.dataset_name,
path_model_state_dict=path_load_model,
device=device,
dataloader=dataloader,
amount_classes=args.amount_classes,
cam_type=args.cam,
pixel_set=args.px_set,
pixel_set_point=args.px_set_point,
path_save_folder_grad_cam=path_save_folder_grad_cam,
)
Model_Class.inference(cam=True, cam_level=args.level)