-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun.py
61 lines (46 loc) · 1.75 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import cfgs.config as config
import argparse, yaml
import random
from easydict import EasyDict as edict
def parse_args():
'''
Parse input arguments
'''
parser = argparse.ArgumentParser(description='Bilinear Args')
parser.add_argument('--run', dest='run_mode',
choices=['train', 'val', 'test',''],
help='{train, val, test}',
type=str, required=True)
parser.add_argument('--mode', dest='mode',
choices=['maa'],
help='{maa, ...}',
default='maa', type=str)
parser.add_argument('--dataset', dest='dataset',
choices=['imdb', 'yelp_13', 'yelp_14'],
help='{imdb, yelp_13, yelp_14}',
default='imdb', type=str)
parser.add_argument('--gpu', dest='gpu',
help="gpu select, eg.'0, 1, 2'",
type=str,
default="0,1")
parser.add_argument('--seed', dest='seed',
help='fix random seed',
type=int,
default=random.randint(0, 99999999))
parser.add_argument('--version', dest='version',
help='version control',
type=str,
default="default")
args = parser.parse_args()
return args
if __name__ == '__main__':
__C = config.__C
args = parse_args()
args_dict = edict({**vars(args)})
config.add_edit(args_dict, __C)
config.proc(__C)
print('Hyper Parameters:')
config.config_print(__C)
from common.trainer_maa import MAATrainer
execution = MAATrainer(__C)
execution.run(__C.run_mode)