-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
23 lines (21 loc) · 1.15 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import trainer
import argparse
def main(param):
model = trainer.Trainer(load=param.load, emb_size=param.emb_size, num_epoch=param.num_epoch, nhead=param.nhead,
ffn_hid_dim=param.ffn_hid_dim, batch_size=param.batch_size, lr=param.learning_rate,
n_layers=param.n_layers, dropout=param.dropout, variation=param.variation)
model.train()
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Transformer Machine Translation')
parser.add_argument('--load', type=bool, default=False)
parser.add_argument('--num_epoch', type=int, default=15)
parser.add_argument('--nhead', type=int, default=8)
parser.add_argument('--emb_size', type=int, default=512)
parser.add_argument('--ffn_hid_dim', type=int, default=512)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--n_layers', type=int, default=6)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--variation', type=bool, default=False)
parser.add_argument('--learning_rate', type=float, default=5e-5)
args = parser.parse_args()
main(args)