Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise a MisconfigurationException when trainer functions are called with ckpt_path="best" but checkpoint_callback isn't configured #9841

Merged
merged 5 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled automatic parameters tying for TPUs ([#9525](https://github.com/PyTorchLightning/pytorch-lightning/pull/9525))


- Added a check to look for `checkpoint_callback` if `ckpt_path="best"` ([#9841](https://github.com/PyTorchLightning/pytorch-lightning/pull/9841))
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved


### Changed

- `pytorch_lightning.loggers.neptune.NeptuneLogger` is now consistent with new [neptune-client](https://github.com/neptune-ai/neptune-client) API ([#6867](https://github.com/PyTorchLightning/pytorch-lightning/pull/6867)).
Expand Down
22 changes: 15 additions & 7 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,8 @@ def validate(

ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if a checkpoint callback is configured.

verbose: If True, prints the validation results.

Expand Down Expand Up @@ -740,7 +741,8 @@ def test(

ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if a checkpoint callback is configured.

verbose: If True, prints the test results.

Expand Down Expand Up @@ -834,7 +836,8 @@ def predict(

ckpt_path: Either ``best`` or path to the checkpoint you wish to predict.
If ``None`` and the model instance was passed, use the current weights.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded.
Otherwise, the best model from the previous ``trainer.fit`` call will be loaded
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
if a checkpoint callback is configured.

Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
Expand Down Expand Up @@ -1258,15 +1261,20 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_

if model_connected and ckpt_path is None:
rank_zero_warn(
f"`.{fn}(ckpt_path=None)` was called without a model. "
"The best model of the previous `fit` call will be used. "
f"You can pass `{fn}(ckpt_path='best')` to avoid this warning "
"or `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
f"`.{fn}(ckpt_path=None)` was called without a model."
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
" checkpoint and avoid this warning or"
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
)
ckpt_path = "best"

if ckpt_path == "best":
# if user requests the best checkpoint but we don't have it, error
if self.checkpoint_callback is None:
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
Expand Down
44 changes: 44 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,50 @@ def predict_step(self, batch, *_):
assert getattr(trainer, path_attr) == ckpt_path


@pytest.mark.parametrize("checkpoint_callback", (False, True))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
def test_tested_checkpoint_path_best(tmpdir, checkpoint_callback, fn):
class TestModel(BoringModel):
def validation_step(self, batch, batch_idx):
self.log("foo", -batch_idx)
return super().validation_step(batch, batch_idx)

def test_step(self, *args):
return self.validation_step(*args)

def predict_step(self, batch, *_):
return self(batch)

model = TestModel()
model.test_epoch_end = None
trainer = Trainer(
max_epochs=2,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
enable_progress_bar=False,
default_root_dir=tmpdir,
checkpoint_callback=checkpoint_callback,
)
trainer.fit(model)

trainer_fn = getattr(trainer, fn)
path_attr = f"{fn}{'d' if fn == 'validate' else 'ed'}_ckpt_path"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
assert getattr(trainer, path_attr) is None

if checkpoint_callback:
trainer_fn(ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path

trainer_fn(model, ckpt_path="best")
assert getattr(trainer, path_attr) == trainer.checkpoint_callback.best_model_path
else:
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
trainer_fn(ckpt_path="best")
with pytest.raises(MisconfigurationException, match="`ModelCheckpoint` is not configured."):
trainer_fn(model, ckpt_path="best")


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""

Expand Down