Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add error handling for all trainer entry points #8819

Merged
merged 36 commits into from
Aug 18, 2021

Conversation

daniellepintz
Copy link
Contributor

@daniellepintz daniellepintz commented Aug 9, 2021

What does this PR do?

Before this PR lightning only has error handling for trainer.fit(). This PR moves the error handling to a higher level of abstraction so that it also applies to trainer.validate(), trainer.test(), and trainer.predict().
Fixes #8723

Does your PR introduce any breaking changes? If yes, please list them.

No

Before submitting

  • Was this discussed/approved via a GitHub issue? (not for typos and docs)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes? (if necessary)
  • Did you write any new necessary tests? (not for typos and docs)
  • Did you verify new and existing tests pass locally with your changes?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you update the CHANGELOG? (not for typos, docs, test updates, or internal minor changes/refactorings)

PR review

Anyone in the community is welcome to review the PR.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:

  • Is this pull request ready for review? (if not, please submit in draft mode)
  • Check that all items from Before submitting are resolved
  • Make sure the title is self-explanatory and the description concisely explains the PR
  • Add labels and milestones (and optionally projects) to the PR so it can be classified

Did you have fun?

Yes!

@pep8speaks
Copy link

pep8speaks commented Aug 9, 2021

Hello @daniellepintz! Thanks for updating this PR. We checked the lines you've touched for PEP 8 issues, and found:

There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻

Comment last updated at 2021-08-12 00:25:10 UTC

Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for working on this @daniellepintz!

on reading through the entry points again, all of fit/validate/test/predict can raise exceptions before _run is called. For the error handling to be complete, I think applying the try/catch around each of them directly (with something like _fit_impl as opposed to _run_impl) will avoid the risk of missing other exceptions. What do you think?

the accelerator is also calling on_train_end in the error handling - i think this should be calling accelerator.teardown() instead

cc @carmocca @awaelchli @yifuwang @tchaton

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Aug 10, 2021

Codecov Report

Merging #8819 (37bb48b) into master (522df2b) will decrease coverage by 4%.
The diff coverage is 97%.

@@           Coverage Diff           @@
##           master   #8819    +/-   ##
=======================================
- Coverage      93%     89%    -4%     
=======================================
  Files         176     176            
  Lines       14402   14410     +8     
=======================================
- Hits        13343   12771   -572     
- Misses       1059    1639   +580     

Copy link
Contributor

@ananthsub ananthsub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to @yifuwang 's comment, the context manager to consolidate the error handling would be really nice. that'll make it less likely to miss handling across the various entry points

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
@daniellepintz
Copy link
Contributor Author

thanks for the review @ananthsub and @yifuwang!! I have updated according to comments. Unfortunately I am still unable to test locally due to an error when I run python -m pytest -v tests/trainer/test_trainer_error_handling.py::test_error_handling_all_stages -> P436961746

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There seems to be a bit code duplication with this addition. If I see correctly, the only difference is the call to the trainer.fit/test etc.

Would it make sense to have an intermediate function for interrupt handling like so:

def call_and_handle_interrupt(self, trainer_fn, *args, **kwargs):
    try: 
         ...
         trainer_fn(*args, **kwargs)
     except:
     ....

and then in fit/test/etc. we do

self._call_and_handle_interrupt(self.fit_impl)

?

tests/trainer/test_trainer_error_handling.py Outdated Show resolved Hide resolved
pytorch_lightning/utilities/exceptions.py Outdated Show resolved Hide resolved
tests/trainer/test_trainer_error_handling.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
@carmocca carmocca added feature Is an improvement or enhancement refactor labels Aug 10, 2021
@carmocca carmocca added this to the v1.5 milestone Aug 10, 2021
@ananthsub
Copy link
Contributor

@daniellepintz from the CI:

____________________ test_spawn_predict_return_predictions ____________________

self = <pytorch_lightning.trainer.trainer.Trainer object at 0x000001B977E75460>
trainer_fn = <bound method Trainer._predict_impl of <pytorch_lightning.trainer.trainer.Trainer object at 0x000001B977E75460>>
args = (BoringModel(
  (layer): Linear(in_features=32, out_features=2, bias=True)
), <torch.utils.data.dataloader.DataLoader object at 0x000001B977F9D100>, None, True, None)
kwargs = {}

    def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any):
        r"""
        Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
        as all errors should funnel through them
    
        Args:
            trainer_fn: one of (fit, validate, test, predict)
    
            *args/**kwargs: args to be passed to trainer_fn
        """
        try:
