Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.validate(…) method to run one validation epoch #4707

Closed
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
62bd29e
Add Trainer.validate(…) to run one validation epoch
EliaCereda Nov 17, 2020
055e1ba
Support val_progress_bar without main_progress_bar in ProgressBar
EliaCereda Nov 17, 2020
156b669
Fix PEP 8 issue
EliaCereda Nov 17, 2020
1429548
Use `main_progress_bar is not None` to test if the bar is present in …
EliaCereda Nov 17, 2020
50427e7
Simplify selection of dataloaders arg to be set
EliaCereda Nov 17, 2020
d1988e0
Call setup(…) with stage ‘validation’ when running Trainer.validate(…)
EliaCereda Nov 17, 2020
ae03c6b
Check self.trainer.evaluating instead of self.trainer.testing in Acce…
EliaCereda Nov 17, 2020
5493a5b
Set Trainer.evaluating to None by default
EliaCereda Nov 17, 2020
860fef5
Replace the remaining instances of self.evaluating = False with None
EliaCereda Nov 18, 2020
99a6161
Add a first batch of tests for Trainer.validate(…)
EliaCereda Nov 18, 2020
307c89a
Avoid an if/else in ProgressBar
EliaCereda Nov 18, 2020
9e59e6d
Modify ModelCheckpoint to never save a checkpoint automatically when …
EliaCereda Nov 18, 2020
a844f40
Update test_config_validator.py to match the messages of expected err…
EliaCereda Nov 18, 2020
3f9f927
Fix Trainer.validate(…, verbose=True)
EliaCereda Nov 19, 2020
db22f2b
Transform Trainer.testing to a read-only deprecated property, remove …
EliaCereda Nov 19, 2020
f8647c5
Update docs for Trainer.validate and Trainer.test
EliaCereda Nov 19, 2020
99281a0
Remove usages of deprecated Trainer.testing
EliaCereda Nov 20, 2020
58d1c36
Rename methods and attributes to reflect their new behavior
EliaCereda Nov 20, 2020
7330ad4
Rename Trainer.tested_ckpt_path to Trainer.evaluated_ckpt_path since …
EliaCereda Nov 20, 2020
7abc67d
Update CHANGELOG.md
EliaCereda Nov 20, 2020
14799da
Fix PEP 8 issues
EliaCereda Nov 20, 2020
f8f4d3b
Update documentation of .setup(stage) methods to mention the new ‘val…
EliaCereda Nov 20, 2020
1818f22
Added more tests for Trainer.validate
EliaCereda Nov 20, 2020
ab89faa
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 20, 2020
d0cd34a
Fix hook that tracks LightningDataModule.setup(‘validation’) calls, a…
EliaCereda Nov 20, 2020
0209cfc
Add a test for Trainer.validate on DataParallel
EliaCereda Nov 20, 2020
6a04280
Disable EarlyStopping in evaluation mode
EliaCereda Nov 21, 2020
2115350
Clean up LoggerConnector.get_evaluate_epoch_results
EliaCereda Nov 21, 2020
92acb12
Improve description of Trainer.validate in docs/source/trainer.rst
EliaCereda Nov 21, 2020
8090193
Clean up setup() methods in tests/base/datamodules.py
EliaCereda Nov 21, 2020
a098489
Update deprecation warnings
EliaCereda Nov 21, 2020
f8ab391
Update Trainer.{validate, test} docstrings
EliaCereda Nov 21, 2020
605e7b0
Fix PEP 8 issue
EliaCereda Nov 21, 2020
14a7767
Consistently use the serial comma in docstrings
EliaCereda Nov 23, 2020
e9a6956
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 23, 2020
873099e
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Nov 25, 2020
6f2ce28
Fix PEP 8 issue
EliaCereda Nov 25, 2020
0f4e474
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 2, 2020
d4cb1b0
Rewrite assertions for Trainer.validate in test_callbacks.py using Ma…
EliaCereda Dec 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,9 @@ def barrier(self, name: Optional[str] = None):
def broadcast(self, obj, src=0):
return obj

# TODO: rename train_or_evaluate
def train_or_test(self):
if self.trainer.testing:
if self.trainer.evaluating:
results = self.trainer.run_test()
else:
results = self.trainer.train()
Expand Down Expand Up @@ -160,7 +161,7 @@ def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

def setup_optimizers(self, model):
if self.trainer.testing is True:
if self.trainer.evaluating:
return

optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,13 @@ def init_train_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """

# The main progress bar doesn't exist in trainer.validate(...)
has_main_bar = 1 if self.main_progress_bar is not None else 0

bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -348,11 +352,18 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self.is_enabled and self.val_batch_idx % self.refresh_rate == 0:
self.val_progress_bar.update(self.refresh_rate)
self.main_progress_bar.update(self.refresh_rate)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self.main_progress_bar.update(self.refresh_rate)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down
10 changes: 10 additions & 0 deletions pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def __init__(
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_validation = False
self._has_setup_test = False

@property
Expand Down Expand Up @@ -230,6 +231,15 @@ def has_setup_fit(self):
"""
return self._has_setup_fit

