From 8b25d8c8f31cc55b5e264e3b0f3b81214e621ad3 Mon Sep 17 00:00:00 2001 From: Matt Painter Date: Tue, 25 Feb 2020 03:33:11 +0000 Subject: [PATCH] Fix/test pass overrides (#918) * Fix test requiring both test_step and test_end * Add test Co-authored-by: William Falcon --- pytorch_lightning/trainer/evaluation_loop.py | 4 +-- tests/test_trainer.py | 34 ++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 37847393a10ccb..db514cdd7451df 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -301,8 +301,8 @@ def evaluate(self, model, dataloaders, max_batches, test=False): def run_evaluation(self, test=False): # when testing make sure user defined a test step - if test and not (self.is_overriden('test_step')): - m = '''You called `.test()` without defining model's `.test_step()`. + if test and not (self.is_overriden('test_step') or self.is_overriden('test_end')): + m = '''You called `.test()` without defining model's `.test_step()` or `.test_end()`. Please define and try again''' raise MisconfigurationException(m) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bc5b2979de781b..595dc4255392e8 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -782,5 +782,39 @@ def test_trainer_min_steps_and_epochs(tmpdir): trainer.current_epoch > 0, "Model did not train for at least min_steps" +def test_testpass_overrides(tmpdir): + hparams = tutils.get_hparams() + from pytorch_lightning.utilities.debugging import MisconfigurationException + + class TestModelNoEnd(LightningTestModelBase): + def test_step(self, *args, **kwargs): + return {} + + def test_dataloader(self): + return self.train_dataloader() + + class TestModelNoStep(LightningTestModelBase): + def test_end(self, outputs): + return {} + + def test_dataloader(self): + return self.train_dataloader() + + # Misconfig when neither test_step or test_end is implemented + with pytest.raises(MisconfigurationException): + model = LightningTestModelBase(hparams) + Trainer().test(model) + + # No exceptions when one or both of test_step or test_end are implemented + model = TestModelNoStep(hparams) + Trainer().test(model) + + model = TestModelNoEnd(hparams) + Trainer().test(model) + + model = LightningTestModel(hparams) + Trainer().test(model) + + # if __name__ == '__main__': # pytest.main([__file__])