Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename fast_dev_run -> unit_test #1087

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Renamed `fast_dev_run` to `unit_test` ([#1087](https://github.com/PyTorchLightning/pytorch-lightning/pull/1087))

### Deprecated

Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
8 changes: 4 additions & 4 deletions docs/source/debugging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@ Debugging
=========
The following are flags that make debugging much easier.

Fast dev run
------------
Unit test
---------
This flag runs a "unit test" by running 1 training batch and 1 validation batch.
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
The point is to detect any bugs in the training/validation loop without having to wait for
a full epoch to crash.

(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.fast_dev_run`
(See: :paramref:`~pytorch_lightning.trainer.trainer.Trainer.unit_test`
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
argument of :class:`~pytorch_lightning.trainer.trainer.Trainer`)

.. code-block:: python

trainer = pl.Trainer(fast_dev_run=True)
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
trainer = pl.Trainer(unit_test=True)

Inspect gradient norms
----------------------
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def on_train_end(self):

.. note:: If ``'val_loss'`` is not found will work as if early stopping is disabled.

fast_dev_run
^^^^^^^^^^^^
unit_test
^^^^^^^^^

Runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

Expand All @@ -298,10 +298,10 @@ def on_train_end(self):
Example::

# default used by the Trainer
trainer = Trainer(fast_dev_run=False)
trainer = Trainer(unit_test=False)

# runs 1 train, val, test batch and program ends
trainer = Trainer(fast_dev_run=True)
trainer = Trainer(unit_test=True)

gpus
^^^^
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,18 @@ def nb_sanity_val_steps(self, nb):
"`num_sanity_val_steps` since v0.5.0"
" and this method will be removed in v0.8.0", DeprecationWarning)
self.num_sanity_val_steps = nb

@property
def fast_dev_run(self):
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
"""Back compatibility, will be removed in v0.9.0"""
warnings.warn("Attribute `fast_dev_run` has renamed to `unit_test ` since v0.7.2"
" and this method will be removed in v0.9.0", DeprecationWarning)
return self.unit_test

@fast_dev_run.setter
def fast_dev_run(self, unit_test):
"""Back compatibility, will be removed in v0.9.0"""
warnings.warn("Attribute `fast_dev_run` has renamed to "
"`unit_test` since v0.7.2"
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
" and this method will be removed in v0.9.0", DeprecationWarning)
self.unit_test = unit_test
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class TrainerEvaluationLoopMixin(ABC):
model: LightningModule
num_test_batches: int
num_val_batches: int
fast_dev_run: ...
unit_test: bool
process_position: ...
show_progress_bar: ...
process_output: ...
Expand Down Expand Up @@ -252,7 +252,7 @@ def evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_m
if batch is None:
continue

# stop short when on fast_dev_run (sets max_batch=1)
# stop short when on unit_test (sets max_batch=1)
if batch_idx >= max_batches:
break

Expand Down Expand Up @@ -350,8 +350,8 @@ def run_evaluation(self, test_mode: bool = False):
dataloaders = self.val_dataloaders
max_batches = self.num_val_batches

# cap max batches to 1 when using fast_dev_run
if self.fast_dev_run:
# cap max batches to 1 when using unit_test
if self.unit_test:
max_batches = 1

# init validation or test progress bar
Expand Down
31 changes: 20 additions & 11 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
overfit_pct: float = 0.0,
track_grad_norm: int = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
unit_test: bool = False,
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
min_nb_epochs=None, # backward compatible, todo: remove in v0.8.0
Expand Down Expand Up @@ -120,6 +120,7 @@ def __init__(
profiler: Optional[BaseProfiler] = None,
benchmark: bool = False,
reload_dataloaders_every_epoch: bool = False,
fast_dev_run=None, # backward compatible, todo: remove in v0.8.0
**kwargs
):
r"""
Expand Down Expand Up @@ -167,7 +168,12 @@ def __init__(

check_val_every_n_epoch: Check val every n train epochs.

fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
fast_dev_run:
.. warning:: .. deprecated:: 0.7.2

Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved
Use `unit_test` instead. Will remove 0.9.0.

unit_test: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).

accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.

Expand Down Expand Up @@ -319,15 +325,16 @@ def __init__(
self.resume_from_checkpoint = resume_from_checkpoint
self.shown_warnings = set()

self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
self.unit_test = unit_test
# Backward compatibility, TODO: remove in v0.8.0
if fast_dev_run is not None:
self.fast_dev_run = fast_dev_run

if self.unit_test:
self.num_sanity_val_steps = 1
self.max_epochs = 1
m = '''
Running in fast_dev_run mode: will run a full train,
val loop using a single batch
'''
log.info(m)
log.info("Running in unit_test mode: will run a full train,"
" val and test loop using a single batch")

# set default save path if user didn't provide one
self.default_save_path = default_save_path
Expand Down Expand Up @@ -795,14 +802,15 @@ def run_pretrain_routine(self, model: LightningModule):
self.restore_weights(model)

# when testing requested only run test and return
if self.testing:
# Also, Include test batch validation in unit_test run
if self.testing or self.unit_test:
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
self.run_evaluation(test_mode=True)
return

# check if we should run validation during training
self.disable_validation = not self.is_overriden('validation_step') and not self.fast_dev_run
self.disable_validation = not self.is_overriden('validation_step') and not self.unit_test

# run tiny validation (if validation defined)
# to make sure program won't crash during val
Expand Down Expand Up @@ -896,6 +904,7 @@ class _PatchDataLoader(object):
dataloader: Dataloader object to return when called.

"""

def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader

Expand Down
20 changes: 10 additions & 10 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class TrainerTrainLoopMixin(ABC):
val_check_batch: ...
num_val_batches: int
disable_validation: bool
fast_dev_run: ...
unit_test: bool
main_progress_bar: ...
accumulation_scheduler: ...
lr_schedulers: ...
Expand Down Expand Up @@ -326,8 +326,8 @@ def train(self):
self.total_batches = self.num_training_batches + total_val_batches
self.batch_loss_value = 0 # accumulated grads

if self.fast_dev_run:
# limit the number of batches to 2 (1 train and 1 val) in fast_dev_run
if self.unit_test:
# limit the number of batches to 2 (1 train and 1 val) in unit_test
num_iterations = 2
elif self.total_batches == float('inf'):
# for infinite train or val loader, the progress bar never ends
Expand Down Expand Up @@ -360,7 +360,7 @@ def train(self):

# TODO wrap this logic into the callback
if self.enable_early_stop and not self.disable_validation and is_val_epoch:
if ((met_min_epochs and met_min_steps) or self.fast_dev_run):
if ((met_min_epochs and met_min_steps) or self.unit_test):
should_stop = self.early_stop_callback.on_epoch_end(self, self.get_model())
# stop training
stop = should_stop and met_min_epochs
Expand Down Expand Up @@ -432,19 +432,19 @@ def run_training_epoch(self):
should_check_val = not self.disable_validation and can_check_epoch
should_check_val = should_check_val and (is_val_check_batch or early_stop_epoch)

# fast_dev_run always forces val checking after train batch
if self.fast_dev_run or should_check_val:
# unit_test always forces val checking after train batch
if self.unit_test or should_check_val:
self.run_evaluation(test_mode=self.testing)

# when logs should be saved
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch
if should_save_log or self.fast_dev_run:
if should_save_log or self.unit_test:
if self.proc_rank == 0 and self.logger is not None:
self.logger.save()

# when metrics should be logged
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch
if should_log_metrics or self.fast_dev_run:
if should_log_metrics or self.unit_test:
# logs user requested information to logger
self.log_metrics(batch_step_metrics, grad_norm_dic)

Expand All @@ -453,7 +453,7 @@ def run_training_epoch(self):
# ---------------
# save checkpoint even when no test or val step are defined
train_step_only = not self.is_overriden('validation_step')
if self.fast_dev_run or should_check_val or train_step_only:
if self.unit_test or should_check_val or train_step_only:
self.call_checkpoint_callback()

if self.enable_early_stop:
Expand All @@ -471,7 +471,7 @@ def run_training_epoch(self):
# end epoch early
# stop when the flag is changed or we've gone past the amount
# requested in the batches
if early_stop_epoch or self.fast_dev_run:
if early_stop_epoch or self.unit_test:
break

# Epoch end events
Expand Down
17 changes: 17 additions & 0 deletions tests/test_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,23 @@ def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402


def test_tbd_remove_in_v0_9_0_trainer():
mapping_old_new = {
'fast_dev_run': 'unit_test',
}
# skip 0 since it may be interested as False
kwargs = {k: (i + 1) for i, k in enumerate(mapping_old_new)}
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved

trainer = Trainer(**kwargs)

for attr_old in mapping_old_new:
attr_new = mapping_old_new[attr_old]
assert kwargs[attr_old] == getattr(trainer, attr_old), \
'Missing deprecated attribute "%s"' % attr_old
assert kwargs[attr_old] == getattr(trainer, attr_new), \
'Wrongly passed deprecated argument "%s" to attribute "%s"' % (attr_old, attr_new)
Gokkulnath marked this conversation as resolved.
Show resolved Hide resolved


class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):

# todo: this shall not be needed while evaluate asks for dataloader explicitly
Expand Down