diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a4da285179ed..b507203687f09 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `prefix` argument in `ModelCheckpoint` ([#4765](https://github.com/PyTorchLightning/pytorch-lightning/pull/4765)) ### Removed diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index c2a8c3a6ff859..d41928cd55aea 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -75,6 +75,10 @@ class ModelCheckpoint(Callback): saved (``model.save_weights(filepath)``), else the full model is saved (``model.save(filepath)``). period: Interval (number of epochs) between checkpoints. + prefix: A string to put at the beginning of checkpoint filename. + + .. warning:: + This argument has been deprecated in v1.1 and will be removed in v1.3 dirpath: directory to save the model file. @@ -167,6 +171,12 @@ def __init__( if save_top_k is None and monitor is not None: self.save_top_k = 1 + if prefix: + rank_zero_warn( + 'Argument `prefix` is deprecated in v1.1 and will be removed in v1.3.' + ' Please prepend your prefix in `filename` instead.', DeprecationWarning + ) + self.__init_monitor_mode(monitor, mode) self.__init_ckpt_dir(filepath, dirpath, filename, save_top_k) self.__validate_init_configuration() @@ -380,7 +390,11 @@ def _format_checkpoint_name( if name not in metrics: metrics[name] = 0 filename = filename.format(**metrics) - return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt]) + + if prefix: + filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) + + return filename def format_checkpoint_name( self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index 2956b95313ddd..2369922c31a7c 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -19,6 +19,10 @@ def test_tbd_remove_in_v1_3_0(tmpdir): callback = ModelCheckpoint() Trainer(checkpoint_callback=callback, callbacks=[], default_root_dir=tmpdir) + # Deprecate prefix + with pytest.deprecated_call(match='will be removed in v1.3'): + callback = ModelCheckpoint(prefix='temp') + def test_tbd_remove_in_v1_2_0(): with pytest.deprecated_call(match='will be removed in v1.2'):