Skip to content

Commit

Permalink
Add Trainer.validate(…) method to run one validation epoch (#4948)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <[email protected]>
Co-authored-by: chaton <[email protected]>
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
4 people authored Mar 11, 2021
1 parent d1db604 commit f4cc745
Show file tree
Hide file tree
Showing 20 changed files with 446 additions and 483 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set ([#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))


- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))


Expand Down
15 changes: 14 additions & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,27 @@ So you can run it like so:

------------

Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
useful if you want to collect new metrics from a model right at its initialization
or after it has already been trained.

.. code-block:: python
trainer.validate(val_dataloaders=val_dataloaders)
------------

Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)

.. code-block:: python
trainer.test(test_dataloaders=test_dataloader)
trainer.test(test_dataloaders=test_dataloaders)
------------

Expand Down
11 changes: 8 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,11 @@ def init_predict_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """
# The main progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.main_progress_bar is not None
bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -426,7 +428,8 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)
self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down Expand Up @@ -479,8 +482,10 @@ def print(
def _should_update(self, current, total):
return self.is_enabled and (current % self.refresh_rate == 0 or current == total)

def _update_bar(self, bar):
def _update_bar(self, bar: Optional[tqdm]) -> None:
""" Updates the bar by the refresh rate without overshooting. """
if bar is None:
return
if bar.total is not None:
delta = min(self.refresh_rate, bar.total - bar.n)
else:
Expand Down
21 changes: 13 additions & 8 deletions pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -22,18 +23,24 @@ class ConfigValidator(object):
def __init__(self, trainer):
self.trainer = trainer

def verify_loop_configurations(self, model: LightningModule):
def verify_loop_configurations(self, model: LightningModule) -> None:
r"""
Checks that the model is configured correctly before the run is started.
Args:
model: The model to check the configuration.
"""
if self.trainer.training:
if self.trainer.state == TrainerState.FITTING:
self.__verify_train_loop_configuration(model)
elif self.trainer.evaluating:
self.__verify_eval_loop_configuration(model)
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TUNING:
self.__verify_train_loop_configuration(model)
elif self.trainer.state == TrainerState.VALIDATING:
self.__verify_eval_loop_configuration(model, 'val')
elif self.trainer.state == TrainerState.TESTING:
self.__verify_eval_loop_configuration(model, 'test')
# TODO: add predict

def __verify_train_loop_configuration(self, model):
# -----------------------------------
Expand Down Expand Up @@ -81,11 +88,9 @@ def __verify_train_loop_configuration(self, model):
' It ensures optimizer_step or optimizer_zero_grad are called on every batch.'
)

def __verify_eval_loop_configuration(self, model):
stage = "val" if self.trainer.validating else "test"

def __verify_eval_loop_configuration(self, model: LightningModule, stage: str) -> None:
loader_name = f'{stage}_dataloader'
step_name = f'{stage}_step'
step_name = 'validation_step' if stage == 'val' else 'test_step'

has_loader = is_overridden(loader_name, model)
has_step = is_overridden(step_name, model)
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
def attach_dataloaders(
self,
model,
train_dataloader=None,
val_dataloaders=None,
test_dataloaders=None,
predict_dataloaders=None,
train_dataloader: Optional[Union[DataLoader, List[DataLoader]]] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
predict_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
Expand All @@ -119,7 +119,7 @@ def attach_dataloaders(
if predict_dataloaders is not None:
model.predict_dataloader = _PatchDataLoader(predict_dataloaders)

def attach_datamodule(self, model, datamodule: Optional[LightningDataModule]) -> None:
def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = None) -> None:
# We use datamodule if it's been provided, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class RunningStage(LightningEnum):
"""
TRAINING = 'train'
SANITY_CHECKING = 'sanity_check'
VALIDATING = 'validation'
VALIDATING = 'validate'
TESTING = 'test'
PREDICTING = 'predict'
TUNING = 'tune'
Expand Down
140 changes: 90 additions & 50 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,69 @@ def run_sanity_check(self, ref_model):