>           return trainer_fn(*args, **kwargs)

D:\a\pytorch-lightning\pytorch-lightning\pytorch_lightning\trainer\trainer.py:500: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pytorch_lightning.trainer.trainer.Trainer object at 0x000001B977E75460>
model = BoringModel(
  (layer): Linear(in_features=32, out_features=2, bias=True)
)
dataloaders = <torch.utils.data.dataloader.DataLoader object at 0x000001B977F9D100>
datamodule = None, return_predictions = True, ckpt_path = None

    def _predict_impl(
        self,
        model: Optional["pl.LightningModule"] = None,
        dataloaders: Optional[Union[EVAL_DATALOADERS, LightningDataModule]] = None,
        datamodule: Optional[LightningDataModule] = None,
        return_predictions: Optional[bool] = None,
        ckpt_path: Optional[str] = None,
    ) -> Optional[_PREDICT_OUTPUT]:
        # --------------------
        # SETUP HOOK
        # --------------------
        Trainer._log_api_event("predict")
    
        self.state.fn = TrainerFn.PREDICTING
        self.state.status = TrainerStatus.RUNNING
        self.predicting = True
    
>       self.predict_loop.return_predictions = return_predictions

D:\a\pytorch-lightning\pytorch-lightning\pytorch_lightning\trainer\trainer.py:813: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <pytorch_lightning.loops.dataloader.prediction_loop.PredictionLoop object at 0x000001B977F9D850>
return_predictions = True

    @return_predictions.setter
    def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
        # `DDPSpawnPlugin` plugins and derivatives don't support return predictions.
        is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin)
        if return_predictions and is_ddp_spawn:
>           raise MisconfigurationException(
                "`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. "
                f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
            )
E           pytorch_lightning.utilities.exceptions.MisconfigurationException: `return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. Found True with training_type_plugin <class 'pytorch_lightning.plugins.training_type.ddp_spawn.DDPSpawnPlugin'>.

D:\a\pytorch-lightning\pytorch-lightning\pytorch_lightning\loops\dataloader\prediction_loop.py:35: MisconfigurationException

https://github.com/PyTorchLightning/pytorch-lightning/blob/938a191406fff5f51fba03fcf824f22d8d23c2e0/pytorch_lightning/trainer/trainer.py#L700-L723

so you can set trainer.predict(..., return_predictions=False) in this test

@daniellepintz
Copy link
Contributor Author

Thanks @ananthsub, I saw that, am just waiting for a resolution on the accelerator issue before pushing another update

Copy link
Contributor

@awaelchli awaelchli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@daniellepintz thanks for working on this

I noticed that the PR is labelled refactor but I think we should not classify it as such and also add a changelog entry that the entry points are now fully guarded by the exception handling and that the on_keyboard_interrupt() callback hook will be called in all trainer stages.

pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
@mergify mergify bot added ready PRs ready to be merged and removed has conflicts labels Aug 17, 2021
@daniellepintz daniellepintz changed the title Ensure error handling works across different trainer entry points Add error handling for all trainer entry points Aug 17, 2021
pytorch_lightning/trainer/trainer.py Outdated Show resolved Hide resolved
@ananthsub
Copy link
Contributor

great work!

@daniellepintz
Copy link
Contributor Author

daniellepintz commented Aug 18, 2021

So I am in a bit of a conundrum where the PR says "1 conversation must be resolved before merging" but when I click on the conversation to be resolved it says "We went looking everywhere, but couldn’t find those commits." - this is probably because I force-pushed a commit earlier.. 😅😅😅

Does anyone know how to get around this? I tried Googling it but no luck.

@ananthsub
Copy link
Contributor

So I am in a bit of a conundrum where the PR says "1 conversation must be resolved before merging" but when I click on the conversation to be resolved it says "We went looking everywhere, but couldn’t find those commits." - this is probably because I force-pushed a commit earlier.. 😅😅😅

Does anyone know how to get around this? I tried Googling it but no luck.

@daniellepintz I resolved the conversation from the Conversation tab

@ananthsub ananthsub enabled auto-merge (squash) August 18, 2021 01:41
@ananthsub ananthsub merged commit bd13d39 into Lightning-AI:master Aug 18, 2021
@daniellepintz daniellepintz deleted the error_handling branch August 18, 2021 02:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement ready PRs ready to be merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[RFC] Ensure error handling is supported across all Trainer entry points
8 participants