Skip to content

Commit db322f4

Browse files
rohitgr7ananthsubpre-commit-ci[bot]
authored
Deprecate checkpoint_callback from the Trainer constructor in favour of enable_checkpointing (#9754)
* enable_chekpointing * update codebase * chlog * update tests * fix warning * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 14fb076 commit db322f4

33 files changed

+130
-109
lines changed

CHANGELOG.md

+3
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
322322
- Deprecated Accelerator collective API `barrier`, `broadcast`, and `all_gather`, call `TrainingTypePlugin` collective API directly ([#9677](https://github.com/PyTorchLightning/pytorch-lightning/pull/9677))
323323

324324

325+
- Deprecated `checkpoint_callback` from the `Trainer` constructor in favour of `enable_checkpointing` ([#9754](https://github.com/PyTorchLightning/pytorch-lightning/pull/9754))
326+
327+
325328
- Deprecated the `LightningModule.on_post_move_to_device` method ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))
326329

327330

docs/source/common/hyperparameters.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ To recap, add ALL possible trainer flags to the argparser and init the ``Trainer
201201
trainer = Trainer.from_argparse_args(hparams)
202202
203203
# or if you need to pass in callbacks
204-
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])
204+
trainer = Trainer.from_argparse_args(hparams, enable_checkpointing=..., callbacks=[...])
205205
206206
----------
207207

docs/source/common/trainer.rst

+35-35
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,38 @@ Example::
528528
checkpoint_callback
529529
^^^^^^^^^^^^^^^^^^^
530530

531+
Deprecated: This has been deprecated in v1.5 and will be removed in v.17. Please use ``enable_checkpointing`` instead.
532+
533+
default_root_dir
534+
^^^^^^^^^^^^^^^^
535+
536+
.. raw:: html
537+
538+
<video width="50%" max-width="400px" controls
539+
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
540+
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>
541+
542+
|
543+
544+
Default path for logs and weights when no logger or
545+
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
546+
certain clusters you might want to separate where logs and checkpoints are
547+
stored. If you don't then use this argument for convenience. Paths can be local
548+
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
549+
will need to be set up to use remote filepaths.
550+
551+
.. testcode::
552+
553+
# default used by the Trainer
554+
trainer = Trainer(default_root_dir=os.getcwd())
555+
556+
distributed_backend
557+
^^^^^^^^^^^^^^^^^^^
558+
Deprecated: This has been renamed ``accelerator``.
559+
560+
enable_checkpointing
561+
^^^^^^^^^^^^^^^^^^^^
562+
531563
.. raw:: html
532564

533565
<video width="50%" max-width="400px" controls
@@ -542,11 +574,11 @@ To disable automatic checkpointing, set this to `False`.
542574

543575
.. code-block:: python
544576
545-
# default used by Trainer
546-
trainer = Trainer(checkpoint_callback=True)
577+
# default used by Trainer, saves the most recent model to a single checkpoint after each epoch
578+
trainer = Trainer(enable_checkpointing=True)
547579
548580
# turn off automatic checkpointing
549-
trainer = Trainer(checkpoint_callback=False)
581+
trainer = Trainer(enable_checkpointing=False)
550582
551583
552584
You can override the default behavior by initializing the :class:`~pytorch_lightning.callbacks.ModelCheckpoint`
@@ -563,38 +595,6 @@ See :doc:`Saving and Loading Weights <../common/weights_loading>` for how to cus
563595
# Add your callback to the callbacks list
564596
trainer = Trainer(callbacks=[checkpoint_callback])
565597

566-
567-
.. warning:: Passing a ModelCheckpoint instance to this argument is deprecated since
568-
v1.1 and will be unsupported from v1.3. Use `callbacks` argument instead.
569-
570-
571-
default_root_dir
572-
^^^^^^^^^^^^^^^^
573-
574-
.. raw:: html
575-
576-
<video width="50%" max-width="400px" controls
577-
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/default%E2%80%A8_root_dir.jpg"
578-
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/default_root_dir.mp4"></video>
579-
580-
|
581-
582-
Default path for logs and weights when no logger or
583-
:class:`pytorch_lightning.callbacks.ModelCheckpoint` callback passed. On
584-
certain clusters you might want to separate where logs and checkpoints are
585-
stored. If you don't then use this argument for convenience. Paths can be local
586-
paths or remote paths such as `s3://bucket/path` or 'hdfs://path/'. Credentials
587-
will need to be set up to use remote filepaths.
588-
589-
.. testcode::
590-
591-
# default used by the Trainer
592-
trainer = Trainer(default_root_dir=os.getcwd())
593-
594-
distributed_backend
595-
^^^^^^^^^^^^^^^^^^^
596-
Deprecated: This has been renamed ``accelerator``.
597-
598598
fast_dev_run
599599
^^^^^^^^^^^^
600600

pytorch_lightning/trainer/connectors/callback_connector.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ def __init__(self, trainer):
3838
def on_trainer_init(
3939
self,
4040
callbacks: Optional[Union[List[Callback], Callback]],
41-
checkpoint_callback: bool,
41+
checkpoint_callback: Optional[bool],
42+
enable_checkpointing: bool,
4243
enable_progress_bar: bool,
4344
progress_bar_refresh_rate: Optional[int],
4445
process_position: int,
@@ -67,7 +68,7 @@ def on_trainer_init(
6768

6869
# configure checkpoint callback
6970
# pass through the required args to figure out defaults
70-
self._configure_checkpoint_callbacks(checkpoint_callback)
71+
self._configure_checkpoint_callbacks(checkpoint_callback, enable_checkpointing)
7172

7273
# configure swa callback
7374
self._configure_swa_callbacks()
@@ -140,22 +141,31 @@ def _configure_accumulated_gradients(
140141
self.trainer.accumulate_grad_batches = grad_accum_callback.get_accumulate_grad_batches(0)
141142
self.trainer.accumulation_scheduler = grad_accum_callback
142143

143-
def _configure_checkpoint_callbacks(self, checkpoint_callback: bool) -> None:
144+
def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], enable_checkpointing: bool) -> None:
145+
if checkpoint_callback is not None:
146+
rank_zero_deprecation(
147+
f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
148+
f"be removed in v1.7. Please consider using `Trainer(enable_checkpointing={checkpoint_callback})`."
149+
)
150+
# if both are set then checkpoint only if both are True
151+
enable_checkpointing = checkpoint_callback and enable_checkpointing
152+
144153
# TODO: Remove this error in v1.5 so we rely purely on the type signature
145-
if not isinstance(checkpoint_callback, bool):
154+
if not isinstance(enable_checkpointing, bool):
146155
error_msg = (
147-
"Invalid type provided for checkpoint_callback:"
148-
f" Expected bool but received {type(checkpoint_callback)}."
156+
"Invalid type provided for `enable_checkpointing`: "
157+
f"Expected bool but received {type(enable_checkpointing)}."
149158
)
150-
if isinstance(checkpoint_callback, Callback):
159+
if isinstance(enable_checkpointing, Callback):
151160
error_msg += " Pass callback instances to the `callbacks` argument in the Trainer constructor instead."
152161
raise MisconfigurationException(error_msg)
153-
if self._trainer_has_checkpoint_callbacks() and checkpoint_callback is False:
162+
if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False:
154163
raise MisconfigurationException(
155-
"Trainer was configured with checkpoint_callback=False but found ModelCheckpoint in callbacks list."
164+
"Trainer was configured with `enable_checkpointing=False`"
165+
" but found `ModelCheckpoint` in callbacks list."
156166
)
157167

158-
if not self._trainer_has_checkpoint_callbacks() and checkpoint_callback is True:
168+
if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
159169
self.trainer.callbacks.append(ModelCheckpoint())
160170

161171
def _configure_model_summary_callback(self, weights_summary: Optional[str] = None) -> None:

pytorch_lightning/trainer/trainer.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ class Trainer(
120120
def __init__(
121121
self,
122122
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
123-
checkpoint_callback: bool = True,
123+
checkpoint_callback: Optional[bool] = None,
124+
enable_checkpointing: bool = True,
124125
callbacks: Optional[Union[List[Callback], Callback]] = None,
125126
default_root_dir: Optional[str] = None,
126127
gradient_clip_val: Union[int, float] = 0.0,
@@ -215,6 +216,12 @@ def __init__(
215216
callbacks: Add a callback or list of callbacks.
216217
217218
checkpoint_callback: If ``True``, enable checkpointing.
219+
220+
.. deprecated:: v1.5
221+
``checkpoint_callback`` has been deprecated in v1.5 and will be removed in v1.7.
222+
Please consider using ``enable_checkpointing`` instead.
223+
224+
enable_checkpointing: If ``True``, enable checkpointing.
218225
It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
219226
:paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`.
220227
@@ -465,6 +472,7 @@ def __init__(
465472
self.callback_connector.on_trainer_init(
466473
callbacks,
467474
checkpoint_callback,
475+
enable_checkpointing,
468476
enable_progress_bar,
469477
progress_bar_refresh_rate,
470478
process_position,

tests/accelerators/test_tpu_backend.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_resume_training_on_cpu(tmpdir):
5151
"""Checks if training can be resumed from a saved checkpoint on CPU."""
5252
# Train a model on TPU
5353
model = BoringModel()
54-
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=8)
54+
trainer = Trainer(max_epochs=1, tpu_cores=8)
5555
trainer.fit(model)
5656

5757
model_path = trainer.checkpoint_callback.best_model_path
@@ -62,9 +62,7 @@ def test_resume_training_on_cpu(tmpdir):
6262
assert weight_tensor.device == torch.device("cpu")
6363

6464
# Verify that training is resumed on CPU
65-
trainer = Trainer(
66-
resume_from_checkpoint=model_path, checkpoint_callback=True, max_epochs=1, default_root_dir=tmpdir
67-
)
65+
trainer = Trainer(resume_from_checkpoint=model_path, max_epochs=1, default_root_dir=tmpdir)
6866
trainer.fit(model)
6967
assert trainer.state.finished, f"Training failed with {trainer.state}"
7068

tests/callbacks/test_callbacks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def configure_callbacks(self):
3535

3636
model = TestModel()
3737
trainer_options = dict(
38-
default_root_dir=tmpdir, checkpoint_callback=False, fast_dev_run=True, enable_progress_bar=False
38+
default_root_dir=tmpdir, enable_checkpointing=False, fast_dev_run=True, enable_progress_bar=False
3939
)
4040

4141
def assert_expected_calls(_trainer, model_callback, trainer_callback):
@@ -86,7 +86,7 @@ def configure_callbacks(self):
8686
return [model_callback_mock]
8787

8888
model = TestModel()
89-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
89+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)
9090

9191
callbacks_before_fit = trainer.callbacks.copy()
9292
assert callbacks_before_fit

tests/callbacks/test_early_stopping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_early_stopping_no_extraneous_invocations(tmpdir):
111111
limit_train_batches=4,
112112
limit_val_batches=4,
113113
max_epochs=expected_count,
114-
checkpoint_callback=False,
114+
enable_checkpointing=False,
115115
)
116116
trainer.fit(model, datamodule=dm)
117117

tests/callbacks/test_lr_monitor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
390390
callbacks=[TestFinetuning(), lr_monitor, Check()],
391391
enable_progress_bar=False,
392392
weights_summary=None,
393-
checkpoint_callback=False,
393+
enable_checkpointing=False,
394394
)
395395
model = TestModel()
396396
model.training_epoch_end = None

tests/callbacks/test_progress_bar.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def on_validation_epoch_end(self, *args):
263263
limit_val_batches=limit_val_batches,
264264
callbacks=[progress_bar],
265265
logger=False,
266-
checkpoint_callback=False,
266+
enable_checkpointing=False,
267267
)
268268
trainer.fit(model)
269269

@@ -342,7 +342,7 @@ def test_main_progress_bar_update_amount(
342342
limit_val_batches=val_batches,
343343
callbacks=[progress_bar],
344344
logger=False,
345-
checkpoint_callback=False,
345+
enable_checkpointing=False,
346346
)
347347
trainer.fit(model)
348348
if train_batches > 0:
@@ -362,7 +362,7 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate
362362
limit_test_batches=test_batches,
363363
callbacks=[progress_bar],
364364
logger=False,
365-
checkpoint_callback=False,
365+
enable_checkpointing=False,
366366
)
367367
trainer.test(model)
368368
progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas])
@@ -379,7 +379,7 @@ def training_step(self, batch, batch_idx):
379379
return super().training_step(batch, batch_idx)
380380

381381
trainer = Trainer(
382-
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, checkpoint_callback=False
382+
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False
383383
)
384384
trainer.fit(TestModel())
385385

tests/callbacks/test_pruning.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def train_with_pruning_callback(
109109
default_root_dir=tmpdir,
110110
enable_progress_bar=False,
111111
weights_summary=None,
112-
checkpoint_callback=False,
112+
enable_checkpointing=False,
113113
logger=False,
114114
limit_train_batches=10,
115115
limit_val_batches=2,
@@ -227,7 +227,7 @@ def apply_lottery_ticket_hypothesis(self):
227227
default_root_dir=tmpdir,
228228
enable_progress_bar=False,
229229
weights_summary=None,
230-
checkpoint_callback=False,
230+
enable_checkpointing=False,
231231
logger=False,
232232
limit_train_batches=10,
233233
limit_val_batches=2,
@@ -254,7 +254,7 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent: bool
254254
default_root_dir=tmpdir,
255255
enable_progress_bar=False,
256256
weights_summary=None,
257-
checkpoint_callback=False,
257+
enable_checkpointing=False,
258258
logger=False,
259259
limit_train_batches=10,
260260
limit_val_batches=2,

tests/callbacks/test_quantization.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
5656
assert torch.allclose(org_score, quant_score, atol=0.45)
5757
model_path = trainer.checkpoint_callback.best_model_path
5858

59-
trainer_args.update(dict(max_epochs=1, checkpoint_callback=False))
59+
trainer_args.update(dict(max_epochs=1, enable_checkpointing=False))
6060
if not convert:
6161
trainer = Trainer(callbacks=[QuantizationAwareTraining()], **trainer_args)
6262
trainer.fit(qmodel, datamodule=dm)

tests/checkpointing/test_checkpoint_callback_frequency.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from tests.helpers.runif import RunIf
2323

2424

25-
def test_checkpoint_callback_disabled(tmpdir):
25+
def test_disabled_checkpointing(tmpdir):
2626
# no callback
27-
trainer = Trainer(max_epochs=3, checkpoint_callback=False)
27+
trainer = Trainer(max_epochs=3, enable_checkpointing=False)
2828
assert not trainer.checkpoint_callbacks
2929
trainer.fit(BoringModel())
3030
assert not trainer.checkpoint_callbacks

tests/checkpointing/test_legacy_checkpoints.py

-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,6 @@ def test_resume_legacy_checkpoints(tmpdir, pl_version: str):
7979
default_root_dir=str(tmpdir),
8080
gpus=int(torch.cuda.is_available()),
8181
precision=(16 if torch.cuda.is_available() else 32),
82-
checkpoint_callback=True,
8382
callbacks=[es, stop],
8483
max_epochs=21,
8584
accumulate_grad_batches=2,

tests/checkpointing/test_model_checkpoint.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -998,17 +998,17 @@ def test_configure_model_checkpoint(tmpdir):
998998
callback2 = ModelCheckpoint()
999999

10001000
# no callbacks
1001-
trainer = Trainer(checkpoint_callback=False, callbacks=[], **kwargs)
1001+
trainer = Trainer(enable_checkpointing=False, callbacks=[], **kwargs)
10021002
assert not any(isinstance(c, ModelCheckpoint) for c in trainer.callbacks)
10031003
assert trainer.checkpoint_callback is None
10041004

10051005
# default configuration
1006-
trainer = Trainer(checkpoint_callback=True, callbacks=[], **kwargs)
1006+
trainer = Trainer(callbacks=[], **kwargs)
10071007
assert sum(1 for c in trainer.callbacks if isinstance(c, ModelCheckpoint)) == 1
10081008
assert isinstance(trainer.checkpoint_callback, ModelCheckpoint)
10091009

1010-
# custom callback passed to callbacks list, checkpoint_callback=True is ignored
1011-
trainer = Trainer(checkpoint_callback=True, callbacks=[callback1], **kwargs)
1010+
# custom callback passed to callbacks list, enable_checkpointing=True is ignored
1011+
trainer = Trainer(enable_checkpointing=True, callbacks=[callback1], **kwargs)
10121012
assert [c for c in trainer.callbacks if isinstance(c, ModelCheckpoint)] == [callback1]
10131013
assert trainer.checkpoint_callback == callback1
10141014

@@ -1017,8 +1017,8 @@ def test_configure_model_checkpoint(tmpdir):
10171017
assert trainer.checkpoint_callback == callback1
10181018
assert trainer.checkpoint_callbacks == [callback1, callback2]
10191019

1020-
with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
1021-
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)
1020+
with pytest.raises(MisconfigurationException, match="`enable_checkpointing=False` but found `ModelCheckpoint`"):
1021+
Trainer(enable_checkpointing=False, callbacks=[callback1], **kwargs)
10221022

10231023

10241024
def test_val_check_interval_checkpoint_files(tmpdir):
@@ -1189,8 +1189,8 @@ def test_model_checkpoint_mode_options():
11891189

11901190
def test_trainer_checkpoint_callback_bool(tmpdir):
11911191
mc = ModelCheckpoint(dirpath=tmpdir)
1192-
with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"):
1193-
Trainer(checkpoint_callback=mc)
1192+
with pytest.raises(MisconfigurationException, match="Invalid type provided for `enable_checkpointing`"):
1193+
Trainer(enable_checkpointing=mc)
11941194

11951195

11961196
def test_check_val_every_n_epochs_top_k_integration(tmpdir):

0 commit comments

Comments
 (0)