-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmain.py
137 lines (116 loc) · 5.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
"""
run main_train.py or main_challenge.py
"""
import os
import argparse
import configparser
from main_runner import main_train, main_challenge
class Conf:
def __init__(self, dir, ini):
self.dir = dir
self.ini = ini
self.data_dir = self.ini.get('BASE','data_dir')
self.result_dir = self.ini.get('BASE','result_dir')
self.testsize = int(self.ini.get('BASE', 'testsize'))
self.verbose = bool(self.ini.get('BASE','verbose'))
def set_dae_conf(self):
self.epochs = int(self.ini.get('DAE','epochs'))
self.batch = int(self.ini.get('DAE','batch'))
self.lr = float(self.ini.get('DAE', 'lr'))
self.reg_lambda = float(self.ini.get('DAE','reg_lambda'))
test_seed = self.ini.get('DAE', 'test_seed')
self.test_seed = ['test-'+item for item in test_seed.split(',')]
update_seed = self.ini.get('DAE', 'update_seed')
self.update_seed = ['test-'+item for item in update_seed.split(',')]
input_kp = self.ini.get('DAE','input_kp')
self.input_kp = [float(item) for item in input_kp.split(',')]
self.kp = float(self.ini.get('DAE', 'keep_prob'))
firstN = self.ini.get('DAE','firstN_range')
self.firstN = [float(item) for item in firstN.split(',')]
if len(self.firstN) == 1:
assert self.firstN[0] == -1.0
else:
assert self.firstN[0] <= self.firstN[1]
if self.firstN[1] < 1:
assert self.firstN[0] == 0 or self.firstN[0].is_integer() is False
else:
assert self.firstN[0] >= 1
assert self.firstN[0].is_integer() is True and self.firstN[1].is_integer() is True
self.initval = os.path.join(self.dir, self.ini.get('DAE', 'initval'))
self.save = os.path.join(self.dir, self.ini.get('DAE', 'save'))
self.hidden = int(self.ini.get('DAE', 'hidden'))
self.mode = 'dae'
def set_pretrain_conf(self):
self.epochs = int(self.ini.get('PRETRAIN','epochs'))
self.batch = int(self.ini.get('PRETRAIN','batch'))
self.lr = float(self.ini.get('PRETRAIN', 'lr'))
self.reg_lambda = float(self.ini.get('PRETRAIN','reg_lambda'))
self.is_pretrain = True
self.save = os.path.join(self.dir, self.ini.get('PRETRAIN', 'save'))
self.mode = 'pretrain'
def set_title_conf(self):
self.epochs = int(self.ini.get('TITLE', 'epochs'))
self.batch = int(self.ini.get('TITLE','batch'))
self.lr = float(self.ini.get('TITLE', 'lr'))
input_kp = self.ini.get('TITLE','input_kp')
self.input_kp = [float(item) for item in input_kp.split(',')]
self.title_kp = self.ini.get('TITLE', 'title_kp')
test_seed = self.ini.get('TITLE', 'test_seed')
self.test_seed = ['test-' + item for item in test_seed.split(',')]
update_seed = self.ini.get('TITLE', 'update_seed')
self.update_seed = ['test-' + item for item in update_seed.split(',')]
self.char_emb = int(self.ini.get('TITLE','char_emb'))
self.char_model = self.ini.get('TITLE','char_model')
if self.char_model == 'Char_CNN':
self.filter_num = int(self.ini.get('TITLE', 'filter_num'))
filter_size = self.ini.get('TITLE','filter_size')
self.filter_size = [int(item) for item in filter_size.split(',')]
elif self.char_model == 'Char_LSTM':
self.rnn_hidden = int(self.ini.get('TITLE', 'rnn_hidden'))
self.bi = bool(self.ini.get('TITLE', 'bi'))
self.DAEval = os.path.join(self.dir, self.ini.get('TITLE', 'DAEval'))
self.save = os.path.join(self.dir, self.ini.get('TITLE', 'save'))
if not os.path.isdir(self.save):
os.makedirs(self.save)
os.rmdir(self.save)
self.mode = 'title'
def set_challenge_oonf(self):
if not os.path.isdir(self.result_dir):
os.mkdir(self.result_dir)
self.challenge_data = self.ini.get('CHALLENGE', 'challenge_data')
self.result = os.path.join(self.result_dir, self.ini.get('CHALLENGE', 'result'))
self.batch = int(self.ini.get('CHALLENGE', 'batch'))
if __name__ == '__main__':
args = argparse.ArgumentParser(description="args")
args.add_argument('--dir', type=str, default='qwerty',help="directory name which contains config file")
args.add_argument('--pretrain', action='store_true', default=False, help="pretrain mode if Specified")
args.add_argument('--dae', action='store_true', default=False, help="DAE mode if Specified")
args.add_argument('--title', action='store_true', default=False, help="title mode if Specified")
args.add_argument('--challenge', action='store_true', default=False, help="challenge mode if Specified")
args.add_argument('--testmode', action='store_true', default=False, help="test mode if Specified(just check the result)")
args = args.parse_args()
dir = os.path.join(".",args.dir)
if not os.path.isdir(dir):
print("ERROR: Cannot find "+dir+" ->Create directory and config.ini file first")
exit(0)
if 'config.ini' not in os.listdir(dir):
print("ERROR: Cannot find config.ini in " +dir+" ->Create config.ini file in the directory first")
exit(0)
ini_dir = os.path.join(dir,'config.ini')
ini = configparser.ConfigParser()
ini.read(ini_dir)
conf = Conf(dir, ini)
conf.set_dae_conf()
if args.pretrain:
conf.set_pretrain_conf()
main_train.run(conf, args.testmode)
elif args.dae:
conf.set_dae_conf()
main_train.run(conf, args.testmode)
elif args.title:
conf.set_title_conf()
main_train.run(conf, args.testmode)
elif args.challenge:
conf.set_title_conf()
conf.set_challenge_oonf()
main_challenge.run(conf)