@property
def has_setup_validation(self):
"""Return bool letting you know if datamodule.setup('validation') has been called or not.

Returns:
bool: True if datamodule.setup('validation') has been called. False by default.
"""
return self._has_setup_validation

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def verify_loop_configurations(self, model: LightningModule):
model: The model to check the configuration.

"""
if not self.trainer.testing:
if not self.trainer.evaluating:
self.__verify_train_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'validation')
else:
# check test loop configuration
self.__verify_eval_loop_configuration(model, 'test')
# check evaluation loop configurations
self.__verify_eval_loop_configuration(model, self.trainer.evaluating)

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def on_trainer_init(self):
self.trainer.val_dataloaders = None
self.trainer.running_sanity_check = False
self.trainer.testing = False
self.trainer.evaluating = False
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved

# when .test() is called, it sets this
self.trainer.tested_ckpt_path = None
Expand Down
101 changes: 85 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,10 +659,12 @@ def track_output_for_epoch_end(self, outputs, output):
outputs.append(output)
return outputs

# TODO: rename run_test_or_validate?
def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
eval_loop_results, _ = self.run_evaluation(test_mode=True)
test_mode = True if self.evaluating == 'test' else False
eval_loop_results, _ = self.run_evaluation(test_mode=test_mode)

if len(eval_loop_results) == 0:
return 1
Expand Down Expand Up @@ -710,6 +712,60 @@ def run_sanity_check(self, ref_model):
self.on_sanity_check_end()
self.running_sanity_check = False

def validate(
self,
model: Optional[LightningModule] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
# TODO: docstring
r"""

Separates from fit to make sure you never run on your test set until you want to.

Args:
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the weights from the last epoch to test. Default to ``best``.

datamodule: A instance of :class:`LightningDataModule`.

model: The model to test.

test_dataloaders: Either a single
Pytorch Dataloader or a list of them, specifying validation samples.

verbose: If True, prints the test results

Returns:
The final test result dictionary. If no test_epoch_end is defined returns a list of dictionaries
"""
# --------------------
# SETUP HOOK
# --------------------
self.verbose_test = verbose # TODO: rename / else?

self.logger_connector.set_stage("validation")

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass val_dataloaders to trainer.validate if you supply a datamodule'
)

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'validation')

if model is not None:
results = self.__evaluate_given_model(model, val_dataloaders, 'validation')
else:
results = self.__evaluate_using_best_weights(ckpt_path, val_dataloaders, 'validation')

self.teardown('validation')

return results

def test(
self,
model: Optional[LightningModule] = None,
Expand Down Expand Up @@ -745,7 +801,7 @@ def test(

self.logger_connector.set_stage("test")

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
Expand All @@ -755,15 +811,15 @@ def test(
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')

if model is not None:
results = self.__test_given_model(model, test_dataloaders)
results = self.__evaluate_given_model(model, test_dataloaders, 'test')
else:
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
results = self.__evaluate_using_best_weights(ckpt_path, test_dataloaders, 'test')

self.teardown('test')

return results

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
def __evaluate_using_best_weights(self, ckpt_path, dataloaders, stage: str):
model = self.get_model()

# if user requests the best checkpoint but we don't have it, error
Expand Down Expand Up @@ -791,41 +847,47 @@ def __test_using_best_weights(self, ckpt_path, test_dataloaders):
model.load_state_dict(ckpt['state_dict'])

# attach dataloaders
if test_dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
if dataloaders is not None:
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders}
self.data_connector.attach_dataloaders(model, **kwargs)

# run tests
self.evaluating = stage
self.tested_ckpt_path = ckpt_path
self.testing = True
os.environ['PL_TESTING_MODE'] = '1'
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
self.model = model
results = self.fit(model)
self.testing = False
del os.environ['PL_TESTING_MODE']
self.testing = False
self.evaluating = False

# teardown
if self.is_function_implemented('teardown'):
model_ref = self.get_model()
model_ref.teardown('test')
model_ref.teardown(stage)

return results

def __test_given_model(self, model, test_dataloaders):
def __evaluate_given_model(self, model, dataloaders, stage: str):

# attach data
if test_dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
if dataloaders is not None:
kwargs = {'test_dataloaders' if stage == 'test' else 'val_dataloaders': dataloaders}
self.data_connector.attach_dataloaders(model, **kwargs)

# run test
# sets up testing so we short circuit to eval
self.testing = True
self.evaluating = stage
self.testing = True # TODO: remove, keep only evaluating
EliaCereda marked this conversation as resolved.
Show resolved Hide resolved
self.model = model
results = self.fit(model)
self.testing = False
self.evaluating = False

# teardown
if self.is_function_implemented('teardown'):
model.teardown('test')
model.teardown(stage)

return results

Expand Down Expand Up @@ -855,11 +917,18 @@ def tune(

def call_setup_hook(self, model):
# call setup after the ddp process has connected
stage_name = 'test' if self.testing else 'fit'
stage_name = self.evaluating or 'fit'

if self.datamodule is not None:
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
called = {
False: self.datamodule.has_setup_fit,
'validation': self.datamodule.has_setup_validation,
'test': self.datamodule.has_setup_test,
}[self.evaluating]

if not called:
self.datamodule.setup(stage_name)

self.setup(model, stage_name)
model.setup(stage_name)

Expand Down