diff --git a/CHANGELOG.md b/CHANGELOG.md index c948e22e7b553..946dc040a9aa2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,11 +17,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Renamed `fast_dev_run` to `unit_test` ([#1087](https://github.com/PyTorchLightning/pytorch-lightning/pull/1087)) ### Deprecated - Deprecated Trainer argument `print_nan_grads` ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) +- Deprecated Trainer argument `fast_dev_run` ([#1087](https://github.com/PyTorchLightning/pytorch-lightning/pull/1087)) ### Removed diff --git a/docs/source/debugging.rst b/docs/source/debugging.rst index 775862d8c1826..e88847b3b3943 100644 --- a/docs/source/debugging.rst +++ b/docs/source/debugging.rst @@ -2,18 +2,18 @@ Debugging ========= The following are flags that make debugging much easier. -Fast dev run ------------- -This flag runs a "unit test" by running 1 training batch and 1 validation batch. -The point is to detect any bugs in the training/validation loop without having to wait for +Unit test +--------- +This flag runs a "unit test" by running 1 training batch, 1 validation batch and 1 test batch. +The point is to detect any bugs in the training/validation/test loop without having to wait for a full epoch to crash. -(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run` +(.. seealso::paramref:`~pytorch_lightning.trainer.trainer.Trainer.unit_test` argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. code-block:: python - trainer = pl.Trainer(fast_dev_run=True) + trainer = pl.Trainer(unit_test=True) Inspect gradient norms ---------------------- @@ -75,4 +75,4 @@ argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`) .. code-block:: python # DEFAULT - trainer = Trainer(num_sanity_val_steps=5) \ No newline at end of file + trainer = Trainer(num_sanity_val_steps=5) diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index 22d3e2d135633..dd52519d0dd20 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -273,8 +273,8 @@ def on_train_end(self): .. note:: If ``'val_loss'`` is not found will work as if early stopping is disabled. -fast_dev_run -^^^^^^^^^^^^ +unit_test +^^^^^^^^^ Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). @@ -298,10 +298,10 @@ def on_train_end(self): Example:: # default used by the Trainer - trainer = Trainer(fast_dev_run=False) + trainer = Trainer(unit_test=False) # runs 1 train, val, test batch and program ends - trainer = Trainer(fast_dev_run=True) + trainer = Trainer(unit_test=True) gpus ^^^^ diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index 8c4ca8648b5bb..0fda8e7a6131f 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -87,3 +87,23 @@ def nb_sanity_val_steps(self, nb): "`num_sanity_val_steps` since v0.5.0" " and this method will be removed in v0.8.0", DeprecationWarning) self.num_sanity_val_steps = nb + + +class TrainerDeprecatedAPITillVer0_9(ABC): + + def __init__(self): + super().__init__() # mixin calls super too + + @property + def fast_dev_run(self): + """Back compatibility, will be removed in v0.9.0""" + warnings.warn("Attribute `fast_dev_run` has renamed to `unit_test ` since v0.7.2" + " and this method will be removed in v0.9.0", DeprecationWarning) + return self.unit_test + + @fast_dev_run.setter + def fast_dev_run(self, unit_test): + """Back compatibility, will be removed in v0.9.0""" + warnings.warn("Attribute `fast_dev_run` has renamed to `unit_test` since v0.7.2" + " and this method will be removed in v0.9.0", DeprecationWarning) + self.unit_test = unit_test diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index bfbca4c25db29..ce7204ed2b7cf 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -161,7 +161,7 @@ class TrainerEvaluationLoopMixin(ABC): model: LightningModule num_test_batches: int num_val_batches: int - fast_dev_run: ... + unit_test: bool process_position: ... show_progress_bar: ... process_output: ... @@ -253,7 +253,7 @@ def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_m if batch is None: continue - # stop short when on fast_dev_run (sets max_batch=1) + # stop short when on unit_test (sets max_batch=1) if batch_idx >= max_batches: break @@ -351,8 +351,8 @@ def run_evaluation(self, test_mode: bool = False): dataloaders = self.val_dataloaders max_batches = self.num_val_batches - # cap max batches to 1 when using fast_dev_run - if self.fast_dev_run: + # cap max batches to 1 when using unit_test + if self.unit_test: max_batches = 1 # init validation or test progress bar diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 881a2e9103301..ab28f8e7afe0b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -23,8 +23,8 @@ from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin +from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8, TrainerDeprecatedAPITillVer0_9 from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin -from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8 from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin @@ -66,6 +66,7 @@ class Trainer( TrainerCallbackConfigMixin, TrainerCallbackHookMixin, TrainerDeprecatedAPITillVer0_8, + TrainerDeprecatedAPITillVer0_9, ): DEPRECATED_IN_0_8 = ( 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', @@ -93,7 +94,7 @@ def __init__( overfit_pct: float = 0.0, track_grad_norm: int = -1, check_val_every_n_epoch: int = 1, - fast_dev_run: bool = False, + unit_test: bool = False, accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0 @@ -122,6 +123,7 @@ def __init__( profiler: Optional[BaseProfiler] = None, benchmark: bool = False, reload_dataloaders_every_epoch: bool = False, + fast_dev_run=None, # backward compatible, todo: remove in v0.9.0 **kwargs ): r""" @@ -171,7 +173,12 @@ def __init__( check_val_every_n_epoch: Check val every n train epochs. - fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). + fast_dev_run: + .. warning:: .. deprecated:: 0.7.2 + + Use `unit_test` instead. Will remove 0.9.0. + + unit_test: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test). accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. @@ -323,15 +330,16 @@ def __init__( self.resume_from_checkpoint = resume_from_checkpoint self.shown_warnings = set() - self.fast_dev_run = fast_dev_run - if self.fast_dev_run: + self.unit_test = unit_test + # Backward compatibility, TODO: remove in v0.8.0 + if fast_dev_run is not None: + self.fast_dev_run = fast_dev_run + + if self.unit_test: self.num_sanity_val_steps = 1 self.max_epochs = 1 - m = ''' - Running in fast_dev_run mode: will run a full train, - val loop using a single batch - ''' - log.info(m) + log.info("Running in unit_test mode: will run a full train," + " validation and test loop using a single batch") # set default save path if user didn't provide one self.default_save_path = default_save_path @@ -869,14 +877,15 @@ def run_pretrain_routine(self, model: LightningModule): self.restore_weights(model) # when testing requested only run test and return - if self.testing: + # Also, Include test batch validation in unit_test run + if self.testing or self.unit_test: # only load test dataloader for testing # self.reset_test_dataloader(ref_model) self.run_evaluation(test_mode=True) return # check if we should run validation during training - self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run + self.disable_validation = not self.is_overriden('validation_step') and not self.unit_test # run tiny validation (if validation defined) # to make sure program won't crash during val @@ -970,6 +979,7 @@ class _PatchDataLoader(object): dataloader: Dataloader object to return when called. """ + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): self.dataloader = dataloader diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a90a43b0d8beb..ccfeb2824f611 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -180,7 +180,7 @@ class TrainerTrainLoopMixin(ABC): val_check_batch: ... num_val_batches: int disable_validation: bool - fast_dev_run: ... + unit_test: bool main_progress_bar: ... accumulation_scheduler: ... lr_schedulers: ... @@ -326,8 +326,8 @@ def train(self): self.total_batches = self.num_training_batches + total_val_batches self.batch_loss_value = 0 # accumulated grads - if self.fast_dev_run: - # limit the number of batches to 2 (1 train and 1 val) in fast_dev_run + if self.unit_test: + # limit the number of batches to 2 (1 train and 1 val) in unit_test num_iterations = 2 elif self.total_batches == float('inf'): # for infinite train or val loader, the progress bar never ends @@ -360,7 +360,7 @@ def train(self): # TODO wrap this logic into the callback if self.enable_early_stop and not self.disable_validation and is_val_epoch: - if ((met_min_epochs and met_min_steps) or self.fast_dev_run): + if ((met_min_epochs and met_min_steps) or self.unit_test): should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model()) # stop training stop = should_stop and met_min_epochs @@ -432,19 +432,19 @@ def run_training_epoch(self): should_check_val = not self.disable_validation and can_check_epoch should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch) - # fast_dev_run always forces val checking after train batch - if self.fast_dev_run or should_check_val: + # unit_test always forces val checking after train batch + if self.unit_test or should_check_val: self.run_evaluation(test_mode=self.testing) # when logs should be saved should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch - if should_save_log or self.fast_dev_run: + if should_save_log or self.unit_test: if self.proc_rank == 0 and self.logger is not None: self.logger.save() # when metrics should be logged should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch - if should_log_metrics or self.fast_dev_run: + if should_log_metrics or self.unit_test: # logs user requested information to logger self.log_metrics(batch_step_metrics, grad_norm_dic) @@ -453,7 +453,7 @@ def run_training_epoch(self): # --------------- # save checkpoint even when no test or val step are defined train_step_only = not self.is_overriden('validation_step') - if self.fast_dev_run or should_check_val or train_step_only: + if self.unit_test or should_check_val or train_step_only: self.call_checkpoint_callback() if self.enable_early_stop: @@ -471,7 +471,7 @@ def run_training_epoch(self): # end epoch early # stop when the flag is changed or we've gone past the amount # requested in the batches - if early_stop_epoch or self.fast_dev_run: + if early_stop_epoch or self.unit_test: break # Epoch end events diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a79eb7451305f..b73debe74d4d6 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -27,6 +27,15 @@ def test_tbd_remove_in_v0_8_0_module_imports(): from pytorch_lightning.root_module.root_module import LightningModule # noqa: F811 +def _assert_trainer_atribs(trainer, mapping_old_new, kwargs): + for attr_old in mapping_old_new: + attr_new = mapping_old_new[attr_old] + assert kwargs[attr_old] == getattr(trainer, attr_old), \ + 'Missing deprecated attribute "%s"' % attr_old + assert kwargs[attr_old] == getattr(trainer, attr_new), \ + 'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new) + + def test_tbd_remove_in_v0_8_0_trainer(): mapping_old_new = { 'gradient_clip': 'gradient_clip_val', @@ -40,12 +49,7 @@ def test_tbd_remove_in_v0_8_0_trainer(): trainer = Trainer(**kwargs) - for attr_old in mapping_old_new: - attr_new = mapping_old_new[attr_old] - assert kwargs[attr_old] == getattr(trainer, attr_old), \ - 'Missing deprecated attribute "%s"' % attr_old - assert kwargs[attr_old] == getattr(trainer, attr_new), \ - 'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new) + _assert_trainer_atribs(trainer, mapping_old_new, kwargs) def test_tbd_remove_in_v0_9_0_module_imports(): @@ -58,6 +62,18 @@ def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 +def test_tbd_remove_in_v0_9_0_trainer(): + mapping_old_new = { + 'fast_dev_run': 'unit_test', + } + # skip 0 since it may be interested as False + kwargs = {'fast_dev_run': True} + + trainer = Trainer(**kwargs) + + _assert_trainer_atribs(trainer, mapping_old_new, kwargs) + + class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): # todo: this shall not be needed while evaluate asks for dataloader explicitly