-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
Copy pathmodel_checkpoint.py
181 lines (153 loc) · 7 KB
/
model_checkpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import shutil
import logging as log
import warnings
import numpy as np
from .base import Callback
class ModelCheckpoint(Callback):
r"""
Save the model after every epoch.
Args:
filepath: path to save the model file.
Can contain named formatting options to be auto-filled.
Example::
# save epoch and val_loss in name
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
# saves file like: /path/epoch_2-val_loss_0.2.hdf5
monitor (str): quantity to monitor.
verbose (bool): verbosity mode, False or True.
save_top_k (int): if `save_top_k == k`,
the best k models according to
the quantity monitored will be saved.
if ``save_top_k == 0``, no models are saved.
if ``save_top_k == -1``, all models are saved.
Please note that the monitors are checked every `period` epochs.
if ``save_top_k >= 2`` and the callback is called multiple
times inside an epoch, the name of the saved file will be
appended with a version count starting with `v0`.
mode (str): one of {auto, min, max}.
If ``save_top_k != 0``, the decision
to overwrite the current save file is made
based on either the maximization or the
minimization of the monitored quantity. For `val_acc`,
this should be `max`, for `val_loss` this should
be `min`, etc. In `auto` mode, the direction is
automatically inferred from the name of the monitored quantity.
save_weights_only (bool): if True, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model
is saved (`model.save(filepath)`).
period (int): Interval (number of epochs) between checkpoints.
Example::
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# saves checkpoints to my_path whenever 'val_loss' has a new min
checkpoint_callback = ModelCheckpoint(filepath='my_path')
Trainer(checkpoint_callback=checkpoint_callback)
"""
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
save_top_k: int = 1, save_weights_only: bool = False,
mode: str = 'auto', period: int = 1, prefix: str = ''):
super().__init__()
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
warnings.warn(
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
"All files in this directory will be deleted when a checkpoint is saved!"
)
self.monitor = monitor
self.verbose = verbose
self.filepath = filepath
os.makedirs(filepath, exist_ok=True)
self.save_top_k = save_top_k
self.save_weights_only = save_weights_only
self.period = period
self.epochs_since_last_check = 0
self.prefix = prefix
self.best_k_models = {}
# {filename: monitor}
self.kth_best_model = ''
self.best = 0
self.save_function = None
mode_dict = {
'min': (np.less, np.Inf, 'min'),
'max': (np.greater, -np.Inf, 'max'),
'auto': (np.greater, -np.Inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
else (np.less, np.Inf, 'min'),
}
if mode not in mode_dict:
warnings.warn(
f'ModelCheckpoint mode {mode} is unknown, '
'fallback to auto mode.', RuntimeWarning)
mode = 'auto'
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
def _del_model(self, filepath):
try:
shutil.rmtree(filepath)
except OSError:
os.remove(filepath)
def _save_model(self, filepath):
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
# delegate the saving to the model
if self.save_function is not None:
self.save_function(filepath)
else:
raise ValueError(".save_function() not set")
def check_monitor_top_k(self, current):
less_than_k_models = len(self.best_k_models) < self.save_top_k
if less_than_k_models:
return True
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
def on_validation_end(self):
logs = self.trainer.callback_metrics
epoch = self.trainer.current_epoch
self.epochs_since_last_check += 1
if self.save_top_k == 0:
# no models are saved
return
if self.epochs_since_last_check >= self.period:
self.epochs_since_last_check = 0
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}.ckpt'
version_cnt = 0
while os.path.isfile(filepath):
# this epoch called before
filepath = f'{self.filepath}/{self.prefix}_ckpt_epoch_{epoch}_v{version_cnt}.ckpt'
version_cnt += 1
if self.save_top_k != -1:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f'Can save best model only with {self.monitor} available,'
' skipping.', RuntimeWarning)
else:
if self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
else:
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor}'
f' was not in top {self.save_top_k}')
else:
if self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
self._save_model(filepath)
def _do_check_save(self, filepath, current, epoch):
# remove kth
if len(self.best_k_models) == self.save_top_k:
delpath = self.kth_best_model
self.best_k_models.pop(self.kth_best_model)
self._del_model(delpath)
self.best_k_models[filepath] = current
if len(self.best_k_models) == self.save_top_k:
# monitor dict has reached k elements
_op = max if self.mode == 'min' else min
self.kth_best_model = _op(self.best_k_models,
key=self.best_k_models.get)
self.kth_value = self.best_k_models[self.kth_best_model]
_op = min if self.mode == 'min' else max
self.best = _op(self.best_k_models.values())
if self.verbose > 0:
log.info(
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)