From 01dc37893938e0e75c6f057e709a4021b5f14d37 Mon Sep 17 00:00:00 2001 From: Ross Johnstone Date: Wed, 3 Nov 2021 13:40:35 +0900 Subject: [PATCH 01/13] Make required arg of EarlyStopping --- pytorch_lightning/callbacks/early_stopping.py | 10 ++-------- tests/deprecated_api/test_remove_1-6.py | 9 --------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b5118846875db..fdb66c4113c1c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -89,7 +89,7 @@ class EarlyStopping(Callback): def __init__( self, - monitor: Optional[str] = None, + monitor: str, min_delta: float = 0.0, patience: int = 3, verbose: bool = False, @@ -101,6 +101,7 @@ def __init__( check_on_train_epoch_end: Optional[bool] = None, ): super().__init__() + self.monitor = monitor self.min_delta = min_delta self.patience = patience self.verbose = verbose @@ -120,13 +121,6 @@ def __init__( torch_inf = torch.tensor(np.Inf) self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf - if monitor is None: - rank_zero_deprecation( - "The `EarlyStopping(monitor)` argument will be required starting in v1.6." - " For backward compatibility, setting this to `early_stop_on`." - ) - self.monitor = monitor or "early_stop_on" - @property def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) diff --git a/tests/deprecated_api/test_remove_1-6.py b/tests/deprecated_api/test_remove_1-6.py index 62791b482c186..3bd995b4c1ec6 100644 --- a/tests/deprecated_api/test_remove_1-6.py +++ b/tests/deprecated_api/test_remove_1-6.py @@ -20,7 +20,6 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.plugins import PrecisionPlugin from pytorch_lightning.plugins.environments import ( KubeflowEnvironment, @@ -242,14 +241,6 @@ def test_v1_6_0_is_overridden_model(): assert not is_overridden("foo", model=model) -def test_v1_6_0_early_stopping_monitor(tmpdir): - with pytest.deprecated_call( - match=r"The `EarlyStopping\(monitor\)` argument will be required starting in v1.6." - " For backward compatibility, setting this to `early_stop_on`." - ): - EarlyStopping() - - def test_v1_6_0_extras_with_gradients(tmpdir): class TestModel(BoringModel): def training_step(self, *args): From d9d001bf993e6a75f251ec8803bf62dfa9f085f4 Mon Sep 17 00:00:00 2001 From: Ross Johnstone Date: Thu, 4 Nov 2021 11:52:36 +0900 Subject: [PATCH 02/13] Fix failing tests with EarlyStopping requiring monitor arg --- tests/callbacks/test_early_stopping.py | 32 +++++++++---------- .../connectors/test_callback_connector.py | 23 ++++++++----- tests/trainer/flags/test_fast_dev_run.py | 2 +- 3 files changed, 31 insertions(+), 26 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 2b4fe9f05eb87..dea9ea01b3b34 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -34,14 +34,14 @@ def test_early_stopping_state_key(): - early_stopping = EarlyStopping(monitor="val_loss") + early_stopping = EarlyStopping("val_loss") assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error - def __init__(self, expected_state, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, expected_state, monitor, *args, **kwargs): + super().__init__(monitor, *args, **kwargs) self.expected_state = expected_state # cache the state for each epoch self.saved_states = [] @@ -65,7 +65,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) - early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") + early_stop_callback = EarlyStoppingTestRestore(None, "train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], @@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) - early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") + early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, "train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -101,7 +101,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() - early_stop_callback = EarlyStopping(monitor="train_loss") + early_stop_callback = EarlyStopping("train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( @@ -134,7 +134,7 @@ def validation_epoch_end(self, outputs): self.log("test_val_loss", loss) model = ModelOverrideValidationReturn() - early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -168,9 +168,7 @@ def training_epoch_end(self, outputs): if validation_step_none: model.validation_step = None - early_stop_callback = EarlyStopping( - monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True - ) + early_stop_callback = EarlyStopping("train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -183,7 +181,7 @@ def training_epoch_end(self, outputs): def test_pickling(tmpdir): - early_stopping = EarlyStopping() + early_stopping = EarlyStopping("val_loss") early_stopping_pickled = pickle.dumps(early_stopping) early_stopping_loaded = pickle.loads(early_stopping_pickled) @@ -202,7 +200,7 @@ def test_early_stopping_no_val_step(tmpdir): model.validation_step = None model.val_dataloader = None - stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) + stopping = EarlyStopping("train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model, datamodule=dm) @@ -226,7 +224,7 @@ def validation_epoch_end(self, outputs): model = CurrentModel() early_stopping = EarlyStopping( - monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold + "abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold ) trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20) trainer.fit(model) @@ -245,7 +243,7 @@ def validation_epoch_end(self, outputs): self.log("val_loss", val_loss) model = CurrentModel() - early_stopping = EarlyStopping(monitor="val_loss", check_finite=True) + early_stopping = EarlyStopping("val_loss", check_finite=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model) assert trainer.current_epoch == expected_stop_epoch @@ -309,7 +307,7 @@ def validation_epoch_end(self, outputs): model = Model(step_freeze) model.training_step_end = None model.test_dataloader = None - early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) + early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -350,7 +348,7 @@ def validation_epoch_end(self, outputs): def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): - EarlyStopping(mode="unknown_option") + EarlyStopping("val_loss", mode="unknown_option") class EarlyStoppingModel(BoringModel): @@ -453,7 +451,7 @@ def validation_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmpdir, limit_val_batches=1, - callbacks=EarlyStopping(monitor="foo"), + callbacks=EarlyStopping("foo"), enable_progress_bar=False, **kwargs, ) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 7ec238acf5682..be1aa1ad4aeaf 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -33,7 +33,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): checkpoint1 = ModelCheckpoint(tmpdir) checkpoint2 = ModelCheckpoint(tmpdir) model_summary = ModelSummary() - early_stopping = EarlyStopping() + early_stopping = EarlyStopping("val_loss") lr_monitor = LearningRateMonitor() progress_bar = TQDMProgressBar() @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): cb_connector._attach_model_callbacks() return trainer - early_stopping = EarlyStopping() + early_stopping = EarlyStopping("val_loss") progress_bar = TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) @@ -169,14 +169,19 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): # same callback type twice, different instance trainer = _attach_callbacks( - trainer_callbacks=[progress_bar, EarlyStopping()], + trainer_callbacks=[progress_bar, EarlyStopping("val_loss")], model_callbacks=[early_stopping], ) assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping] # multiple callbacks of the same type in trainer trainer = _attach_callbacks( - trainer_callbacks=[LearningRateMonitor(), EarlyStopping(), LearningRateMonitor(), EarlyStopping()], + trainer_callbacks=[ + LearningRateMonitor(), + EarlyStopping("val_loss"), + LearningRateMonitor(), + EarlyStopping("val_loss"), + ], model_callbacks=[early_stopping, lr_monitor], ) assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor] @@ -186,9 +191,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): trainer_callbacks=[ LearningRateMonitor(), progress_bar, - EarlyStopping(), + EarlyStopping("val_loss"), LearningRateMonitor(), - EarlyStopping(), + EarlyStopping("val_loss"), ], model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping], ) @@ -198,8 +203,10 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() - model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping()] - trainer = Trainer(enable_checkpointing=False, callbacks=[EarlyStopping(), LearningRateMonitor(), TQDMProgressBar()]) + model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("val_loss")] + trainer = Trainer( + enable_checkpointing=False, callbacks=[EarlyStopping("val_loss"), LearningRateMonitor(), TQDMProgressBar()] + ) trainer.model = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 2816fe92bc258..8836a0521dc4a 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -65,7 +65,7 @@ def test_step(self, batch, batch_idx): checkpoint_callback = ModelCheckpoint() checkpoint_callback.save_checkpoint = Mock() - early_stopping_callback = EarlyStopping() + early_stopping_callback = EarlyStopping("val_loss") early_stopping_callback._evaluate_stopping_criteria = Mock() trainer_config = dict( default_root_dir=tmpdir, From e5a74703c1533782e8a5fa492e9e0a3b36539695 Mon Sep 17 00:00:00 2001 From: Ross Johnstone Date: Thu, 4 Nov 2021 11:59:00 +0900 Subject: [PATCH 03/13] Update changelog with EarlyStopping requiring monitor arg --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b9c808464dadc..6fd5a4d42e5a2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -183,6 +183,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123)) - Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) - Allowed separate config files for parameters with class type when LightningCLI is in `subclass_mode=False` ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286)) +- EarlyStopping callback now requires `monitor` argument ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328)) ### Deprecated From 138e39603936eb92951aea03f242d90a633d967e Mon Sep 17 00:00:00 2001 From: Ross Johnstone Date: Thu, 4 Nov 2021 17:06:45 +0900 Subject: [PATCH 04/13] Remove unused import --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index fdb66c4113c1c..03b268f714a74 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -26,7 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException log = logging.getLogger(__name__) From e38bcedbb5d3dda14579faed2369056fff754a4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 4 Nov 2021 10:47:06 +0100 Subject: [PATCH 05/13] move changlog --- CHANGELOG.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c310288db0c7..9d99d5b392580 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- The `monitor` argument in the `EarlyStopping` callback is no longer optional ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328)) - @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- - @@ -265,7 +265,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the model size calculation using `ByteCounter` ([#10123](https://github.com/PyTorchLightning/pytorch-lightning/pull/10123)) - Enabled `on_load_checkpoint` for `LightningDataModule` for all `trainer_fn` ([#10238](https://github.com/PyTorchLightning/pytorch-lightning/pull/10238)) - Allowed separate config files for parameters with class type when LightningCLI is in `subclass_mode=False` ([#10286](https://github.com/PyTorchLightning/pytorch-lightning/pull/10286)) -- EarlyStopping callback now requires `monitor` argument ([#10328](https://github.com/PyTorchLightning/pytorch-lightning/pull/10328)) ### Deprecated From f2ed7110270ecfdf9d0d99bce98097a23e85bfef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 4 Nov 2021 09:51:59 +0000 Subject: [PATCH 06/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d99d5b392580..c102268edadad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,7 +66,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - -- +- - From 67823c0d2ff4794d1fdcd3d1ba0707fdaaa42529 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 4 Nov 2021 11:03:34 +0100 Subject: [PATCH 07/13] update docs with required arg --- docs/source/common/lightning_cli.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 2f7b2bae599e4..99e78891ae2fe 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -625,7 +625,7 @@ This can be implemented as follows: class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(EarlyStopping, "my_early_stopping") - parser.set_defaults({"my_early_stopping.patience": 5}) + parser.set_defaults({"monitor": "val_loss", "my_early_stopping.patience": 5}) cli = MyLightningCLI(MyModel) From b14c62db3946f4fdb0fba883fbf0ec37de8966e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 4 Nov 2021 17:35:22 +0100 Subject: [PATCH 08/13] Update docs/source/common/lightning_cli.rst --- docs/source/common/lightning_cli.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/lightning_cli.rst b/docs/source/common/lightning_cli.rst index 99e78891ae2fe..7b4680b2d298c 100644 --- a/docs/source/common/lightning_cli.rst +++ b/docs/source/common/lightning_cli.rst @@ -625,7 +625,7 @@ This can be implemented as follows: class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(EarlyStopping, "my_early_stopping") - parser.set_defaults({"monitor": "val_loss", "my_early_stopping.patience": 5}) + parser.set_defaults({"my_early_stopping.monitor": "val_loss", "my_early_stopping.patience": 5}) cli = MyLightningCLI(MyModel) From 1ab7cab47d45679e3f82e517ddb6b40590f1a3b3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Nov 2021 16:22:58 +0100 Subject: [PATCH 09/13] Address comments --- tests/callbacks/test_early_stopping.py | 32 ++++++++++--------- .../connectors/test_callback_connector.py | 21 +++++------- tests/trainer/flags/test_fast_dev_run.py | 2 +- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index dea9ea01b3b34..fd594dfaa1c42 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -34,14 +34,14 @@ def test_early_stopping_state_key(): - early_stopping = EarlyStopping("val_loss") + early_stopping = EarlyStopping(monitor="val_loss") assert early_stopping.state_key == "EarlyStopping{'monitor': 'val_loss', 'mode': 'min'}" class EarlyStoppingTestRestore(EarlyStopping): # this class has to be defined outside the test function, otherwise we get pickle error - def __init__(self, expected_state, monitor, *args, **kwargs): - super().__init__(monitor, *args, **kwargs) + def __init__(self, expected_state, *args, **kwargs): + super().__init__(*args, **kwargs) self.expected_state = expected_state # cache the state for each epoch self.saved_states = [] @@ -65,7 +65,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): model = ClassificationModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="train_loss", save_top_k=1) - early_stop_callback = EarlyStoppingTestRestore(None, "train_loss") + early_stop_callback = EarlyStoppingTestRestore(None, monitor="train_loss") trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback, checkpoint_callback], @@ -86,7 +86,7 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): assert checkpoint["callbacks"][es_name] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) - early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, "train_loss") + early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor="train_loss") new_trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, @@ -101,7 +101,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir): """Test to ensure that callback methods aren't being invoked outside of the callback handler.""" model = ClassificationModel() dm = ClassifDataModule() - early_stop_callback = EarlyStopping("train_loss") + early_stop_callback = EarlyStopping(monitor="train_loss") early_stop_callback._run_early_stopping_check = Mock() expected_count = 4 trainer = Trainer( @@ -134,7 +134,7 @@ def validation_epoch_end(self, outputs): self.log("test_val_loss", loss) model = ModelOverrideValidationReturn() - early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True) + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -168,7 +168,9 @@ def training_epoch_end(self, outputs): if validation_step_none: model.validation_step = None - early_stop_callback = EarlyStopping("train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True) + early_stop_callback = EarlyStopping( + monitor="train_loss", patience=patience, verbose=True, check_on_train_epoch_end=True + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -181,7 +183,7 @@ def training_epoch_end(self, outputs): def test_pickling(tmpdir): - early_stopping = EarlyStopping("val_loss") + early_stopping = EarlyStopping("foo") early_stopping_pickled = pickle.dumps(early_stopping) early_stopping_loaded = pickle.loads(early_stopping_pickled) @@ -200,7 +202,7 @@ def test_early_stopping_no_val_step(tmpdir): model.validation_step = None model.val_dataloader = None - stopping = EarlyStopping("train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) + stopping = EarlyStopping(monitor="train_loss", min_delta=0.1, patience=0, check_on_train_epoch_end=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model, datamodule=dm) @@ -224,7 +226,7 @@ def validation_epoch_end(self, outputs): model = CurrentModel() early_stopping = EarlyStopping( - "abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold + monitor="abc", stopping_threshold=stopping_threshold, divergence_threshold=divergence_theshold ) trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20) trainer.fit(model) @@ -243,7 +245,7 @@ def validation_epoch_end(self, outputs): self.log("val_loss", val_loss) model = CurrentModel() - early_stopping = EarlyStopping("val_loss", check_finite=True) + early_stopping = EarlyStopping(monitor="val_loss", check_finite=True) trainer = Trainer(default_root_dir=tmpdir, callbacks=[early_stopping], overfit_batches=0.20, max_epochs=10) trainer.fit(model) assert trainer.current_epoch == expected_stop_epoch @@ -307,7 +309,7 @@ def validation_epoch_end(self, outputs): model = Model(step_freeze) model.training_step_end = None model.test_dataloader = None - early_stop_callback = EarlyStopping("test_val_loss", patience=patience, verbose=True) + early_stop_callback = EarlyStopping(monitor="test_val_loss", patience=patience, verbose=True) trainer = Trainer( default_root_dir=tmpdir, callbacks=[early_stop_callback], @@ -348,7 +350,7 @@ def validation_epoch_end(self, outputs): def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): - EarlyStopping("val_loss", mode="unknown_option") + EarlyStopping("foo", mode="unknown_option") class EarlyStoppingModel(BoringModel): @@ -451,7 +453,7 @@ def validation_step(self, batch, batch_idx): trainer = Trainer( default_root_dir=tmpdir, limit_val_batches=1, - callbacks=EarlyStopping("foo"), + callbacks=EarlyStopping(monitor="foo"), enable_progress_bar=False, **kwargs, ) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index be1aa1ad4aeaf..f0e4529a9ef29 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -33,7 +33,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): checkpoint1 = ModelCheckpoint(tmpdir) checkpoint2 = ModelCheckpoint(tmpdir) model_summary = ModelSummary() - early_stopping = EarlyStopping("val_loss") + early_stopping = EarlyStopping("foo") lr_monitor = LearningRateMonitor() progress_bar = TQDMProgressBar() @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): cb_connector._attach_model_callbacks() return trainer - early_stopping = EarlyStopping("val_loss") + early_stopping = EarlyStopping("foo") progress_bar = TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) @@ -169,19 +169,14 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): # same callback type twice, different instance trainer = _attach_callbacks( - trainer_callbacks=[progress_bar, EarlyStopping("val_loss")], + trainer_callbacks=[progress_bar, EarlyStopping("foo")], model_callbacks=[early_stopping], ) assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping] # multiple callbacks of the same type in trainer trainer = _attach_callbacks( - trainer_callbacks=[ - LearningRateMonitor(), - EarlyStopping("val_loss"), - LearningRateMonitor(), - EarlyStopping("val_loss"), - ], + trainer_callbacks=[LearningRateMonitor(), EarlyStopping("foo"), LearningRateMonitor(), EarlyStopping("foo")], model_callbacks=[early_stopping, lr_monitor], ) assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor] @@ -191,9 +186,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): trainer_callbacks=[ LearningRateMonitor(), progress_bar, - EarlyStopping("val_loss"), + EarlyStopping("foo"), LearningRateMonitor(), - EarlyStopping("val_loss"), + EarlyStopping("foo"), ], model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping], ) @@ -203,9 +198,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() - model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("val_loss")] + model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("foo")] trainer = Trainer( - enable_checkpointing=False, callbacks=[EarlyStopping("val_loss"), LearningRateMonitor(), TQDMProgressBar()] + enable_checkpointing=False, callbacks=[EarlyStopping("foo"), LearningRateMonitor(), TQDMProgressBar()] ) trainer.model = model cb_connector = CallbackConnector(trainer) diff --git a/tests/trainer/flags/test_fast_dev_run.py b/tests/trainer/flags/test_fast_dev_run.py index 8836a0521dc4a..cdfdb70e73216 100644 --- a/tests/trainer/flags/test_fast_dev_run.py +++ b/tests/trainer/flags/test_fast_dev_run.py @@ -65,7 +65,7 @@ def test_step(self, batch, batch_idx): checkpoint_callback = ModelCheckpoint() checkpoint_callback.save_checkpoint = Mock() - early_stopping_callback = EarlyStopping("val_loss") + early_stopping_callback = EarlyStopping(monitor="foo") early_stopping_callback._evaluate_stopping_criteria = Mock() trainer_config = dict( default_root_dir=tmpdir, From 696dd088b466b3ea5948edcf775e0ba16d39fa4a Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 9 Nov 2021 16:23:25 +0100 Subject: [PATCH 10/13] whitspace --- CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e700f5f1e2dc5..0e8efe0035b41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - - Raise exception in `init_dist_connection()` when torch distibuted is not available ([#10418](https://github.com/PyTorchLightning/pytorch-lightning/issues/10418)) From 3c493314c19000d813db35ebb87d2fe82bf3d392 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 9 Nov 2021 21:11:24 +0530 Subject: [PATCH 11/13] Apply suggestions from code review --- tests/callbacks/test_early_stopping.py | 4 ++-- .../trainer/connectors/test_callback_connector.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index fd594dfaa1c42..9b20b96778e65 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -183,7 +183,7 @@ def training_epoch_end(self, outputs): def test_pickling(tmpdir): - early_stopping = EarlyStopping("foo") + early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) early_stopping_loaded = pickle.loads(early_stopping_pickled) @@ -350,7 +350,7 @@ def validation_epoch_end(self, outputs): def test_early_stopping_mode_options(): with pytest.raises(MisconfigurationException, match="`mode` can be .* got unknown_option"): - EarlyStopping("foo", mode="unknown_option") + EarlyStopping(monitor="foo", mode="unknown_option") class EarlyStoppingModel(BoringModel): diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index f0e4529a9ef29..98c1213f09881 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -33,7 +33,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): checkpoint1 = ModelCheckpoint(tmpdir) checkpoint2 = ModelCheckpoint(tmpdir) model_summary = ModelSummary() - early_stopping = EarlyStopping("foo") + early_stopping = EarlyStopping(monitor="foo") lr_monitor = LearningRateMonitor() progress_bar = TQDMProgressBar() @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): cb_connector._attach_model_callbacks() return trainer - early_stopping = EarlyStopping("foo") + early_stopping = EarlyStopping(monitor="foo") progress_bar = TQDMProgressBar() lr_monitor = LearningRateMonitor() grad_accumulation = GradientAccumulationScheduler({1: 1}) @@ -169,14 +169,14 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): # same callback type twice, different instance trainer = _attach_callbacks( - trainer_callbacks=[progress_bar, EarlyStopping("foo")], + trainer_callbacks=[progress_bar, EarlyStopping(monitor="foo")], model_callbacks=[early_stopping], ) assert trainer.callbacks == [progress_bar, trainer.accumulation_scheduler, early_stopping] # multiple callbacks of the same type in trainer trainer = _attach_callbacks( - trainer_callbacks=[LearningRateMonitor(), EarlyStopping("foo"), LearningRateMonitor(), EarlyStopping("foo")], + trainer_callbacks=[LearningRateMonitor(), EarlyStopping(monitor="foo"), LearningRateMonitor(), EarlyStopping(monitor="foo")], model_callbacks=[early_stopping, lr_monitor], ) assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor] @@ -186,9 +186,9 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): trainer_callbacks=[ LearningRateMonitor(), progress_bar, - EarlyStopping("foo"), + EarlyStopping(monitor="foo"), LearningRateMonitor(), - EarlyStopping("foo"), + EarlyStopping(monitor="foo"), ], model_callbacks=[early_stopping, lr_monitor, grad_accumulation, early_stopping], ) @@ -198,7 +198,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): def test_attach_model_callbacks_override_info(caplog): """Test that the logs contain the info about overriding callbacks returned by configure_callbacks.""" model = LightningModule() - model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping("foo")] + model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping(monitor="foo")] trainer = Trainer( enable_checkpointing=False, callbacks=[EarlyStopping("foo"), LearningRateMonitor(), TQDMProgressBar()] ) From f48cf63d5326f3b98472449c54cde65a6e4289a5 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 9 Nov 2021 21:12:13 +0530 Subject: [PATCH 12/13] Update tests/trainer/connectors/test_callback_connector.py --- tests/trainer/connectors/test_callback_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 98c1213f09881..84c26fb2bd32c 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -200,7 +200,7 @@ def test_attach_model_callbacks_override_info(caplog): model = LightningModule() model.configure_callbacks = lambda: [LearningRateMonitor(), EarlyStopping(monitor="foo")] trainer = Trainer( - enable_checkpointing=False, callbacks=[EarlyStopping("foo"), LearningRateMonitor(), TQDMProgressBar()] + enable_checkpointing=False, callbacks=[EarlyStopping(monitor="foo"), LearningRateMonitor(), TQDMProgressBar()] ) trainer.model = model cb_connector = CallbackConnector(trainer) From e176b91e26be31668f53e5021a053113fc52b349 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 9 Nov 2021 15:42:34 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/trainer/connectors/test_callback_connector.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_callback_connector.py b/tests/trainer/connectors/test_callback_connector.py index 84c26fb2bd32c..2cb68aa2e95bd 100644 --- a/tests/trainer/connectors/test_callback_connector.py +++ b/tests/trainer/connectors/test_callback_connector.py @@ -176,7 +176,12 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): # multiple callbacks of the same type in trainer trainer = _attach_callbacks( - trainer_callbacks=[LearningRateMonitor(), EarlyStopping(monitor="foo"), LearningRateMonitor(), EarlyStopping(monitor="foo")], + trainer_callbacks=[ + LearningRateMonitor(), + EarlyStopping(monitor="foo"), + LearningRateMonitor(), + EarlyStopping(monitor="foo"), + ], model_callbacks=[early_stopping, lr_monitor], ) assert trainer.callbacks == [trainer.accumulation_scheduler, early_stopping, lr_monitor]