-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathconfig_ssl_upload.py
177 lines (160 loc) · 6.1 KB
/
config_ssl_upload.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""
Experiment configuration file
Extended from config file from original PANet Repository
"""
import os
import re
import glob
import itertools
import sacred
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
from platform import node
from datetime import datetime
from util.consts import IMG_SIZE
sacred.SETTINGS['CONFIG']['READ_ONLY_CONFIG'] = False
sacred.SETTINGS.CAPTURE_MODE = 'no'
ex = Experiment('mySSL')
ex.captured_out_filter = apply_backspaces_and_linefeeds
source_folders = ['.', './dataloaders', './models', './util']
sources_to_save = list(itertools.chain.from_iterable(
[glob.glob(f'{folder}/*.py') for folder in source_folders]))
for source_file in sources_to_save:
ex.add_source_file(source_file)
@ex.config
def cfg():
"""Default configurations"""
seed = 1234
gpu_id = 0
mode = 'train' # for now only allows 'train'
do_validation=False
num_workers = 4 # 0 for debugging.
dataset = 'CHAOST2_Superpix' # i.e. abdominal MRI
use_coco_init = True # initialize backbone with MS_COCO initialization. Anyway coco does not contain medical images
### Training
n_steps = 100100
batch_size = 1
lr_milestones = [ (ii + 1) * 1000 for ii in range(n_steps // 1000 - 1)]
lr_step_gamma = 0.95
ignore_label = 255
print_interval = 100
save_snapshot_every = 25000
max_iters_per_load = 1000 # epoch size, interval for reloading the dataset
epochs=1
scan_per_load = -1 # numbers of 3d scans per load for saving memory. If -1, load the entire dataset to the memory
which_aug = 'sabs_aug' # standard data augmentation with intensity and geometric transforms
input_size = (IMG_SIZE, IMG_SIZE)
min_fg_data='100' # when training with manual annotations, indicating number of foreground pixels in a single class single slice. This empirically stablizes the training process
label_sets = 0 # which group of labels taking as training (the rest are for testing)
curr_cls = "" # choose between rk, lk, spleen and liver
exclude_cls_list = [2, 3] # testing classes to be excluded in training. Set to [] if testing under setting 1
usealign = True # see vanilla PANet
use_wce = True
use_dinov2_loss = False
dice_loss = False
### Validation
z_margin = 0
eval_fold = 0 # which fold for 5 fold cross validation
support_idx=[-1] # indicating which scan is used as support in testing.
val_wsize=2 # L_H, L_W in testing
n_sup_part = 3 # number of chuncks in testing
use_clahe = False
use_slice_adapter = False
adapter_layers=3
debug=False
skip_no_organ_slices=True
# Network
modelname = 'dlfcn_res101' # resnet 101 backbone from torchvision fcn-deeplab
clsname = None #
reload_model_path = None # path for reloading a trained model (overrides ms-coco initialization)
proto_grid_size = 8 # L_H, L_W = (32, 32) / 8 = (4, 4) in training
feature_hw = [input_size[0]//8, input_size[0]//8] # feature map size, should couple this with backbone in future
lora = 0
use_3_slices=False
do_cca=False
use_edge_detector=False
finetune_on_support=False
sliding_window_confidence_segmentation=False
finetune_model_on_single_slice=False
online_finetuning=True
use_bbox=True # for SAM
use_points=True # for SAM
use_mask=False # for SAM
base_model="alpnet" # or "SAM"
# SSL
superpix_scale = 'MIDDLE' #MIDDLE/ LARGE
use_pos_enc=False
support_txt_file = None # path to a txt file containing support slices
augment_support_set=False
coarse_pred_only=False # for ProtoSAM
point_mode="both" # for ProtoSAM, choose: both, conf, centroid
use_neg_points=False
n_support=1 # num support images
protosam_sam_ver="sam_h" # or medsam
grad_accumulation_steps=1
ttt=False
reset_after_slice=True # for TTT, if to reset the model after finetuning on each slice
model = {
'align': usealign,
'dinov2_loss': use_dinov2_loss,
'use_coco_init': use_coco_init,
'which_model': modelname,
'cls_name': clsname,
'proto_grid_size' : proto_grid_size,
'feature_hw': feature_hw,
'reload_model_path': reload_model_path,
'lora': lora,
'use_slice_adapter': use_slice_adapter,
'adapter_layers': adapter_layers,
'debug': debug,
'use_pos_enc': use_pos_enc
}
task = {
'n_ways': 1,
'n_shots': 1,
'n_queries': 1,
'npart': n_sup_part
}
optim_type = 'sgd'
lr=1e-3
momentum=0.9
weight_decay=0.0005
optim = {
'lr': lr,
'momentum': momentum,
'weight_decay': weight_decay
}
exp_prefix = ''
exp_str = '_'.join(
[exp_prefix]
+ [dataset,]
+ [f'sets_{label_sets}_{task["n_shots"]}shot'])
path = {
'log_dir': './runs',
'SABS':{'data_dir': "./data/SABS/sabs_CT_normalized"
},
'SABS_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"
},
'SABS_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"
},
'C0':{'data_dir': "feed your dataset path here"
},
'CHAOST2':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/"
},
'CHAOST2_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"
},
'SABS_Superpix':{'data_dir': "./data/SABS/sabs_CT_normalized"},
'C0_Superpix':{'data_dir': "feed your dataset path here"},
'CHAOST2_Superpix':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized/"},
'CHAOST2_Superpix_672':{'data_dir': "./data/CHAOST2/chaos_MR_T2_normalized_672/"},
'SABS_Superpix_448':{'data_dir': "./data/SABS/sabs_CT_normalized_448"},
'SABS_Superpix_672':{'data_dir': "./data/SABS/sabs_CT_normalized_672"},
}
@ex.config_hook
def add_observer(config, command_name, logger):
"""A hook fucntion to add observer"""
exp_name = f'{ex.path}_{config["exp_str"]}'
observer = FileStorageObserver.create(os.path.join(config['path']['log_dir'], exp_name))
ex.observers.append(observer)
return config