self._running_stage = stage

def validate(
self,
model: Optional[LightningModule] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
r"""
Perform one evaluation epoch over the validation set.
Args:
model: The model to validate.
val_dataloaders: Either a single PyTorch DataLoader or a list of them,
specifying validation samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to validate.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
verbose: If True, prints the validation results.
datamodule: A instance of :class:`LightningDataModule`.
Returns:
The dictionary with final validation results returned by validation_epoch_end.
If validation_epoch_end is not defined, the output is a list of the dictionaries
returned by validation_step.
"""
# --------------------
# SETUP HOOK
# --------------------
self.verbose_evaluate = verbose

self.state = TrainerState.VALIDATING
self.validating = True

# If you supply a datamodule you can't supply val_dataloaders
if val_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass both `trainer.validate(val_dataloaders=..., datamodule=...)`'
)

model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, val_dataloaders=val_dataloaders)

if not model_provided:
self.validated_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)

# run validate
results = self.fit(model)

assert self.state.stopped
self.validating = False

return results

def test(
self,
model: Optional[LightningModule] = None,
Expand All @@ -833,17 +896,19 @@ def test(
fit to make sure you never run on your test set until you want to.
Args:
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model. Default to ``best``.
datamodule: A instance of :class:`LightningDataModule`.
model: The model to test.
test_dataloaders: Either a single PyTorch DataLoader or a list of them,
specifying test samples.
ckpt_path: Either ``best`` or path to the checkpoint you wish to test.
If ``None``, use the current weights of the model.
When the model is given as argument, this parameter will not apply.
verbose: If True, prints the test results.
datamodule: A instance of :class:`LightningDataModule`.
Returns:
Returns a list of dictionaries, one for each test dataloader containing their respective metrics.
"""
Expand All @@ -858,30 +923,33 @@ def test(
# If you supply a datamodule you can't supply test_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
'You cannot pass both `trainer.test(test_dataloaders=..., datamodule=...)`'
)

model_provided = model is not None
model = model or self.lightning_module

# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
results = (
self.__evaluate_given_model(model, dataloaders=test_dataloaders) if model_provided else
self.__evaluate_using_weights(model, ckpt_path=ckpt_path, dataloaders=test_dataloaders)
)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)

if not model_provided:
self.tested_ckpt_path = self.__load_ckpt_weights(model, ckpt_path=ckpt_path)

# run test
results = self.fit(model)

assert self.state.stopped
self.testing = False

return results

def __evaluate_using_weights(
def __load_ckpt_weights(
self,
model,
ckpt_path: Optional[str] = None,
dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None
):
) -> Optional[str]:
# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
raise MisconfigurationException(
Expand All @@ -894,42 +962,18 @@ def __evaluate_using_weights(
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path

if len(ckpt_path) == 0:
rank_zero_warn(
f'`.test()` found no path for the best weights, {ckpt_path}. Please'
' specify a path for a checkpoint `.test(ckpt_path=PATH)`'
if not ckpt_path:
fn = self.state.value
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
return {}

self.training_type_plugin.barrier()

ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])

# attach dataloaders
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)

if self.validating:
self.validated_ckpt_path = ckpt_path
else:
self.tested_ckpt_path = ckpt_path

# run test
results = self.fit(model)

return results

def __evaluate_given_model(self, model, dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None):
# attach data
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=dataloaders)

# run test
# sets up testing so we short circuit to eval
results = self.fit(model)

return results
return ckpt_path

def predict(
self,
Expand Down Expand Up @@ -970,15 +1014,11 @@ def predict(
'You cannot pass dataloaders to trainer.predict if you supply a datamodule.'
)

if datamodule is not None:
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)

# attach data
if dataloaders is not None:
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model, datamodule)
# Attach dataloaders (if given)
self.data_connector.attach_dataloaders(model, predict_dataloaders=dataloaders)

self.model = model
results = self.fit(model)

assert self.state.stopped
Expand Down
Loading

0 comments on commit f4cc745

Please sign in to comment.