-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathoptions.py
38 lines (33 loc) · 2.38 KB
/
options.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import argparse
class TrainOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
# data loader related
self.parser.add_argument('--dataroot', type=str, default='/data1/timer/SuperFusion/dataset/train/MSRS', help='path of data')
self.parser.add_argument('--phase', type=str, default='train', help='phase for dataloading')
self.parser.add_argument('--batch_size', type=int, default=16, help='batch size')
self.parser.add_argument('--nThreads', type=int, default=8, help='# of threads for data loader')
# ouptput related
self.parser.add_argument('--name', type=str, default='FS_MSRS', help='folder name to save outputs')
self.parser.add_argument('--display_dir', type=str, default='./logs', help='path for saving display results')
self.parser.add_argument('--result_dir', type=str, default='./results', help='path for saving result images and models')
self.parser.add_argument('--display_freq', type=int, default=50, help='freq (iteration) of display')
self.parser.add_argument('--img_save_freq', type=int, default=50, help='freq (epoch) of saving images')
self.parser.add_argument('--model_save_freq', type=int, default=50, help='freq (epoch) of saving models')
self.parser.add_argument('--no_display_img', action='store_true', help='specified if no dispaly')
# training related
self.parser.add_argument('--lr_policy', type=str, default='lambda', help='type of learn rate decay')
self.parser.add_argument('--n_ep', type=int, default=2000, help='number of epochs') # 400 * d_iter
self.parser.add_argument('--n_ep_decay', type=int, default=1600, help='epoch start decay learning rate, set -1 if no decay') # 200 * d_iter
self.parser.add_argument('--resume', type=str, default='./checkpoint/MSRS.pth', help='specified the dir of saved models for resume the training')
self.parser.add_argument('--gpu', type=int, default=0, help='gpu')
self.parser.add_argument('--stage', type=str, default='FS', help='reg&fus (RF) or fus&seg (FS)')
#segmentation related
self.parser.add_argument('--dataroot_val', type=str, default='./dataset/test/MSRS/', help="data for segmentation validation")
def parse(self):
self.opt = self.parser.parse_args()
args = vars(self.opt)
print('\n--- load options ---')
for name, value in sorted(args.items()):
print('%s: %s' % (str(name), str(value)))
return self.opt