-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtrain_seq2seq_lm.py
64 lines (54 loc) · 1.98 KB
/
train_seq2seq_lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import pytorch_lightning as pl
from models.seq2seq_lm import argparser
from models.seq2seq_lm.model import Model
from models.seq2seq_lm.data_module import DataModule
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
from models.seq2seq_lm.config import GPUS,ACCELERATOR
from copy import deepcopy
import torch
args = argparser.get_args()
if __name__ == "__main__":
# load model from_checkpoint or init a new one
if args.from_checkpoint is None:
model = Model()
else:
print('load from checkpoint')
model = Model.load_from_checkpoint(args.from_checkpoint)
# run as a flask api server
if args.server:
model.run_server()
exit()
# trainer config
trainer = pl.Trainer(
gpus=GPUS,
accelerator=ACCELERATOR,
fast_dev_run=args.dev,
precision=32,
default_root_dir='.log_seq2seq_lm',
max_epochs=args.epoch,
callbacks=[
EarlyStopping(monitor='dev_loss',patience=5),
ModelCheckpoint(monitor='dev_loss',filename='{epoch}-{dev_loss:.2f}',save_last=True),
]
)
# DataModule
dm = DataModule()
# train
if args.run_test == False:
tuner = pl.tuner.tuning.Tuner(deepcopy(trainer))
new_batch_size = tuner.scale_batch_size(model, datamodule=dm, init_val=torch.cuda.device_count())
del tuner
model.hparams.batch_size = new_batch_size
trainer.fit(model,datamodule=dm)
# decide which checkpoint to use
last_model_path = trainer.checkpoint_callback.last_model_path
best_model_path = trainer.checkpoint_callback.best_model_path
_use_model_path = last_model_path if best_model_path == "" else best_model_path
print('use checkpoint:',_use_model_path)
# run_test
trainer.test(
model=model if _use_model_path == "" else None,
datamodule=dm,
ckpt_path=_use_model_path
)