Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Feb 24, 2021
1 parent 5ffdbd3 commit dcb14bb
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 34 deletions.
4 changes: 4 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def _format_checkpoint_name(
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
Expand All @@ -351,6 +352,9 @@ def _format_checkpoint_name(
metrics[name] = 0
filename = filename.format(**metrics)

if prefix:
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])

return filename

def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None) -> str:
Expand Down
3 changes: 1 addition & 2 deletions tests/accelerators/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def test_unsupported_precision_plugins():
trainer = Mock()
model = Mock()
accelerator = CPUAccelerator(
training_type_plugin=SingleDevicePlugin(torch.device("cpu")),
precision_plugin=MixedPrecisionPlugin()
training_type_plugin=SingleDevicePlugin(torch.device("cpu")), precision_plugin=MixedPrecisionPlugin()
)
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
accelerator.setup(trainer=trainer, model=model)
4 changes: 2 additions & 2 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')

# with version
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test')
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name')
ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
assert ckpt_name == tmpdir / 'name-v3.ckpt'

# using slashes
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}')
Expand Down
34 changes: 8 additions & 26 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,34 +421,17 @@ def test_dp_output_reduce():


@pytest.mark.parametrize(
["save_top_k", "save_last", "file_prefix", "expected_files"],
["save_top_k", "save_last", "expected_files"],
[
pytest.param(
-1,
False,
"",
{"epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt", "epoch=1.ckpt", "epoch=0.ckpt"},
id="CASE K=-1 (all)",
),
pytest.param(1, False, "test_prefix", {"test_prefix-epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
pytest.param(2, False, "", {"epoch=4.ckpt", "epoch=2.ckpt"}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
pytest.param(
4,
False,
"",
{"epoch=1.ckpt", "epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt"},
id="CASE K=4 (save all 4 base)",
),
pytest.param(
3,
False,
"", {"epoch=2.ckpt", "epoch=3.ckpt", "epoch=4.ckpt"},
id="CASE K=3 (save the 2nd, 3rd, 4th model)"
),
pytest.param(1, True, "", {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"),
pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"),
pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"),
pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
],
)
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files):
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files):
"""Test ModelCheckpoint options."""

def mock_save_function(filepath, *args):
Expand All @@ -463,7 +446,6 @@ def mock_save_function(filepath, *args):
monitor='checkpoint_on',
save_top_k=save_top_k,
save_last=save_last,
prefix=file_prefix,
verbose=1
)
checkpoint_callback.save_function = mock_save_function
Expand Down
8 changes: 4 additions & 4 deletions tests/utilities/test_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def test_lightning_getattr(tmpdir):

for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_getattr(m, "this_attr_not_exist")

Expand All @@ -140,7 +140,7 @@ def test_lightning_setattr(tmpdir):

for m in models:
with pytest.raises(
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
AttributeError,
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
):
lightning_setattr(m, "this_attr_not_exist", None)

0 comments on commit dcb14bb

Please sign in to comment.