-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.optim.lr_scheduler.ReduceLROnPlateau #320
Changes from 7 commits
1a173d8
acb30c0
6827262
b6f827b
fa4094e
d232a7c
94a285d
26844ac
cae83ca
29c5480
cd946a3
78b975b
4f5ae52
c9d5618
75faae5
5465466
d16a6c0
7650fa4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler | ||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateauScheduler, GradientAccumulationScheduler | ||
|
||
__all__ = [ | ||
'EarlyStopping', | ||
'ModelCheckpoint', | ||
'ReduceLROnPlateauScheduler', | ||
'GradientAccumulationScheduler', | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,6 +145,36 @@ def on_train_end(self, logs=None): | |
print('Epoch %05d: early stopping' % (self.stopped_epoch + 1)) | ||
|
||
|
||
class ReduceLROnPlateauScheduler(Callback): | ||
""" | ||
Reduce learning rate when the monitored metric has stopped improving. | ||
Wrapper for torch.optim.lr_schuduler.ReduceLROnPlateau learning rate | ||
schedulers. | ||
|
||
# Arguments | ||
schedulers: list of torch.optim.lr_scheduler.ReduceLROnPlateau | ||
monitor: quantity to be monitored. | ||
""" | ||
|
||
def __init__(self, schedulers, monitor='val_loss'): | ||
super(ReduceLROnPlateauScheduler, self).__init__() | ||
|
||
self.monitor = monitor | ||
self.schedulers = schedulers | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
current = logs.get(self.monitor) | ||
stop_training = False | ||
if current is None: | ||
print('ReduceLROnPlateau conditioned on metric `%s` ' | ||
'which is not available. Available metrics are: %s' % | ||
(self.monitor, ','.join(list(logs.keys()))), RuntimeWarning) | ||
exit(-1) | ||
|
||
for scheduler in self.schedulers: | ||
scheduler.step(current, epoch=epoch) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to create our own ReduceLROnPlateauScheduler? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
class ModelCheckpoint(Callback): | ||
"""Save the model after every epoch. | ||
`filepath` can contain named formatting options, | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -21,7 +21,7 @@ | |||||||||||||
from pytorch_lightning.pt_overrides.override_data_parallel import ( | ||||||||||||||
LightningDistributedDataParallel, LightningDataParallel) | ||||||||||||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler, \ | ||||||||||||||
ModelCheckpoint, EarlyStopping | ||||||||||||||
ReduceLROnPlateauScheduler, ModelCheckpoint, EarlyStopping | ||||||||||||||
from pytorch_lightning.utilities.debugging import MisconfigurationException | ||||||||||||||
import pdb | ||||||||||||||
from pytorch_lightning.trainer import ignored_warnings | ||||||||||||||
|
@@ -190,6 +190,7 @@ def __init__(self, | |||||||||||||
else: | ||||||||||||||
self.early_stop_callback = early_stop_callback | ||||||||||||||
self.enable_early_stop = True | ||||||||||||||
self.lr_scheduler_callback = None | ||||||||||||||
|
||||||||||||||
# configure logger | ||||||||||||||
if logger is True: | ||||||||||||||
|
@@ -793,12 +794,25 @@ def init_optimizers(self, optimizers): | |||||||||||||
# two lists | ||||||||||||||
elif len(optimizers) == 2 and isinstance(optimizers[0], list): | ||||||||||||||
optimizers, lr_schedulers = optimizers | ||||||||||||||
lr_schedulers = self.configure_schedulers(lr_schedulers) | ||||||||||||||
return optimizers, lr_schedulers | ||||||||||||||
|
||||||||||||||
# single list or tuple | ||||||||||||||
elif isinstance(optimizers, list) or isinstance(optimizers, tuple): | ||||||||||||||
return optimizers, [] | ||||||||||||||
|
||||||||||||||
def configure_schedulers(self, schedulers): | ||||||||||||||
custom_schedulers = [] | ||||||||||||||
i = 0 | ||||||||||||||
while i < len(schedulers): | ||||||||||||||
if isinstance(schedulers[i], torch.optim.lr_scheduler.ReduceLROnPlateau): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. proof here |
||||||||||||||
custom_schedulers.append(schedulers.pop(i)) | ||||||||||||||
i += 1 | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a small issue with this snippet. When the ReduceLROnPlateau optimizer is pop'ed, the i should not be increased, otherwise, the element following the element being pop'ed ends up in position schedulers[i] and then i is immediately increased, so that element never gets checked ( to see if it is another ReduceLROnPlateau.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi, thank you for the fix. I decided to support only one ReduceLROnPlateau scheduler. |
||||||||||||||
if custom_schedulers: | ||||||||||||||
self.lr_scheduler_callback = ReduceLROnPlateauScheduler(custom_schedulers, | ||||||||||||||
monitor='val_loss') | ||||||||||||||
return schedulers | ||||||||||||||
|
||||||||||||||
def __single_gpu_train(self, model): | ||||||||||||||
# CHOOSE OPTIMIZER | ||||||||||||||
# allow for lr schedulers as well | ||||||||||||||
|
@@ -1096,7 +1110,7 @@ def __train(self): | |||||||||||||
# update LR schedulers | ||||||||||||||
if self.lr_schedulers is not None: | ||||||||||||||
for lr_scheduler in self.lr_schedulers: | ||||||||||||||
lr_scheduler.step(self.current_epoch) | ||||||||||||||
lr_scheduler.step(epoch=self.current_epoch) | ||||||||||||||
|
||||||||||||||
# early stopping | ||||||||||||||
met_min_epochs = epoch_nb > self.min_nb_epochs | ||||||||||||||
|
@@ -1544,6 +1558,11 @@ def __run_evaluation(self, test=False): | |||||||||||||
tqdm_metrics = self.__training_tqdm_dict | ||||||||||||||
self.progress_bar.set_postfix(**tqdm_metrics) | ||||||||||||||
|
||||||||||||||
# reduce learning rate based on metrics | ||||||||||||||
if self.lr_scheduler_callback is not None and not test: | ||||||||||||||
self.lr_scheduler_callback.on_epoch_end(epoch=self.current_epoch, | ||||||||||||||
logs=self.__training_tqdm_dict) | ||||||||||||||
|
||||||||||||||
# model checkpointing | ||||||||||||||
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: | ||||||||||||||
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like a relative import which we shall not use... :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Relative imports of
EarlyStopping
,ModelCheckpoint
etc. are taken from the original repository. Why is relative import ofReduceLROnPlateauScheduler
inappropriate?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#402 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#402 never fixed relative imports in callbacks init as well as in many other places. I'd say that above comment is out of scope for this PR. @Borda it might be better to create a separate PR that will properly fix relative imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was not that the PR was fixing relative imports, but I tried to make them which was stopped...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's do a separate PR for relative imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have opened ticket #459