Skip to content

Commit

Permalink
Add a first batch of tests for Trainer.validate(…)
Browse files Browse the repository at this point in the history
  • Loading branch information
EliaCereda committed Nov 18, 2020
1 parent 860fef5 commit 3b5ae9b
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/trainer/test_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,23 @@ def test_test_loop_config(tmpdir):
model = EvalModelTemplate(**hparams)
model.test_step = None
trainer.test(model, test_dataloaders=model.dataloader(train=False))


def test_validation_loop_config(tmpdir):
""""
When either validation loop or validation data are missing
"""
hparams = EvalModelTemplate.get_default_hparams()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)

# has val loop but no val data
with pytest.warns(UserWarning):
model = EvalModelTemplate(**hparams)
model.val_dataloader = None
trainer.validate(model)

# has val data but no val loop
with pytest.warns(UserWarning):
model = EvalModelTemplate(**hparams)
model.validation_step = None
trainer.validate(model, val_dataloaders=model.dataloader(train=False))
39 changes: 39 additions & 0 deletions tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,45 @@ def test_step(self, batch, batch_idx, *args, **kwargs):
trainer.test(ckpt_path=ckpt_path)


@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
def test_multiple_validate_dataloader(tmpdir, ckpt_path):
"""Verify multiple val_dataloaders."""

model_template = EvalModelTemplate()

class MultipleValDataloaderModel(EvalModelTemplate):
def val_dataloader(self):
return model_template.val_dataloader__multiple()

def validation_step(self, batch, batch_idx, *args, **kwargs):
return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs)

model = MultipleValDataloaderModel()

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
)
trainer.fit(model)
if ckpt_path == 'specific':
ckpt_path = trainer.checkpoint_callback.best_model_path
trainer.validate(ckpt_path=ckpt_path)

# verify there are 2 test loaders
assert len(trainer.val_dataloaders) == 2, \
'Multiple val_dataloaders not initiated properly'

# make sure predictions are good for each test set
for dataloader in trainer.val_dataloaders:
tpipes.run_prediction(dataloader, trainer.model)

# run the test method
trainer.validate(ckpt_path=ckpt_path)


def test_train_dataloader_passed_to_fit(tmpdir):
"""Verify that train dataloader can be passed to fit """

Expand Down
18 changes: 18 additions & 0 deletions tests/trainer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,24 @@ def test_init_optimizers_during_testing(tmpdir):
assert len(trainer.optimizer_frequencies) == 0


def test_init_optimizers_during_validation(tmpdir):
"""
Test that optimizers is an empty list during validation.
"""
model = EvalModelTemplate()
model.configure_optimizers = model.configure_optimizers__multiple_schedulers

trainer = Trainer(
default_root_dir=tmpdir,
limit_test_batches=10
)
trainer.validate(model, ckpt_path=None)

assert len(trainer.lr_schedulers) == 0
assert len(trainer.optimizers) == 0
assert len(trainer.optimizer_frequencies) == 0


def test_multiple_optimizers_callbacks(tmpdir):
"""
Tests that multiple optimizers can be used with callbacks
Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,40 @@ def test_finished_state_after_test(tmpdir):
assert trainer.state == TrainerState.FINISHED


def test_running_state_during_validation(tmpdir):
""" Tests that state is set to RUNNING during test """

hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start')

trainer = Trainer(
callbacks=[snapshot_callback],
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.validate(model)

assert snapshot_callback.trainer_state == TrainerState.RUNNING


def test_finished_state_after_validation(tmpdir):
""" Tests that state is FINISHED after fit """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)

trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
)

trainer.validate(model)

assert trainer.state == TrainerState.FINISHED


@pytest.mark.parametrize("extra_params", [
pytest.param(dict(fast_dev_run=True), id='Fast-Run'),
pytest.param(dict(max_steps=1), id='Single-Step'),
Expand Down
45 changes: 45 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,47 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
assert trainer.tested_ckpt_path == ckpt_path


@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"])
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k):
hparams = EvalModelTemplate.get_default_hparams()

model = EvalModelTemplate(**hparams)
trainer = Trainer(
max_epochs=2,
progress_bar_refresh_rate=0,
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k),
)
trainer.fit(model)
if ckpt_path == "best":
# ckpt_path is 'best', meaning we load the best weights
if save_top_k == 0:
with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"):
trainer.validate(ckpt_path=ckpt_path)
else:
trainer.validate(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path
elif ckpt_path is None:
# ckpt_path is None, meaning we don't load any checkpoints and
# use the weights from the end of training
trainer.validate(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path is None
else:
# specific checkpoint, pick one from saved ones
if save_top_k == 0:
with pytest.raises(FileNotFoundError):
trainer.validate(ckpt_path="random.ckpt")
else:
ckpt_path = str(
list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[
0
].absolute()
)
trainer.validate(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == ckpt_path


def test_disabled_training(tmpdir):
"""Verify that `limit_train_batches=0` disables the training loop unless `fast_dev_run=True`."""

Expand Down Expand Up @@ -1448,6 +1489,10 @@ def setup(self, model, stage):
assert trainer.stage == "test"
assert trainer.get_model().stage == "test"

trainer.validate(ckpt_path=None)
assert trainer.stage == "validation"
assert trainer.get_model().stage == "validation"


@pytest.mark.parametrize(
"train_batches, max_steps, log_interval",
Expand Down
76 changes: 76 additions & 0 deletions tests/trainer/test_trainer_validate_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch

import pytorch_lightning as pl
import tests.base.develop_utils as tutils
from tests.base import EvalModelTemplate


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_single_gpu_validate(tmpdir):
tutils.set_random_master_port()

model = EvalModelTemplate()
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0],
)
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.validate()
assert 'val_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.validate(model)
assert 'val_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_ddp_spawn_validate(tmpdir):
tutils.set_random_master_port()

model = EvalModelTemplate()
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
distributed_backend='ddp_spawn',
)
trainer.fit(model)
assert 'ckpt' in trainer.checkpoint_callback.best_model_path
results = trainer.validate()
assert 'val_acc' in results[0]

old_weights = model.c_d1.weight.clone().detach().cpu()

results = trainer.validate(model)
assert 'val_acc' in results[0]

# make sure weights didn't change
new_weights = model.c_d1.weight.clone().detach().cpu()

assert torch.all(torch.eq(old_weights, new_weights))

0 comments on commit 3b5ae9b

Please sign in to comment.