Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handling for all trainer entry points #8819

Merged
merged 36 commits into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
dfaf7a1
[lightning] Ensure error handling works different trainer entry points
Aug 9, 2021
7747faf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
dc3b571
Rebase and update after review
daniellepintz Aug 10, 2021
297caae
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 10, 2021
2ed6b66
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 10, 2021
17af386
[lightning] Ensure error handling works different trainer entry points
Aug 9, 2021
1a07547
Rebase and update after review
daniellepintz Aug 10, 2021
ef58012
modularize duplicate code
daniellepintz Aug 11, 2021
b287588
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 11, 2021
0b9cbeb
Merge branch 'PyTorchLightning:master' into error_handling
daniellepintz Aug 11, 2021
dda1c36
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 11, 2021
083198c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2021
b0af6ed
Apply black formatting
daniellepintz Aug 11, 2021
0260eb3
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 11, 2021
fc10ebf
Add typing and DDPSpawnPlugin check
daniellepintz Aug 11, 2021
ebe8b42
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2021
fde1b44
Fix CLI test
carmocca Aug 12, 2021
eb24c17
separate ddp test
daniellepintz Aug 14, 2021
e4a11a4
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 14, 2021
3ee8f11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 14, 2021
356bbdf
add ddp test and remove dummyexception
daniellepintz Aug 16, 2021
bd94015
add ddp test
daniellepintz Aug 16, 2021
9c5fac2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 16, 2021
9913438
fix cli test
daniellepintz Aug 16, 2021
a0145c9
fix style
daniellepintz Aug 16, 2021
7d20612
fix style
daniellepintz Aug 16, 2021
0b7b29a
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 16, 2021
e7b03ef
Runif special
daniellepintz Aug 17, 2021
cb8567d
remove accel.teardown and move test to test_trainer.py
daniellepintz Aug 17, 2021
bb9be0e
delete ddp test
daniellepintz Aug 17, 2021
7e6051a
Fix sphinx formatting
daniellepintz Aug 17, 2021
10bcd73
Merge branch 'error_handling' of https://github.com/daniellepintz/pyt…
daniellepintz Aug 17, 2021
8e13f1d
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
daniellepintz Aug 17, 2021
7def3fe
update changelog
daniellepintz Aug 17, 2021
f7ec908
Fix style in unrleated file to make linter happy
daniellepintz Aug 17, 2021
37bb48b
Update pytorch_lightning/trainer/trainer.py
daniellepintz Aug 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added error handling including calling of `on_keyboard_interrupt()` and `on_exception()` for all entrypoints (fit, validate, test, predict) ([#8819](https://github.com/PyTorchLightning/pytorch-lightning/pull/8819))


- Added a flavor of `training_step` that takes `dataloader_iter` as an argument ([#8807](https://github.com/PyTorchLightning/pytorch-lightning/pull/8807))


- Added `state_id` property to the `Callback` base class ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def disable(self):
self.enable = False

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important :-)
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) # important
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
Expand Down
131 changes: 89 additions & 42 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import warnings
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from weakref import proxy

import torch
Expand Down Expand Up @@ -486,6 +486,34 @@ def _setup_on_init(self, num_sanity_val_steps: int) -> None:

self.num_predict_batches = []

def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
r"""
Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
as all errors should funnel through them

Args:
trainer_fn: one of (fit, validate, test, predict)
*args: positional arguments to be passed to the `trainer_fn`
**kwargs: keyword arguments to be passed to `trainer_fn`
"""
try:
return trainer_fn(*args, **kwargs)
except KeyboardInterrupt:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self.on_keyboard_interrupt()
except BaseException:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaing processes, kill otherwise
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
self._on_exception()
# reset bookkeeping
self.state.stage = None
raise

def fit(
self,
model: "pl.LightningModule",
Expand All @@ -508,18 +536,27 @@ def fit(

datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`.
"""
Trainer._log_api_event("fit")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True

if train_dataloader is not None:
rank_zero_deprecation(
"`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
)
train_dataloaders = train_dataloader
self._call_and_handle_interrupt(self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule)

def _fit_impl(
self,
model: "pl.LightningModule",
train_dataloaders: Optional[Union[TRAIN_DATALOADERS, LightningDataModule]] = None,
val_dataloaders: Optional[EVAL_DATALOADERS] = None,
datamodule: Optional[LightningDataModule] = None,
) -> None:
Trainer._log_api_event("fit")

self.state.fn = TrainerFn.FITTING
self.state.status = TrainerStatus.RUNNING
self.training = True

# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloaders, LightningDataModule):
datamodule = train_dataloaders
Expand Down Expand Up @@ -572,6 +609,22 @@ def validate(
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc.
The length of the list corresponds to the number of validation dataloaders used.
"""
if val_dataloaders is not None:
rank_zero_deprecation(
"`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
" Use `trainer.validate(dataloaders)` instead."
)
dataloaders = val_dataloaders
return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule)

def _validate_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
# --------------------
# SETUP HOOK
# --------------------
Expand All @@ -582,12 +635,6 @@ def validate(
self.state.status = TrainerStatus.RUNNING
self.validating = True

if val_dataloaders is not None:
rank_zero_deprecation(
"`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
" Use `trainer.validate(dataloaders)` instead."
)
dataloaders = val_dataloaders
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand Down Expand Up @@ -651,6 +698,22 @@ def test(
:meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc.
The length of the list corresponds to the number of test dataloaders used.
"""
if test_dataloaders is not None:
rank_zero_deprecation(
"`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
" Use `trainer.test(dataloaders)` instead."
)
dataloaders = test_dataloaders
return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)

def _test_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
ckpt_path: Optional[str] = None,
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
) -> _EVALUATE_OUTPUT:
# --------------------
# SETUP HOOK
# --------------------
Expand All @@ -661,12 +724,6 @@ def test(
self.state.status = TrainerStatus.RUNNING
self.testing = True

if test_dataloaders is not None:
rank_zero_deprecation(
"`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."
" Use `trainer.test(dataloaders)` instead."
)
dataloaders = test_dataloaders
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(dataloaders, LightningDataModule):
datamodule = dataloaders
Expand Down Expand Up @@ -728,7 +785,18 @@ def predict(
Returns:
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
"""
return self._call_and_handle_interrupt(
self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path
)

def _predict_impl(
self,
model: Optional["pl.LightningModule"] = None,
dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
datamodule: Optional[LightningDataModule] = None,
return_predictions: Optional[bool] = None,
ckpt_path: Optional[str] = None,
) -> Optional[_PREDICT_OUTPUT]:
# --------------------
# SETUP HOOK
# --------------------
Expand Down Expand Up @@ -1067,29 +1135,8 @@ def _run_train(self) -> None:

self.reset_train_val_dataloaders(model)

try:
# reset trainer on this loop and all child loops in case user connected a custom loop
self.fit_loop.trainer = self
self.fit_loop.run()
except KeyboardInterrupt:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not self.interrupted:
self.state.status = TrainerStatus.INTERRUPTED
self.on_keyboard_interrupt()
# same treatment as below
self.accelerator.on_train_end()
except BaseException:
self.state.status = TrainerStatus.INTERRUPTED
if distributed_available() and self.world_size > 1:
# try syncing remaing processes, kill otherwise
self.training_type_plugin.reconciliate_processes(traceback.format_exc())
# give accelerators a chance to finish
self.accelerator.on_train_end()
self._on_exception()
# reset bookkeeping
self.state.stage = None
raise
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
self.fit_loop.trainer = self
self.fit_loop.run()

def _run_evaluate(self) -> _EVALUATE_OUTPUT:
if not self.is_global_zero and self.progress_bar_callback is not None:
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers/runif.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __new__(
ipu: if IPU is available
horovod: if Horovod is installed
horovod_nccl: if Horovod is installed with NCCL support
skip_windows: skip test for Windows platform (typically fo some limited torch functionality)
skip_windows: skip test for Windows platform (typically for some limited torch functionality)
special: running in special mode, outside pytest suit
fairscale: if `fairscale` module is required to run the test
fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test
Expand Down
34 changes: 34 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1834,3 +1834,37 @@ def current_memory():
trainer_2.fit(model)

assert current_memory() <= initial


class TrainerStagesErrorsModel(BoringModel):
def on_train_start(self) -> None:
raise Exception("Error during train")

def on_validation_start(self) -> None:
raise Exception("Error during validation")

def on_test_start(self) -> None:
raise Exception("Error during test")

def on_predict_start(self) -> None:
raise Exception("Error during predict")


@pytest.mark.parametrize(
"accelerator,num_processes",
[
(None, 1),
pytest.param("ddp_cpu", 2, marks=RunIf(skip_windows=True)),
],
)
def test_error_handling_all_stages(tmpdir, accelerator, num_processes):
model = TrainerStagesErrorsModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator=accelerator, num_processes=num_processes, fast_dev_run=True)
with pytest.raises(Exception, match=r"Error during train"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.fit(model)
with pytest.raises(Exception, match=r"Error during validation"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.validate(model)
with pytest.raises(Exception, match=r"Error during test"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.test(model)
with pytest.raises(Exception, match=r"Error during predict"), patch("pytorch_lightning.Trainer._on_exception"):
trainer.predict(model, model.val_dataloader(), return_predictions=False)
4 changes: 2 additions & 2 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def add_arguments_to_parser(self, parser):

class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
raise KeyboardInterrupt()
raise Exception("Error on fit start")


@pytest.mark.parametrize("logger", (False, True))
Expand All @@ -562,7 +562,7 @@ def on_fit_start(self):
),
)
def test_cli_ddp_spawn_save_config_callback(tmpdir, logger, trainer_kwargs):
with mock.patch("sys.argv", ["any.py"]), pytest.raises(KeyboardInterrupt):
with mock.patch("sys.argv", ["any.py"]), pytest.raises(Exception, match=r"Error on fit start"):
LightningCLI(
EarlyExitTestModel,
trainer_defaults={
Expand Down