Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove magic monitor support for ModelCheckpoint #8293

Merged
merged 8 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([#7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))


- Removed support for automatically monitoring the `val_loss` key with `ModelCheckpoint`. Pass your `monitor` of choice to the `ModelCheckpoint` instance instead ([#8293](https://github.com/PyTorchLightning/pytorch-lightning/pull/8293))


- Removed support for `self.log(tbptt_reduce_fx)` and `self.log(tbptt_pad_token)`. Please, open a discussion explaining your use-case if you relied on these. ([#7644](https://github.com/PyTorchLightning/pytorch-lightning/pull/7644))


Expand Down
48 changes: 10 additions & 38 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __init__(
monitor: Optional[str] = None,
verbose: bool = False,
save_last: Optional[bool] = None,
save_top_k: Optional[int] = None,
save_top_k: int = 1,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
save_weights_only: bool = False,
mode: str = "min",
auto_insert_metric_name: bool = True,
Expand All @@ -221,7 +221,7 @@ def __init__(
self.last_model_path = ""

self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, train_time_interval, period)
self.__validate_init_configuration()
self._save_function = None
Expand Down Expand Up @@ -313,7 +313,6 @@ def save_checkpoint(self, trainer: 'pl.Trainer', unused: Optional['pl.LightningM
epoch = trainer.current_epoch
global_step = trainer.global_step

self._add_backward_monitor_support(trainer)
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
Expand Down Expand Up @@ -345,8 +344,8 @@ def _should_skip_saving_checkpoint(self, trainer: 'pl.Trainer') -> bool:
)

def __validate_init_configuration(self) -> None:
if self.save_top_k is not None and self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be None or >= -1')
if self.save_top_k < -1:
raise MisconfigurationException(f'Invalid value for save_top_k={self.save_top_k}. Must be >= -1')
if self._every_n_train_steps < 0:
raise MisconfigurationException(
f'Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0'
Expand All @@ -359,16 +358,16 @@ def __validate_init_configuration(self) -> None:
every_n_train_steps_triggered = self._every_n_train_steps >= 1
every_n_val_epochs_triggered = self._every_n_val_epochs >= 1
train_time_interval_triggered = self._train_time_interval is not None
if (every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1):
if every_n_train_steps_triggered + every_n_val_epochs_triggered + train_time_interval_triggered > 1:
raise MisconfigurationException(
f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "
f"every_n_val_epochs={self._every_n_val_epochs} and train_time_interval={self._train_time_interval} "
"should be mutually exclusive."
)

if self.monitor is None:
# None: save last epoch, -1: save all epochs, 0: nothing is saved
if self.save_top_k not in (None, -1, 0):
# -1: save all epochs, 0: nothing is saved, 1: save last epoch
if self.save_top_k not in (-1, 0, 1):
raise MisconfigurationException(
f'ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid'
' configuration. No quantity for top_k to track.'
Expand All @@ -384,18 +383,10 @@ def __validate_init_configuration(self) -> None:
' will duplicate the last checkpoint saved.'
)

def __init_ckpt_dir(
self,
dirpath: Optional[Union[str, Path]],
filename: Optional[str],
save_top_k: Optional[int],
) -> None:
def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None:
self._fs = get_filesystem(str(dirpath) if dirpath else '')

if (
save_top_k is not None and save_top_k > 0 and dirpath is not None and self._fs.isdir(dirpath)
and len(self._fs.ls(dirpath)) > 0
):
if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

if dirpath and self._fs.protocol == 'file':
Expand Down Expand Up @@ -628,25 +619,6 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None:
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
self._fs.makedirs(self.dirpath, exist_ok=True)

def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
metrics = trainer.callback_metrics
deprecation_warning = False

if self.monitor is None and 'val_loss' in metrics:
self.monitor = 'val_loss'
deprecation_warning = True

if self.save_top_k is None and self.monitor is not None:
# TODO: Remove `Optional` from `save_top_k` when this is deleted in v1.4
self.save_top_k = 1

if deprecation_warning:
warning_cache.deprecation(
"Relying on `self.log('val_loss', ...)` to set the ModelCheckpoint monitor is deprecated in v1.2"
" and will be removed in v1.4. Please, create your own `mc = ModelCheckpoint(monitor='your_monitor')`"
" and use it as `Trainer(callbacks=[mc])`.",
)

def _validate_monitor_key(self, trainer: 'pl.Trainer') -> None:
metrics = trainer.callback_metrics

Expand Down Expand Up @@ -717,7 +689,7 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate
self._save_model(trainer, filepath)

if (
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
self.save_top_k == 1 and self.best_model_path and self.best_model_path != filepath
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
and trainer.should_rank_save_checkpoint
):
self._del_model(trainer, self.best_model_path)
Expand Down
4 changes: 2 additions & 2 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def test_model_checkpoint_save_last(tmpdir):

def test_invalid_top_k(tmpdir):
""" Make sure that a MisconfigurationException is raised for a negative save_top_k argument. """
with pytest.raises(MisconfigurationException, match=r'.*Must be None or >= -1'):
with pytest.raises(MisconfigurationException, match=r'.*Must be >= -1'):
ModelCheckpoint(dirpath=tmpdir, save_top_k=-3)


Expand All @@ -544,9 +544,9 @@ def test_none_monitor_top_k(tmpdir):
):
ModelCheckpoint(dirpath=tmpdir, save_top_k=3)
# These should not fail
ModelCheckpoint(dirpath=tmpdir, save_top_k=None)
ModelCheckpoint(dirpath=tmpdir, save_top_k=-1)
ModelCheckpoint(dirpath=tmpdir, save_top_k=0)
ModelCheckpoint(dirpath=tmpdir, save_top_k=1)


def test_none_monitor_save_last(tmpdir):
Expand Down
16 changes: 0 additions & 16 deletions tests/deprecated_api/test_remove_1-4.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ def test_v1_4_0_deprecated_imports():
from pytorch_lightning.utilities.argparse_utils import _gpus_arg_default # noqa: F811 F401


def test_v1_4_0_deprecated_checkpoint_on(tmpdir):
from pytorch_lightning.callbacks.model_checkpoint import warning_cache
warning_cache.clear()

class TestModel(BoringModel):

def training_step(self, batch, batch_idx):
self.log("val_loss", -batch_idx)
return super().training_step(batch, batch_idx)

trainer = Trainer(default_root_dir=tmpdir, checkpoint_callback=True, max_epochs=1)

with pytest.deprecated_call(match=r"Relying on.*is deprecated in v1.2 and will be removed in v1.4"):
trainer.fit(TestModel())


def test_v1_4_0_deprecated_hpc_load(tmpdir):
model = BoringModel()
trainer = Trainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def test_wandb_log_model(wandb, tmpdir):
'monitor': None,
'mode': 'min',
'save_last': None,
'save_top_k': None,
'save_top_k': 1,
'save_weights_only': False,
'_every_n_train_steps': 0,
'_every_n_val_epochs': 1
Expand Down