diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c17cdc06cc19..bd60b96c8f106 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -70,6 +70,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Pytorch Geometric` integration example with Lightning ([#4568](https://github.com/PyTorchLightning/pytorch-lightning/pull/4568)) +- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ( + [#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707)) + + ### Changed - Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903)) diff --git a/docs/source/trainer.rst b/docs/source/trainer.rst index c390db8d7537e..04ac191458c1e 100644 --- a/docs/source/trainer.rst +++ b/docs/source/trainer.rst @@ -148,6 +148,19 @@ So you can run it like so: ------------ +Validation +---------- +You can perform an evaluation epoch over the validation set, outside of the training loop, +using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be +useful if you want to collect new metrics from a model right at its initialization +or that has already been trained. + +.. code-block:: python + + trainer.validate(val_dataloaders=val_dataloaders) + +------------ + Testing ------- Once you're done training, feel free to run the test set! @@ -155,7 +168,7 @@ Once you're done training, feel free to run the test set! .. code-block:: python - trainer.test(test_dataloader=test_dataloader) + trainer.test(test_dataloaders=test_dataloaders) ------------ diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 931a39e07af89..75e46dbce83dd 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -59,9 +59,9 @@ def barrier(self, name: Optional[str] = None): def broadcast(self, obj, src=0): return obj - def train_or_test(self): - if self.trainer.testing: - results = self.trainer.run_test() + def train_or_evaluate(self): + if self.trainer.evaluating: + results = self.trainer.run_test_or_validate() else: results = self.trainer.train() return results @@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module): return self.trainer.should_stop def setup_optimizers(self, model): - if self.trainer.testing is True: + if self.trainer.evaluating: return optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model) diff --git a/pytorch_lightning/accelerators/cpu_accelerator.py b/pytorch_lightning/accelerators/cpu_accelerator.py index fe0ab59fb554f..279b6327bba5a 100644 --- a/pytorch_lightning/accelerators/cpu_accelerator.py +++ b/pytorch_lightning/accelerators/cpu_accelerator.py @@ -57,8 +57,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index f43866881cabb..0acc5d6b65339 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -181,8 +181,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 687b5c21874fb..90347a60a4566 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -275,8 +275,8 @@ def ddp_train(self, process_idx, model): self.barrier('ddp_setup') self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 982da2f53216b..879ad3cdb8b74 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -145,8 +145,8 @@ def ddp_train(self, process_idx, mp_queue, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 28817c6845f5b..316fac61ca732 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -174,8 +174,8 @@ def ddp_train(self, process_idx, model): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # clean up memory torch.cuda.empty_cache() diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index a06d0b82d6d15..b871f6cbf0c6d 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -157,8 +157,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0 # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # get original model model = self.trainer.get_model() diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 4b4e1eac8a66c..214b4d88f03aa 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -106,8 +106,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() return results diff --git a/pytorch_lightning/accelerators/gpu_accelerator.py b/pytorch_lightning/accelerators/gpu_accelerator.py index b12d275c8ac26..e3f0fb9890809 100644 --- a/pytorch_lightning/accelerators/gpu_accelerator.py +++ b/pytorch_lightning/accelerators/gpu_accelerator.py @@ -62,8 +62,9 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() + return results def training_step(self, args): diff --git a/pytorch_lightning/accelerators/horovod_accelerator.py b/pytorch_lightning/accelerators/horovod_accelerator.py index b2cec906178f9..d4027c772e061 100644 --- a/pytorch_lightning/accelerators/horovod_accelerator.py +++ b/pytorch_lightning/accelerators/horovod_accelerator.py @@ -111,8 +111,8 @@ def train(self): # set up training routine self.trainer.train_loop.setup_training(self.trainer.model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # Make sure all workers have finished training before returning to the user hvd.join() diff --git a/pytorch_lightning/accelerators/tpu_accelerator.py b/pytorch_lightning/accelerators/tpu_accelerator.py index 30cf6c9dbf169..303066c5e5310 100644 --- a/pytorch_lightning/accelerators/tpu_accelerator.py +++ b/pytorch_lightning/accelerators/tpu_accelerator.py @@ -129,8 +129,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine # set up training routine self.trainer.train_loop.setup_training(model) - # train or test - results = self.train_or_test() + # train or evaluate + results = self.train_or_evaluate() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 3f6b4ffe9622a..8ca0ef301c260 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -28,11 +28,11 @@ class Callback(abc.ABC): """ def setup(self, trainer, pl_module, stage: str): - """Called when fit or test begins""" + """Called when fit, validate, or test begins""" pass def teardown(self, trainer, pl_module, stage: str): - """Called when fit or test ends""" + """Called when fit, validate, or test ends""" pass def on_init_start(self, trainer): diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 005a3f8cde4ad..3a2b5c2a57259 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -134,13 +134,13 @@ def on_load_checkpoint(self, checkpointed_state): self.patience = checkpointed_state['patience'] def on_validation_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return self._run_early_stopping_check(trainer, pl_module) def on_validation_epoch_end(self, trainer, pl_module): - if trainer.running_sanity_check: + if trainer.running_sanity_check or trainer.evaluating: return if self._validate_condition_metric(trainer.logger_connector.callback_metrics): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index d41928cd55aea..0efaef9c660b7 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module): or self.period < 1 # no models are saved or (epoch + 1) % self.period # skip epoch or trainer.running_sanity_check # don't save anything during sanity check + or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated or self.last_global_step_saved == global_step # already saved at the last step ): return diff --git a/pytorch_lightning/callbacks/progress.py b/pytorch_lightning/callbacks/progress.py index 6582f16fd27be..b00dca548671f 100644 --- a/pytorch_lightning/callbacks/progress.py +++ b/pytorch_lightning/callbacks/progress.py @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm: def init_validation_tqdm(self) -> tqdm: """ Override this to customize the tqdm bar for validation. """ + + # The main progress bar doesn't exist in trainer.validate(...) + has_main_bar = int(self.main_progress_bar is not None) + bar = tqdm( desc='Validating', - position=(2 * self.process_position + 1), + position=(2 * self.process_position + has_main_bar), disable=self.is_disabled, leave=False, dynamic_ncols=True, @@ -341,7 +345,10 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) if not trainer.running_sanity_check: - self._update_bar(self.main_progress_bar) # fill up remaining + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self._update_bar(self.main_progress_bar) # fill up remaining + self.val_progress_bar = self.init_validation_tqdm() self.val_progress_bar.total = convert_inf(self.total_val_batches) @@ -349,11 +356,18 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if self._should_update(self.val_batch_idx, self.total_val_batches): self._update_bar(self.val_progress_bar) - self._update_bar(self.main_progress_bar) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self._update_bar(self.main_progress_bar) def on_validation_end(self, trainer, pl_module): super().on_validation_end(trainer, pl_module) - self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + + # The main progress bar doesn't exist in trainer.validate(...) + if self.main_progress_bar is not None: + self.main_progress_bar.set_postfix(trainer.progress_bar_dict) + self.val_progress_bar.close() def on_train_end(self, trainer, pl_module): diff --git a/pytorch_lightning/core/datamodule.py b/pytorch_lightning/core/datamodule.py index fe81d641c86d6..3ff9f4cf889d4 100644 --- a/pytorch_lightning/core/datamodule.py +++ b/pytorch_lightning/core/datamodule.py @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs): if fn.__name__ == "setup": # Get stage either by grabbing from args or checking kwargs. - # If not provided, set call status of 'fit' and 'test' to True. + # If not provided, set call status of 'fit', 'validation', and 'test' to True. # We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test() stage = args[1] if len(args) > 1 else kwargs.get("stage", None) if stage == "fit" or stage is None: obj._has_setup_fit = True + if stage == "validation" or stage is None: + obj._has_setup_validation = True + if stage == "test" or stage is None: obj._has_setup_test = True @@ -155,6 +158,7 @@ def __init__( # Private attrs to keep track of whether or not data hooks have been called yet self._has_prepared_data = False self._has_setup_fit = False + self._has_setup_validation = False self._has_setup_test = False @property @@ -230,6 +234,15 @@ def has_setup_fit(self): """ return self._has_setup_fit + @property + def has_setup_validation(self): + """Return bool letting you know if datamodule.setup('validation') has been called or not. + + Returns: + bool: True if datamodule.setup('validation') has been called. False by default. + """ + return self._has_setup_validation + @property def has_setup_test(self): """Return bool letting you know if datamodule.setup('test') has been called or not. diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 57979b73f2cb6..a4251484991f2 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -26,12 +26,12 @@ class ModelHooks: """Hooks to be used in LightningModule.""" def setup(self, stage: str): """ - Called at the beginning of fit and test. + Called at the beginning of fit (training + validation), validation, and test. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation', or 'test' Example:: @@ -54,10 +54,10 @@ def setup(stage): def teardown(self, stage: str): """ - Called at the end of fit and test. + Called at the end of fit (training + validation), validation, and test. Args: - stage: either 'fit' or 'test' + stage: either 'fit', 'validation', or 'test' """ def on_fit_start(self): diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 01c0119e857ec..23967dc1bc2a9 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule): model: The model to check the configuration. """ - if not self.trainer.testing: + if not self.trainer.evaluating: self.__verify_train_loop_configuration(model) self.__verify_eval_loop_configuration(model, 'validation') else: - # check test loop configuration - self.__verify_eval_loop_configuration(model, 'test') + # check evaluation loop configurations + self.__verify_eval_loop_configuration(model, self.trainer.evaluating) def __verify_train_loop_configuration(self, model): # ----------------------------------- diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index cab08edd58531..33ff30380eabb 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -265,7 +265,7 @@ def prepare_eval_loop_results(self): for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders): self.add_to_eval_loop_results(dl_idx, has_been_initialized) - def get_evaluate_epoch_results(self, test_mode): + def get_evaluate_epoch_results(self): if not self.trainer.running_sanity_check: # log all the metrics as a single dict metrics_to_log = self.cached_results.get_epoch_log_metrics() @@ -274,11 +274,11 @@ def get_evaluate_epoch_results(self, test_mode): self.prepare_eval_loop_results() - # log results of test - if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test: + # log results of evaluation + if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate: print('-' * 80) for result_idx, results in enumerate(self.eval_loop_results): - print(f'DATALOADER:{result_idx} TEST RESULTS') + print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS') pprint(results) print('-' * 80) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index c5a8c48357b44..c665ee971b885 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -36,7 +36,10 @@ def copy_trainer_model_properties(self, model): m.use_ddp2 = self.trainer.use_ddp2 m.use_ddp = self.trainer.use_ddp m.use_amp = self.trainer.amp_backend is not None - m.testing = self.trainer.testing + # TODO: I only find usages of m.testing in DDP, where it's used to + # discriminate test from validation, as opposed to test from fit in + # Trainer. Still need to fully determine if it's correct. + m.testing = self.trainer.evaluating == 'test' m.use_single_gpu = self.trainer.use_single_gpu m.use_tpu = self.trainer.use_tpu m.tpu_local_core_rank = self.trainer.tpu_local_core_rank diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 097727a6bed78..11da428b83453 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +import pytorch_lightning as pl from pytorch_lightning.core.step_result import EvalResult, Result from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.distributed import rank_zero_warn @@ -22,7 +23,7 @@ class EvaluationLoop(object): - def __init__(self, trainer): + def __init__(self, trainer: 'pl.Trainer'): self.trainer = trainer self.testing = False self.outputs = [] @@ -39,13 +40,15 @@ def on_trainer_init(self): self.trainer.test_dataloaders = None self.trainer.val_dataloaders = None self.trainer.running_sanity_check = False - self.trainer.testing = False - # when .test() is called, it sets this - self.trainer.tested_ckpt_path = None + # .validate() sets this to 'validation' and .test() sets this to 'test' + self.trainer.evaluating = None - # when true, prints test results - self.trainer.verbose_test = True + # .validate() and .test() set this when they load a checkpoint + self.trainer.evaluated_ckpt_path = None + + # when true, print evaluation results in .validate() and .test() + self.trainer.verbose_evaluate = True def get_evaluation_dataloaders(self, max_batches): # select dataloaders @@ -216,7 +219,7 @@ def evaluation_epoch_end(self): def log_epoch_metrics_on_evaluation_end(self): # get the final loop results - eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results(self.testing) + eval_loop_results = self.trainer.logger_connector.get_evaluate_epoch_results() return eval_loop_results def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index fd715988ef370..1bad441eb0083 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -441,10 +441,6 @@ def fit( # hook self.data_connector.prepare_data(model) - # bookkeeping - # we reuse fit in .test() but change its behavior using this flag - self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -659,11 +655,15 @@ def track_output_for_epoch_end(self, outputs, output): outputs.append(output) return outputs - def run_test(self): + def run_test_or_validate(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) - with self.profiler.profile("run_test_evaluation"): - eval_loop_results, _ = self.run_evaluation(test_mode=True) + if self.evaluating == 'test': + with self.profiler.profile("run_test_evaluation"): + eval_loop_results, _ = self.run_evaluation(test_mode=True) + else: + with self.profiler.profile("run_validate_evaluation"): + eval_loop_results, _ = self.run_evaluation(test_mode=False) if len(eval_loop_results) == 0: return 1 @@ -711,42 +711,90 @@ def run_sanity_check(self, ref_model): self.on_sanity_check_end() self.running_sanity_check = False - def test( + def validate( self, model: Optional[LightningModule] = None, - test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, ckpt_path: Optional[str] = 'best', verbose: bool = True, datamodule: Optional[LightningDataModule] = None, ): r""" - - Separates from fit to make sure you never run on your test set until you want to. + Perform one evaluation epoch over the validation set. Args: - ckpt_path: Either ``best`` or path to the checkpoint you wish to test. - If ``None``, use the weights from the last epoch to test. Default to ``best``. - + ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. + If ``None``, use the current weights of the model. Default to ``best``. datamodule: A instance of :class:`LightningDataModule`. + model: The model to evaluate. + val_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying validation samples. + verbose: If True, prints the validation results. + + Returns: + The dictionary with final validation results returned by validation_epoch_end. + If validation_epoch_end is not defined, the output is a list of the dictionaries + returned by validation_step. + """ + # -------------------- + # SETUP HOOK + # -------------------- + self.verbose_evaluate = verbose + + self.logger_connector.set_stage("validation") + + # If you supply a datamodule you can't supply val_dataloaders + if val_dataloaders and datamodule: + raise MisconfigurationException( + 'You cannot pass val_dataloaders to trainer.validate if you supply a datamodule' + ) + + # Attach datamodule to get setup/prepare_data added to model before the call to it below + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'validation') + + if model is not None: + results = self.__evaluate_given_model(model, val_dataloaders, 'validation') + else: + results = self.__evaluate_using_best_weights(ckpt_path, val_dataloaders, 'validation') + + self.teardown('validation') - model: The model to test. + return results - test_dataloaders: Either a single - Pytorch Dataloader or a list of them, specifying validation samples. + def test( + self, + model: Optional[LightningModule] = None, + test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None, + ckpt_path: Optional[str] = 'best', + verbose: bool = True, + datamodule: Optional[LightningDataModule] = None, + ): + r""" + Perform one evaluation epoch over the test set. It's separated from + fit to make sure you never run on your test set until you want to. - verbose: If True, prints the test results + Args: + ckpt_path: Either ``best`` or path to the checkpoint you wish to test. + If ``None``, use the current weights of the model. Default to ``best``. + datamodule: A instance of :class:`LightningDataModule`. + model: The model to evaluate. + test_dataloaders: Either a single PyTorch DataLoader or a list of them, + specifying test samples. + verbose: If True, prints the test results. Returns: - The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries + The dictionary with final test results returned by test_epoch_end. + If test_epoch_end is not defined, the output is a list of the dictionaries + returned by test_step. """ # -------------------- # SETUP HOOK # -------------------- - self.verbose_test = verbose + self.verbose_evaluate = verbose self.logger_connector.set_stage("test") - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + # If you supply a datamodule you can't supply test_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( 'You cannot pass test_dataloaders to trainer.test if you supply a datamodule' @@ -756,15 +804,15 @@ def test( self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: - results = self.__test_given_model(model, test_dataloaders) + results = self.__evaluate_given_model(model, test_dataloaders, 'test') else: - results = self.__test_using_best_weights(ckpt_path, test_dataloaders) + results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test') self.teardown('test') return results - def __test_using_best_weights(self, ckpt_path, test_dataloaders): + def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str): model = self.get_model() # if user requests the best checkpoint but we don't have it, error @@ -792,44 +840,62 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders): model.load_state_dict(ckpt['state_dict']) # attach dataloaders - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} + self.data_connector.attach_dataloaders(model, **kwargs) # run tests - self.tested_ckpt_path = ckpt_path - self.testing = True - os.environ['PL_TESTING_MODE'] = '1' + self.evaluating = stage + self.evaluated_ckpt_path = ckpt_path self.model = model results = self.fit(model) - self.testing = False - del os.environ['PL_TESTING_MODE'] + self.evaluating = None # teardown if self.is_function_implemented('teardown'): model_ref = self.get_model() - model_ref.teardown('test') + model_ref.teardown(stage) return results - def __test_given_model(self, model, test_dataloaders): + def __evaluate_given_model(self, model, dataloaders, stage: str): # attach data - if test_dataloaders is not None: - self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) + if dataloaders is not None: + kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders} + self.data_connector.attach_dataloaders(model, **kwargs) # run test # sets up testing so we short circuit to eval - self.testing = True + self.evaluating = stage self.model = model results = self.fit(model) - self.testing = False + self.evaluating = None # teardown if self.is_function_implemented('teardown'): - model.teardown('test') + model.teardown(stage) return results + @property + def testing(self): + warnings.warn( + 'Trainer.testing has been deprecated in v1.1 and will be removed ' + 'in v1.3, use Trainer.evaluating instead.', + DeprecationWarning, stacklevel=2 + ) + return bool(self.evaluating) + + @property + def tested_ckpt_path(self): + warnings.warn( + 'Trainer.tested_ckpt_path has been renamed Trainer.evaluated_ckpt_path ' + 'in v1.1 and will be removed in v1.3.', + DeprecationWarning, stacklevel=2 + ) + return self.evaluated_ckpt_path + def tune( self, model: LightningModule, @@ -856,11 +922,18 @@ def tune( def call_setup_hook(self, model): # call setup after the ddp process has connected - stage_name = 'test' if self.testing else 'fit' + stage_name = self.evaluating or 'fit' + if self.datamodule is not None: - called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit + called = { + None: self.datamodule.has_setup_fit, + 'validation': self.datamodule.has_setup_validation, + 'test': self.datamodule.has_setup_test, + }[self.evaluating] + if not called: self.datamodule.setup(stage_name) + self.setup(model, stage_name) model.setup(stage_name) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9a4f324033d39..ff19b9b8a9858 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -161,7 +161,7 @@ def setup_training(self, model: LightningModule): ref_model.on_pretrain_routine_start() # print model summary - if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.testing: + if self.trainer.is_global_zero and self.trainer.weights_summary is not None and not self.trainer.evaluating: if self.trainer.weights_summary in ModelSummary.MODES: ref_model.summarize(mode=self.trainer.weights_summary) else: diff --git a/tests/backends/test_dp.py b/tests/backends/test_dp.py index c051b442cb7a7..b697440280f80 100644 --- a/tests/backends/test_dp.py +++ b/tests/backends/test_dp.py @@ -67,7 +67,7 @@ def test_multi_gpu_model_dp(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_dp_test(tmpdir): +def test_dp_evaluate(tmpdir): tutils.set_random_master_port() import os @@ -84,6 +84,22 @@ def test_dp_test(tmpdir): ) trainer.fit(model) assert 'ckpt' in trainer.checkpoint_callback.best_model_path + + # validate + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + # test results = trainer.test() assert 'test_acc' in results[0] diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index e4d0b4bff89d7..94e4ba9c1efe9 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -33,7 +33,7 @@ def prepare_data(self): def setup(self, stage: Optional[str] = None): - if stage == "fit" or stage is None: + if stage != 'test': mnist_full = TrialMNIST( root=self.data_dir, train=True, num_samples=64, download=True ) @@ -88,7 +88,7 @@ def setup(self, stage: Optional[str] = None): # Assign train/val datasets for use in dataloaders # TODO: need to split using random_split once updated to torch >= 1.6 - if stage == "fit" or stage is None: + if stage != 'test': self.mnist_train = MNIST( self.data_dir, train=True, normalize=(0.1307, 0.3081) ) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index bb740b1dcbb1c..6f427afef7728 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -101,6 +101,28 @@ def test_trainer_callback_system(torch_save): call.teardown(trainer, model, 'fit'), ] + callback_mock.reset_mock() + trainer = Trainer(**trainer_options) + trainer.validate(model) + + assert callback_mock.method_calls == [ + call.on_init_start(trainer), + call.on_init_end(trainer), + call.setup(trainer, model, 'validation'), + call.on_fit_start(trainer, model), + call.on_pretrain_routine_start(trainer, model), + call.on_pretrain_routine_end(trainer, model), + call.on_validation_start(trainer, model), + call.on_validation_epoch_start(trainer, model), + call.on_validation_batch_start(trainer, model, ANY, 0, 0), + call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0), + call.on_validation_epoch_end(trainer, model), + call.on_validation_end(trainer, model), + call.on_fit_end(trainer, model), + call.teardown(trainer, model, 'fit'), + call.teardown(trainer, model, 'validation'), + ] + callback_mock.reset_mock() trainer = Trainer(**trainer_options) trainer.test(model) diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index 3c19748765e52..988da6f233dd2 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -84,7 +84,7 @@ def test_progress_bar_totals(tmpdir): limit_val_batches=1.0, max_epochs=1, ) - bar = trainer.progress_bar_callback + bar: ProgressBar = trainer.progress_bar_callback assert 0 == bar.total_train_batches assert 0 == bar.total_val_batches assert 0 == bar.total_test_batches @@ -113,6 +113,17 @@ def test_progress_bar_totals(tmpdir): assert 0 == bar.total_test_batches assert bar.test_progress_bar is None + trainer.validate(model) + + # check validation progress bar total + k = bar.total_val_batches + assert sum(len(loader) for loader in trainer.val_dataloaders) == k + assert bar.val_progress_bar.total == k + + # validation progress bar should have reached the end + assert bar.val_progress_bar.n == k + assert bar.val_batch_idx == k + trainer.test(model) # check test progress bar total @@ -135,7 +146,7 @@ def test_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) - progress_bar = trainer.progress_bar_callback + progress_bar: ProgressBar = trainer.progress_bar_callback assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded @@ -150,6 +161,13 @@ def test_progress_bar_fast_dev_run(tmpdir): assert 2 == progress_bar.main_progress_bar.total assert 2 == progress_bar.main_progress_bar.n + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == progress_bar.val_batch_idx + assert 1 == progress_bar.val_progress_bar.total + assert 1 == progress_bar.val_progress_bar.n + trainer.test(model) # the test progress bar should display 1 batch @@ -207,8 +225,16 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal trainer.fit(model) assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches assert progress_bar.val_batches_seen == 3 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 + + trainer.validate(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps + assert progress_bar.test_batches_seen == 0 trainer.test(model) + assert progress_bar.train_batches_seen == 3 * progress_bar.total_train_batches + assert progress_bar.val_batches_seen == 4 * progress_bar.total_val_batches + trainer.num_sanity_val_steps assert progress_bar.test_batches_seen == progress_bar.total_test_batches diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 33bc19a894d8f..e3e6dfe4ceddc 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -797,6 +797,9 @@ def get_model(): assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) + trainer.validate(model) + assert trainer.current_epoch == epochs - 1 + trainer.test(model) assert trainer.current_epoch == epochs - 1 @@ -817,6 +820,11 @@ def get_model(): ) assert_trainer_init(trainer) + trainer.validate(model) + assert not trainer.checkpoint_connector.has_trained + assert trainer.global_step == epochs * limit_train_batches + assert trainer.current_epoch == epochs + trainer.test(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 3e683025e8867..32f4aebe445d4 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -111,6 +111,7 @@ def test_base_datamodule_with_verbose_setup(tmpdir): dm = TrialMNISTDataModule() dm.prepare_data() dm.setup('fit') + dm.setup('validation') dm.setup('test') @@ -118,16 +119,19 @@ def test_data_hooks_called(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup() assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -135,21 +139,31 @@ def test_data_hooks_called_verbose(tmpdir): dm = TrialMNISTDataModule() assert dm.has_prepared_data is False assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.prepare_data() assert dm.has_prepared_data is True assert dm.has_setup_fit is False + assert dm.has_setup_validation is False assert dm.has_setup_test is False dm.setup('fit') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup('validation') + assert dm.has_prepared_data is True + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup('test') assert dm.has_prepared_data is True assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -160,10 +174,17 @@ def test_data_hooks_called_with_stage_kwarg(tmpdir): dm.setup(stage='fit') assert dm.has_setup_fit is True + assert dm.has_setup_validation is False + assert dm.has_setup_test is False + + dm.setup(stage='validation') + assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is False dm.setup(stage='test') assert dm.has_setup_fit is True + assert dm.has_setup_validation is True assert dm.has_setup_test is True @@ -254,6 +275,21 @@ def test_dm_checkpoint_save(tmpdir): assert checkpoint[dm.__class__.__name__] == dm.__class__.__name__ +def test_validate_loop_only(tmpdir): + reset_seed() + + dm = TrialMNISTDataModule(tmpdir) + + model = EvalModelTemplate() + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=3, + weights_summary=None, + ) + trainer.validate(model, datamodule=dm) + + def test_test_loop_only(tmpdir): reset_seed() @@ -287,6 +323,11 @@ def test_full_loop(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -312,6 +353,11 @@ def test_trainer_attached_to_dm(tmpdir): assert result == 1 assert dm.trainer is not None + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert dm.trainer is not None + # test result = trainer.test(datamodule=dm) result = result[0] @@ -338,6 +384,11 @@ def test_full_loop_single_gpu(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] @@ -365,6 +416,11 @@ def test_full_loop_dp(tmpdir): result = trainer.fit(model, dm) assert result == 1 + # validate + result = trainer.validate(datamodule=dm) + result = result[0] + assert result['val_acc'] > 0.8 + # test result = trainer.test(datamodule=dm) result = result[0] diff --git a/tests/trainer/test_config_validator.py b/tests/trainer/test_config_validator.py index 1ab97304f2338..b724fc8587e24 100755 --- a/tests/trainer/test_config_validator.py +++ b/tests/trainer/test_config_validator.py @@ -19,9 +19,6 @@ from tests.base import EvalModelTemplate -# TODO: add matching messages - - def test_wrong_train_setting(tmpdir): """ * Test that an error is thrown when no `train_dataloader()` is defined @@ -31,12 +28,12 @@ def test_wrong_train_setting(tmpdir): hparams = EvalModelTemplate.get_default_hparams() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `train_dataloader\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.train_dataloader = None trainer.fit(model) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `training_step\(\)` method defined.'): model = EvalModelTemplate(**hparams) model.training_step = None trainer.fit(model) @@ -47,7 +44,7 @@ def test_wrong_configure_optimizers(tmpdir): tutils.reset_seed() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - with pytest.raises(MisconfigurationException): + with pytest.raises(MisconfigurationException, match=r'No `configure_optimizers\(\)` method defined.'): model = EvalModelTemplate() model.configure_optimizers = None trainer.fit(model) @@ -62,13 +59,13 @@ def test_val_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # no val data has val loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): model = EvalModelTemplate(**hparams) model.validation_step = None trainer.fit(model) # has val loop but no val data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): model = EvalModelTemplate(**hparams) model.val_dataloader = None trainer.fit(model) @@ -82,13 +79,33 @@ def test_test_loop_config(tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) # has test loop but no test data - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you defined a test_step but have no test_dataloader'): model = EvalModelTemplate(**hparams) model.test_dataloader = None trainer.test(model) # has test data but no test loop - with pytest.warns(UserWarning): + with pytest.warns(UserWarning, match=r'you passed in a test_dataloader but have no test_step'): model = EvalModelTemplate(**hparams) model.test_step = None trainer.test(model, test_dataloaders=model.dataloader(train=False)) + + +def test_validation_loop_config(tmpdir): + """" + When either validation loop or validation data are missing + """ + hparams = EvalModelTemplate.get_default_hparams() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + + # has val loop but no val data + with pytest.warns(UserWarning, match=r'you defined a validation_step but have no val_dataloader'): + model = EvalModelTemplate(**hparams) + model.val_dataloader = None + trainer.validate(model) + + # has val data but no val loop + with pytest.warns(UserWarning, match=r'you passed in a val_dataloader but have no validation_step'): + model = EvalModelTemplate(**hparams) + model.validation_step = None + trainer.validate(model, val_dataloaders=model.dataloader(train=False)) diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index f16ef22faa507..d0b838b5fbf45 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -172,6 +172,48 @@ def test_step(self, batch, batch_idx, *args, **kwargs): trainer.test(ckpt_path=ckpt_path) +@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific']) +def test_multiple_validate_dataloader(tmpdir, ckpt_path): + """Verify multiple val_dataloaders.""" + + model_template = EvalModelTemplate() + + class MultipleValDataloaderModel(EvalModelTemplate): + def val_dataloader(self): + return model_template.val_dataloader__multiple() + + def validation_step(self, batch, batch_idx, *args, **kwargs): + return model_template.validation_step__multiple_dataloaders(batch, batch_idx, *args, **kwargs) + + def validation_epoch_end(self, outputs): + return model_template.validation_epoch_end__multiple_dataloaders(outputs) + + model = MultipleValDataloaderModel() + + # fit model + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0.1, + limit_train_batches=0.2, + ) + trainer.fit(model) + if ckpt_path == 'specific': + ckpt_path = trainer.checkpoint_callback.best_model_path + trainer.validate(ckpt_path=ckpt_path) + + # verify there are 2 test loaders + assert len(trainer.val_dataloaders) == 2, \ + 'Multiple val_dataloaders not initiated properly' + + # make sure predictions are good for each test set + for dataloader in trainer.val_dataloaders: + tpipes.run_prediction(dataloader, trainer.model) + + # run the validate method + trainer.validate(ckpt_path=ckpt_path) + + def test_train_dataloader_passed_to_fit(tmpdir): """Verify that train dataloader can be passed to fit """ diff --git a/tests/trainer/test_optimizers.py b/tests/trainer/test_optimizers.py index 2e76192836740..27f0bcda66926 100644 --- a/tests/trainer/test_optimizers.py +++ b/tests/trainer/test_optimizers.py @@ -337,6 +337,24 @@ def test_init_optimizers_during_testing(tmpdir): assert len(trainer.optimizer_frequencies) == 0 +def test_init_optimizers_during_validation(tmpdir): + """ + Test that optimizers is an empty list during validation. + """ + model = EvalModelTemplate() + model.configure_optimizers = model.configure_optimizers__multiple_schedulers + + trainer = Trainer( + default_root_dir=tmpdir, + limit_test_batches=10 + ) + trainer.validate(model, ckpt_path=None) + + assert len(trainer.lr_schedulers) == 0 + assert len(trainer.optimizers) == 0 + assert len(trainer.optimizer_frequencies) == 0 + + def test_multiple_optimizers_callbacks(tmpdir): """ Tests that multiple optimizers can be used with callbacks diff --git a/tests/trainer/test_states.py b/tests/trainer/test_states.py index 0244f654227a2..f6e29b7187d61 100644 --- a/tests/trainer/test_states.py +++ b/tests/trainer/test_states.py @@ -23,7 +23,7 @@ class StateSnapshotCallback(Callback): def __init__(self, snapshot_method: str): super().__init__() - assert snapshot_method in ['on_batch_start', 'on_test_batch_start'] + assert snapshot_method in ['on_batch_start', 'on_validation_batch_start', 'on_test_batch_start'] self.snapshot_method = snapshot_method self.trainer_state = None @@ -31,6 +31,10 @@ def on_batch_start(self, trainer, pl_module): if self.snapshot_method == 'on_batch_start': self.trainer_state = trainer.state + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + if self.snapshot_method == 'on_validation_batch_start': + self.trainer_state = trainer.state + def on_test_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): if self.snapshot_method == 'on_test_batch_start': self.trainer_state = trainer.state @@ -191,6 +195,40 @@ def test_finished_state_after_test(tmpdir): assert trainer.state == TrainerState.FINISHED +def test_running_state_during_validation(tmpdir): + """ Tests that state is set to RUNNING during test """ + + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + snapshot_callback = StateSnapshotCallback(snapshot_method='on_validation_batch_start') + + trainer = Trainer( + callbacks=[snapshot_callback], + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert snapshot_callback.trainer_state == TrainerState.RUNNING + + +def test_finished_state_after_validation(tmpdir): + """ Tests that state is FINISHED after fit """ + hparams = EvalModelTemplate.get_default_hparams() + model = EvalModelTemplate(**hparams) + + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + ) + + trainer.validate(model) + + assert trainer.state == TrainerState.FINISHED + + @pytest.mark.parametrize("extra_params", [ pytest.param(dict(fast_dev_run=True), id='Fast-Run'), pytest.param(dict(max_steps=1), id='Single-Step'), diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 085d361952844..5cf1bb17218a2 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -728,12 +728,12 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): trainer.test(ckpt_path=ckpt_path) else: trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == trainer.checkpoint_callback.best_model_path + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path elif ckpt_path is None: # ckpt_path is None, meaning we don't load any checkpoints and # use the weights from the end of training trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path is None + assert trainer.evaluated_ckpt_path is None else: # specific checkpoint, pick one from saved ones if save_top_k == 0: @@ -746,7 +746,48 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): ].absolute() ) trainer.test(ckpt_path=ckpt_path) - assert trainer.tested_ckpt_path == ckpt_path + assert trainer.evaluated_ckpt_path == ckpt_path + + +@pytest.mark.parametrize("ckpt_path", [None, "best", "specific"]) +@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) +def test_validate_checkpoint_path(tmpdir, ckpt_path, save_top_k): + hparams = EvalModelTemplate.get_default_hparams() + + model = EvalModelTemplate(**hparams) + trainer = Trainer( + max_epochs=2, + progress_bar_refresh_rate=0, + default_root_dir=tmpdir, + checkpoint_callback=ModelCheckpoint(monitor="early_stop_on", save_top_k=save_top_k), + ) + trainer.fit(model) + if ckpt_path == "best": + # ckpt_path is 'best', meaning we load the best weights + if save_top_k == 0: + with pytest.raises(MisconfigurationException, match=".*is not configured to save the best.*"): + trainer.validate(ckpt_path=ckpt_path) + else: + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path == trainer.checkpoint_callback.best_model_path + elif ckpt_path is None: + # ckpt_path is None, meaning we don't load any checkpoints and + # use the weights from the end of training + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path is None + else: + # specific checkpoint, pick one from saved ones + if save_top_k == 0: + with pytest.raises(FileNotFoundError): + trainer.validate(ckpt_path="random.ckpt") + else: + ckpt_path = str( + list((Path(tmpdir) / f"lightning_logs/version_{trainer.logger.version}/checkpoints").iterdir())[ + 0 + ].absolute() + ) + trainer.validate(ckpt_path=ckpt_path) + assert trainer.evaluated_ckpt_path == ckpt_path def test_disabled_training(tmpdir): @@ -1450,6 +1491,10 @@ def setup(self, model, stage): assert trainer.stage == "test" assert trainer.get_model().stage == "test" + trainer.validate(ckpt_path=None) + assert trainer.stage == "validation" + assert trainer.get_model().stage == "validation" + @pytest.mark.parametrize( "train_batches, max_steps, log_interval", diff --git a/tests/trainer/test_trainer_validate_loop.py b/tests/trainer/test_trainer_validate_loop.py new file mode 100644 index 0000000000000..a2205a4b50dc2 --- /dev/null +++ b/tests/trainer/test_trainer_validate_loop.py @@ -0,0 +1,76 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch + +import pytorch_lightning as pl +import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_single_gpu_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0], + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights)) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_spawn_validate(tmpdir): + tutils.set_random_master_port() + + model = EvalModelTemplate() + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + trainer.fit(model) + assert 'ckpt' in trainer.checkpoint_callback.best_model_path + results = trainer.validate() + assert 'val_acc' in results[0] + + old_weights = model.c_d1.weight.clone().detach().cpu() + + results = trainer.validate(model) + assert 'val_acc' in results[0] + + # make sure weights didn't change + new_weights = model.c_d1.weight.clone().detach().cpu() + + assert torch.all(torch.eq(old_weights, new_weights))