Skip to content

Commit

Permalink
Change error to warning if state_dict is empty in `load_from_checkpoi…
Browse files Browse the repository at this point in the history
…nt` (#18266)

Co-authored-by: Adrian Wälchli <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 10, 2023
1 parent c83774a commit e24620c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/lightning/pytorch/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _load_from_checkpoint(
model = _load_state(cls, checkpoint, strict=strict, **kwargs)
state_dict = checkpoint["state_dict"]
if not state_dict:
raise ValueError(f"The state dict in {checkpoint_path!r} contains no parameters.")
rank_zero_warn(f"The state dict in {checkpoint_path!r} contains no parameters.")
return model

device = next((t for t in state_dict.values() if isinstance(t, torch.Tensor)), torch.tensor(0)).device
assert isinstance(model, pl.LightningModule)
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_pytorch/core/test_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,16 @@ def set_extra_state(self, state):
create_boring_checkpoint(tmp_path, ExtraStateModel(), accelerator="cuda")
model = ExtraStateModel.load_from_checkpoint(f"{tmp_path}/checkpoint.ckpt", map_location=None)
assert model.device.type == "cuda"


def test_load_from_checkpoint_warn_on_empty_state_dict(tmp_path):
"""Test that checkpoints can be loaded with an empty state dict and that the appropriate warning is raised."""
create_boring_checkpoint(tmp_path, BoringModel(), accelerator="cpu")
# Now edit so the state_dict is empty
checkpoint = torch.load(tmp_path / "checkpoint.ckpt")
checkpoint["state_dict"] = {}
torch.save(checkpoint, tmp_path / "checkpoint.ckpt")

with pytest.warns(UserWarning, match="contains no parameters"):
model = BoringModel.load_from_checkpoint(tmp_path / "checkpoint.ckpt", strict=False)
assert model.device.type == "cpu"

0 comments on commit e24620c

Please sign in to comment.