Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.validate(…) method to run one validation epoch #4707

Closed
Closed
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
62bd29e
Add Trainer.validate(…) to run one validation epoch
EliaCereda Nov 17, 2020
055e1ba
Support val_progress_bar without main_progress_bar in ProgressBar
EliaCereda Nov 17, 2020
156b669
Fix PEP 8 issue
EliaCereda Nov 17, 2020
1429548
Use `main_progress_bar is not None` to test if the bar is present in …
EliaCereda Nov 17, 2020
50427e7
Simplify selection of dataloaders arg to be set
EliaCereda Nov 17, 2020
d1988e0
Call setup(…) with stage ‘validation’ when running Trainer.validate(…)
EliaCereda Nov 17, 2020
ae03c6b
Check self.trainer.evaluating instead of self.trainer.testing in Acce…
EliaCereda Nov 17, 2020
5493a5b
Set Trainer.evaluating to None by default
EliaCereda Nov 17, 2020
860fef5
Replace the remaining instances of self.evaluating = False with None
EliaCereda Nov 18, 2020
99a6161
Add a first batch of tests for Trainer.validate(…)
EliaCereda Nov 18, 2020
307c89a
Avoid an if/else in ProgressBar
EliaCereda Nov 18, 2020
9e59e6d
Modify ModelCheckpoint to never save a checkpoint automatically when …
EliaCereda Nov 18, 2020
a844f40
Update test_config_validator.py to match the messages of expected err…
EliaCereda Nov 18, 2020
3f9f927
Fix Trainer.validate(…, verbose=True)
EliaCereda Nov 19, 2020
db22f2b
Transform Trainer.testing to a read-only deprecated property, remove …
EliaCereda Nov 19, 2020
f8647c5
Update docs for Trainer.validate and Trainer.test
EliaCereda Nov 19, 2020
99281a0
Remove usages of deprecated Trainer.testing
EliaCereda Nov 20, 2020
58d1c36
Rename methods and attributes to reflect their new behavior
EliaCereda Nov 20, 2020
7330ad4
Rename Trainer.tested_ckpt_path to Trainer.evaluated_ckpt_path since …
EliaCereda Nov 20, 2020
7abc67d
Update CHANGELOG.md
EliaCereda Nov 20, 2020
14799da
Fix PEP 8 issues
EliaCereda Nov 20, 2020
f8f4d3b
Update documentation of .setup(stage) methods to mention the new ‘val…
EliaCereda Nov 20, 2020
1818f22
Added more tests for Trainer.validate
EliaCereda Nov 20, 2020
ab89faa
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 20, 2020
d0cd34a
Fix hook that tracks LightningDataModule.setup(‘validation’) calls, a…
EliaCereda Nov 20, 2020
0209cfc
Add a test for Trainer.validate on DataParallel
EliaCereda Nov 20, 2020
6a04280
Disable EarlyStopping in evaluation mode
EliaCereda Nov 21, 2020
2115350
Clean up LoggerConnector.get_evaluate_epoch_results
EliaCereda Nov 21, 2020
92acb12
Improve description of Trainer.validate in docs/source/trainer.rst
EliaCereda Nov 21, 2020
8090193
Clean up setup() methods in tests/base/datamodules.py
EliaCereda Nov 21, 2020
a098489
Update deprecation warnings
EliaCereda Nov 21, 2020
f8ab391
Update Trainer.{validate, test} docstrings
EliaCereda Nov 21, 2020
605e7b0
Fix PEP 8 issue
EliaCereda Nov 21, 2020
14a7767
Consistently use the serial comma in docstrings
EliaCereda Nov 23, 2020
e9a6956
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 23, 2020
873099e
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 25, 2020
6f2ce28
Fix PEP 8 issue
EliaCereda Nov 25, 2020
0f4e474
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 2, 2020
d4cb1b0
Rewrite assertions for Trainer.validate in test_callbacks.py using Ma…
EliaCereda Dec 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `F1` class metric ([#4656](https://github.com/PyTorchLightning/pytorch-lightning/pull/4656))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set (
[#4707](https://github.com/PyTorchLightning/pytorch-lightning/pull/4707))


### Changed

- Consistently use `step=trainer.global_step` in `LearningRateMonitor` independently of `logging_interval` ([#4376](https://github.com/PyTorchLightning/pytorch-lightning/pull/4376))
Expand Down
15 changes: 14 additions & 1 deletion docs/source/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,27 @@ So you can run it like so:

------------

Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
useful if you want to collect new metrics from a model right at its initialization
or that has already been trained.

.. code-block:: python

trainer.validate(val_dataloaders=val_dataloaders)

------------

Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)

.. code-block:: python

trainer.test(test_dataloader=test_dataloader)
trainer.test(test_dataloaders=test_dataloaders)

------------

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def barrier(self, name: Optional[str] = None):
def broadcast(self, obj, src=0):
return obj

def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
def train_or_evaluate(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name here is a bit misleading as this also runs test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used evaluate here to refer to either test or validate. I think it was inspired by the pre-existing Trainer.run_evaluation method, which is used to run either the test or validation loop depending on the value of the test_mode parameter.

Let me know if you have a better idea for the name!

if self.trainer.evaluating:
results = self.trainer.run_test_or_validate()
else:
results = self.trainer.train()
return results
Expand Down Expand Up @@ -160,7 +160,7 @@ def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

def setup_optimizers(self, model):
if self.trainer.testing is True:
if self.trainer.evaluating:
return

optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()
return results

def training_step(self, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def ddp_train(self, process_idx, model):
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def ddp_train(self, process_idx, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

def training_step(self, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(self.trainer.model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@ class Callback(abc.ABC):
"""

def setup(self, trainer, pl_module, stage: str):
"""Called when fit or test begins"""
"""Called when fit, validate, or test begins"""
pass

def teardown(self, trainer, pl_module, stage: str):
"""Called when fit or test ends"""
"""Called when fit, validate, or test ends"""
pass

def on_init_start(self, trainer):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def on_load_checkpoint(self, checkpointed_state):
self.patience = checkpointed_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

if self._validate_condition_metric(trainer.logger_connector.callback_metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
or self.last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
22 changes: 18 additions & 4 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """

# The main progress bar doesn't exist in trainer.validate(...)
has_main_bar = int(self.main_progress_bar is not None)

bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -341,19 +345,29 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar) # fill up remaining

self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs):
if fn.__name__ == "setup":

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# If not provided, set call status of 'fit', 'validation', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage == "fit" or stage is None:
obj._has_setup_fit = True

if stage == "validation" or stage is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would make all references in datamodule "validate" (instead of "validation) to keep it consistent with fit and test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fit and test can be nouns or verbs, however, we are talking about stages which means they should be nouns.
So if I am not mistaken (English is not my first language), using validation is more consistent.
The validation stage vs the validate stage

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think grammar should be a consideration here since we're only talking about variables in code... and that variable name consistency is more important. Thoughts on this? @justusschock @rohitgr7

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think both of you have a point here and once could certainly use both. However, I feel that validation stage is more intuitive and personally I would go with it since it sounds 'more correct' to me, but this is just a personal opinion. Also I think, that this should definitely not be a blocker here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is a good point. I was also ambivalent about it while I was writing the code.

There is another occurrence of this issue: the Trainer.evaluating attribute, which can be either test or validation. Here validation is the right choice in my opinion, reading it as "currently evaluating over the test/validation set".

It was not so clear cut in the data module: I'd say that 'validation' sounds better for me too, but I would not be opposed to using 'validate' either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, it does have a better ring to it :]

obj._has_setup_validation = True

if stage == "test" or stage is None:
obj._has_setup_test = True

Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_validation = False
self._has_setup_test = False

@property
Expand Down Expand Up @@ -230,6 +234,15 @@ def has_setup_fit(self):
"""
return self._has_setup_fit

@property
def has_setup_validation(self):
"""Return bool letting you know if datamodule.setup('validation') has been called or not.

Returns:
bool: True if datamodule.setup('validation') has been called. False by default.
"""
return self._has_setup_validation

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str):
"""
Called at the beginning of fit and test.
Called at the beginning of fit (training + validation), validation, and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'

Example::

Expand All @@ -61,10 +61,10 @@ def setup(stage):

def teardown(self, stage: str):
"""
Called at the end of fit and test.
Called at the end of fit (training + validation), validation, and test.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'
"""

def on_fit_start(self):
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
model: The model to check the configuration.

"""
if not self.trainer.testing:
if not self.trainer.evaluating:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
# check test loop configuration
self.__verify_eval_loop_configuration(model, 'test')
# check evaluation loop configurations
self.__verify_eval_loop_configuration(model, self.trainer.evaluating)

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def prepare_eval_loop_results(self):
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
self.add_to_eval_loop_results(dl_idx, has_been_initialized)

def get_evaluate_epoch_results(self, test_mode):
def get_evaluate_epoch_results(self):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
Expand All @@ -269,11 +269,11 @@ def get_evaluate_epoch_results(self, test_mode):

self.prepare_eval_loop_results()

# log results of test
if test_mode and self.trainer.is_global_zero and self.trainer.verbose_test:
# log results of evaluation
if self.trainer.evaluating and self.trainer.is_global_zero and self.trainer.verbose_evaluate:
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} TEST RESULTS')
print(f'DATALOADER:{result_idx} {self.trainer.evaluating.upper()} RESULTS')
pprint(results)
print('-' * 80)

Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ def copy_trainer_model_properties(self, model):
m.use_ddp2 = self.trainer.use_ddp2
m.use_ddp = self.trainer.use_ddp
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
# TODO: I only find usages of m.testing in DDP, where it's used to
# discriminate test from validation, as opposed to test from fit in
# Trainer. Still need to fully determine if it's correct.
m.testing = self.trainer.evaluating == 'test'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that m.testing is interpreted a bit differently than trainer.testing:

  • The latter is used in most of the code to discriminate if the top-level function that the user called was fit() or test(). As such, in the rest of the code I replaced if trainer.testing with if trainer.evaluating, that is validate() and test() take the same code path
  • The former is used by DP and DDP to discriminate if the model is currently inside the validation loop or the test loop (eg. here)

If this interpretation is correct, the code should be good as is. If not, it needs to be changed.

m.use_single_gpu = self.trainer.use_single_gpu
m.use_tpu = self.trainer.use_tpu
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
Expand Down
Loading