Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make monitor required arg of EarlyStopping callback #10328

Merged
merged 18 commits into from
Nov 9, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Raise exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418))


-
- The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328))


-
Expand Down
2 changes: 1 addition & 1 deletion docs/source/common/lightning_cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ This can be implemented as follows:
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_lightning_class_args(EarlyStopping, "my_early_stopping")
parser.set_defaults({"my_early_stopping.patience": 5})
parser.set_defaults({"my_early_stopping.monitor": "val_loss", "my_early_stopping.patience": 5})


cli = MyLightningCLI(MyModel)
Expand Down
12 changes: 3 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +89,7 @@ class EarlyStopping(Callback):

def __init__(
self,
monitor: Optional[str] = None,
monitor: str,
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
Expand All @@ -101,6 +101,7 @@ def __init__(
check_on_train_epoch_end: Optional[bool] = None,
):
super().__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
Expand All @@ -120,13 +121,6 @@ def __init__(
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

if monitor is None:
rank_zero_deprecation(
"The `EarlyStopping(monitor)` argument will be required starting in v1.6."
" For backward compatibility, setting this to `early_stop_on`."
)
self.monitor = monitor or "early_stop_on"

@property
def state_key(self) -> str:
return self._generate_state_key(monitor=self.monitor, mode=self.mode)
Expand Down
4 changes: 2 additions & 2 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def training_epoch_end(self, outputs):


def test_pickling(tmpdir):
early_stopping = EarlyStopping()
early_stopping = EarlyStopping("foo")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

early_stopping_pickled = pickle.dumps(early_stopping)
early_stopping_loaded = pickle.loads(early_stopping_pickled)
Expand Down Expand Up @@ -350,7 +350,7 @@ def validation_epoch_end(self, outputs):

def test_early_stopping_mode_options():
with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"):
EarlyStopping(mode="unknown_option")
EarlyStopping("foo", mode="unknown_option")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


class EarlyStoppingModel(BoringModel):
Expand Down
9 changes: 0 additions & 9 deletions tests/deprecated_api/test_remove_1-6.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import pytest

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.plugins.training_type import DDPPlugin
from pytorch_lightning.utilities.distributed import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -73,14 +72,6 @@ def test_v1_6_0_is_overridden_model():
assert not is_overridden("foo", model=model)


def test_v1_6_0_early_stopping_monitor(tmpdir):
with pytest.deprecated_call(
match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6."
" For backward compatibility, setting this to `early_stop_on`."
):
EarlyStopping()


def test_v1_6_0_extras_with_gradients(tmpdir):
class TestModel(BoringModel):
def training_step(self, *args):
Expand Down
18 changes: 10 additions & 8 deletions tests/trainer/connectors/test_callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_checkpoint_callbacks_are_last(tmpdir):
checkpoint1 = ModelCheckpoint(tmpdir)
checkpoint2 = ModelCheckpoint(tmpdir)
model_summary = ModelSummary()
early_stopping = EarlyStopping()
early_stopping = EarlyStopping("foo")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
lr_monitor = LearningRateMonitor()
progress_bar = TQDMProgressBar()

Expand Down Expand Up @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
cb_connector._attach_model_callbacks()
return trainer

early_stopping = EarlyStopping()
early_stopping = EarlyStopping("foo")
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
progress_bar = TQDMProgressBar()
lr_monitor = LearningRateMonitor()
grad_accumulation = GradientAccumulationScheduler({1: 1})
Expand All @@ -169,14 +169,14 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):

# same callback type twice, different instance
trainer = _attach_callbacks(
trainer_callbacks=[progress_bar, EarlyStopping()],
trainer_callbacks=[progress_bar, EarlyStopping("foo")],
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
model_callbacks=[early_stopping],
)
assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping]

# multiple callbacks of the same type in trainer
trainer = _attach_callbacks(
trainer_callbacks=[LearningRateMonitor(), EarlyStopping(), LearningRateMonitor(), EarlyStopping()],
trainer_callbacks=[LearningRateMonitor(), EarlyStopping("foo"), LearningRateMonitor(), EarlyStopping("foo")],
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
model_callbacks=[early_stopping, lr_monitor],
)
assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]
Expand All @@ -186,9 +186,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
trainer_callbacks=[
LearningRateMonitor(),
progress_bar,
EarlyStopping(),
EarlyStopping("foo"),
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
LearningRateMonitor(),
EarlyStopping(),
EarlyStopping("foo"),
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
],
model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping],
)
Expand All @@ -198,8 +198,10 @@ def _attach_callbacks(trainer_callbacks, model_callbacks):
def test_attach_model_callbacks_override_info(caplog):
"""Test that the logs contain the info about overriding callbacks returned by configure_callbacks."""
model = LightningModule()
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()]
trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()])
model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("foo")]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(
enable_checkpointing=False, callbacks=[EarlyStopping("foo"), LearningRateMonitor(), TQDMProgressBar()]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)
trainer.model = model
cb_connector = CallbackConnector(trainer)
with caplog.at_level(logging.INFO):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/flags/test_fast_dev_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_step(self, batch, batch_idx):

checkpoint_callback = ModelCheckpoint()
checkpoint_callback.save_checkpoint = Mock()
early_stopping_callback = EarlyStopping()
early_stopping_callback = EarlyStopping(monitor="foo")
early_stopping_callback._evaluate_stopping_criteria = Mock()
trainer_config = dict(
default_root_dir=tmpdir,
Expand Down