Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make monitor required arg of EarlyStopping callback #10328

Merged
merged 18 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328))


-
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ This can be implemented as follows:
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(EarlyStopping, "my_early_stopping")
parser.set_defaults({"my_early_stopping.patience": 5})
parser.set_defaults({"monitor": "val_loss", "my_early_stopping.patience": 5})
carmocca marked this conversation as resolved.
Show resolved Hide resolved


cli = MyLightningCLI(MyModel)
Expand Down
12 changes: 3 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +89,7 @@ class EarlyStopping(Callback):

def __init__(
self,
monitor: Optional[str] = None,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
Expand All @@ -101,6 +101,7 @@ def __init__(
check_on_train_epoch_end: Optional[bool] = None,
):
super().__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
Expand All @@ -120,13 +121,6 @@ def __init__(
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

if monitor is None:
rank_zero_deprecation(
"The `EarlyStopping(monitor)` argument will be required starting in v1.6."
" For backward compatibility, setting this to `early_stop_on`."
)
self.monitor = monitor or "early_stop_on"

@property
def state_key(self) -> str:
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
Expand Down
32 changes: 15 additions & 17 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@


def test_early_stopping_state_key():
early_stopping = EarlyStopping(monitor="val_loss")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
early_stopping = EarlyStopping("val_loss")
assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}"


class EarlyStoppingTestRestore(EarlyStopping):
# this class has to be defined outside the test function, otherwise we get pickle error
def __init__(self, expected_state, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, expected_state, monitor, *args, **kwargs):
super().__init__(monitor, *args, **kwargs)
self.expected_state = expected_state
# cache the state for each epoch
self.saved_states = []
Expand All @@ -65,7 +65,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
model = ClassificationModel()
dm = ClassifDataModule()
checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1)
early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss")
early_stop_callback = EarlyStoppingTestRestore(None, "train_loss")
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback, checkpoint_callback],
Expand All @@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
assert checkpoint["callbacks"][es_name] == early_stop_callback_state

# ensure state is reloaded properly (assertion in the callback)
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss")
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, "train_loss")
new_trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
Expand All @@ -101,7 +101,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
"""Test to ensure that callback methods aren't being invoked outside of the callback handler."""
model = ClassificationModel()
dm = ClassifDataModule()
early_stop_callback = EarlyStopping(monitor="train_loss")
early_stop_callback = EarlyStopping("train_loss")
early_stop_callback._run_early_stopping_check = Mock()
expected_count = 4
trainer = Trainer(
Expand Down Expand Up @@ -134,7 +134,7 @@ def validation_epoch_end(self, outputs):
self.log("test_val_loss", loss)

model = ModelOverrideValidationReturn()
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
Expand Down Expand Up @@ -168,9 +168,7 @@ def training_epoch_end(self, outputs):
if validation_step_none:
model.validation_step = None

early_stop_callback = EarlyStopping(
monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True
)
early_stop_callback = EarlyStopping("train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
Expand All @@ -183,7 +181,7 @@ def training_epoch_end(self, outputs):


def test_pickling(tmpdir):
early_stopping = EarlyStopping()
early_stopping = EarlyStopping("val_loss")
carmocca marked this conversation as resolved.
Show resolved Hide resolved

early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
Expand All @@ -202,7 +200,7 @@ def test_early_stopping_no_val_step(tmpdir):
model.validation_step = None
model.val_dataloader = None

stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True)
stopping = EarlyStopping("train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10)
trainer.fit(model, datamodule=dm)

Expand All @@ -226,7 +224,7 @@ def validation_epoch_end(self, outputs):

model = CurrentModel()
early_stopping = EarlyStopping(
monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold
"abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold
)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20)
trainer.fit(model)
Expand All @@ -245,7 +243,7 @@ def validation_epoch_end(self, outputs):
self.log("val_loss", val_loss)

model = CurrentModel()
early_stopping = EarlyStopping(monitor="val_loss", check_finite=True)
early_stopping = EarlyStopping("val_loss", check_finite=True)
trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=10)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
Expand Down Expand Up @@ -309,7 +307,7 @@ def validation_epoch_end(self, outputs):
model = Model(step_freeze)
model.training_step_end = None
model.test_dataloader = None
early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True)
early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stop_callback],
Expand Down Expand Up @@ -350,7 +348,7 @@ def validation_epoch_end(self, outputs):

def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(mode="unknown_option")
EarlyStopping("val_loss", mode="unknown_option")


class EarlyStoppingModel(BoringModel):
Expand Down Expand Up @@ -453,7 +451,7 @@ def validation_step(self, batch, batch_idx):
trainer = Trainer(
default_root_dir=tmpdir,
limit_val_batches=1,
callbacks=EarlyStopping(monitor="foo"),
callbacks=EarlyStopping("foo"),
enable_progress_bar=False,
**kwargs,
)
Expand Down
9 changes: 0 additions & 9 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins import PrecisionPlugin
from pytorch_lightning.plugins.environments import (
KubeflowEnvironment,
Expand Down Expand Up @@ -231,14 +230,6 @@ def test_v1_6_0_is_overridden_model():
assert not is_overridden("foo", model=model)


def test_v1_6_0_early_stopping_monitor(tmpdir):
with pytest.deprecated_call(
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."
" For backward compatibility, setting this to `early_stop_on`."
):
EarlyStopping()


def test_v1_6_0_extras_with_gradients(tmpdir):
class TestModel(BoringModel):
def training_step(self, *args):
Expand Down
23 changes: 15 additions & 8 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_checkpoint_callbacks_are_last(tmpdir):
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
model_summary = ModelSummary()
early_stopping = EarlyStopping()
early_stopping = EarlyStopping("val_loss")
lr_monitor = LearningRateMonitor()
progress_bar = TQDMProgressBar()

Expand Down Expand Up @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
cb_connector._attach_model_callbacks()
return trainer

early_stopping = EarlyStopping()
early_stopping = EarlyStopping("val_loss")
progress_bar = TQDMProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})
Expand All @@ -169,14 +169,19 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):

# same callback type twice, different instance
trainer = _attach_callbacks(
trainer_callbacks=[progress_bar, EarlyStopping()],
trainer_callbacks=[progress_bar, EarlyStopping("val_loss")],
model_callbacks=[early_stopping],
)
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]

# multiple callbacks of the same type in trainer
trainer = _attach_callbacks(
trainer_callbacks=[LearningRateMonitor(), EarlyStopping(), LearningRateMonitor(), EarlyStopping()],
trainer_callbacks=[
LearningRateMonitor(),
EarlyStopping("val_loss"),
LearningRateMonitor(),
EarlyStopping("val_loss"),
],
model_callbacks=[early_stopping, lr_monitor],
)
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]
Expand All @@ -186,9 +191,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
trainer_callbacks=[
LearningRateMonitor(),
progress_bar,
EarlyStopping(),
EarlyStopping("val_loss"),
LearningRateMonitor(),
EarlyStopping(),
EarlyStopping("val_loss"),
],
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
)
Expand All @@ -198,8 +203,10 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()])
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("val_loss")]
trainer = Trainer(
enable_checkpointing=False, callbacks=[EarlyStopping("val_loss"), LearningRateMonitor(), TQDMProgressBar()]
)
trainer.model = model
cb_connector = CallbackConnector(trainer)
with caplog.at_level(logging.INFO):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_step(self, batch, batch_idx):

checkpoint_callback = ModelCheckpoint()
checkpoint_callback.save_checkpoint = Mock()
early_stopping_callback = EarlyStopping()
early_stopping_callback = EarlyStopping("val_loss")
early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict(
default_root_dir=tmpdir,
Expand Down