Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate save_function from model checkpoint callback #7201

Merged
merged 7 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))


- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))


Expand Down Expand Up @@ -190,6 +193,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Removed


- Removed `automatic_optimization` as a property from the training loop in favor of `LightningModule.automatic_optimization` ([#7130](https://github.com/PyTorchLightning/pytorch-lightning/pull/7130))


Expand Down
33 changes: 23 additions & 10 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -201,19 +201,19 @@ def __init__(
self.best_model_score = None
self.best_model_path = ""
self.last_model_path = ""
self.save_function = None

self.__init_monitor_mode(monitor, mode)
self.__init_ckpt_dir(dirpath, filename, save_top_k)
self.__init_triggers(every_n_train_steps, every_n_val_epochs, period)
self.__validate_init_configuration()
self._save_function = None

def on_pretrain_routine_start(self, trainer, pl_module):
"""
When pretrain routine starts we build the ckpt dir on the fly
"""
self.__resolve_ckpt_dir(trainer)
self.save_function = trainer.save_checkpoint
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
self._save_function = trainer.save_checkpoint

def on_train_batch_end(
self, trainer, pl_module, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int
Expand Down Expand Up @@ -254,9 +254,9 @@ def on_load_checkpoint(self, callback_state: Dict[str, Any]):

def save_checkpoint(self, trainer, unused: Optional = None):
"""
Performs the main logic around saving a checkpoint.
This method runs on all ranks, it is the responsibility of `self.save_function`
to handle correct behaviour in distributed training, i.e., saving only on rank 0.
Performs the main logic around saving a checkpoint. This method runs on all ranks.
It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training,
i.e., saving only on rank 0 for data parallel use cases.
"""
if unused is not None:
rank_zero_deprecation(
Expand Down Expand Up @@ -396,6 +396,22 @@ def period(self, value: Optional[int]) -> None:
)
self._period = value

@property
def save_function(self) -> Optional[Callable]:
rank_zero_deprecation(
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `trainer.save_checkpoint` instead.'
)
return self._save_function

@save_function.setter
def save_function(self, value: Optional[Callable]) -> None:
rank_zero_deprecation(
'Property `save_function` in `ModelCheckpoint` is deprecated in v1.3 and will be removed in v1.5.'
' Please use `trainer.save_checkpoint` instead.'
)
self._save_function = value

@rank_zero_only
def _del_model(self, filepath: str):
if self._fs.exists(filepath):
Expand All @@ -420,10 +436,7 @@ def _do_save(self, trainer, filepath: str):
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)

# delegate the saving to the trainer
if self.save_function is not None:
self.save_function(filepath, self.save_weights_only)
else:
raise ValueError(".save_function() not set")
trainer.save_checkpoint(filepath, self.save_weights_only)

def check_monitor_top_k(self, trainer, current: Optional[torch.Tensor] = None) -> bool:
if current is None:
Expand Down
13 changes: 11 additions & 2 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@

def test_v1_5_0_model_checkpoint_save_checkpoint():
model_ckpt = ModelCheckpoint()
model_ckpt.save_function = lambda *_, **__: None
trainer = Trainer()
trainer.save_checkpoint = lambda *_, **__: None
with pytest.deprecated_call(match="ModelCheckpoint.save_checkpoint` signature has changed"):
model_ckpt.save_checkpoint(Trainer(), object())
model_ckpt.save_checkpoint(trainer, object())


def test_v1_5_0_model_checkpoint_save_function():
model_ckpt = ModelCheckpoint()
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
model_ckpt.save_function = lambda *_, **__: None
with pytest.deprecated_call(match="Property `save_function` in `ModelCheckpoint` is deprecated in v1.3"):
_ = model_ckpt.save_function


@mock.patch('pytorch_lightning.loggers.wandb.wandb')
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,9 @@ def mock_save_function(filepath, *args):
save_last=save_last,
verbose=True
)
checkpoint_callback.save_function = mock_save_function
trainer = Trainer()
trainer.state = TrainerState.FITTING
trainer.save_checkpoint = mock_save_function

# emulate callback's calls during the training
for i, loss in enumerate(losses):
Expand Down