Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation metrics assumed to be logged within the first training epoch #6791

Closed
tmcclintock opened this issue Apr 1, 2021 · 8 comments
Closed
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task won't fix This will not be worked on

Comments

@tmcclintock
Copy link

🐛 Bug

In TrainLoop.on_train_end a call to check_checkpoint_callback is made. Within that method a call to on_validation_end is performed. As per the docs (and the fact that the ModelCheckpoint fires on on_validation_end), the expectation is to monitor validation metrics. However, if in the Trainer we set num_sanity_val_steps to 0 then validation metrics are never logged, resulting in a misconfiguration exception in _validate_monitor_key.

Note that this is only an issue on the first epoch -- after this the val keys appear in the callback metrics and this issue is moot.

Please reproduce using the BoringModel

To Reproduce

Use following BoringModel and post here

I cannot reproduce this with the BoringModel since it uses deprecated x_step methods (e.g. validation_step returns the loss rather than logs it). It should be updated to 1.2.6 in a different issue.

Expected behavior

If the model checkpoint only implements on_validation_end then it should only fire on that callback, not secretly in on_train_end. If it should fire in on_train_end it should either have a second monitor specific to the callback_metrics logged during training, or its logic should be moved out from under on_validation_end to a more general (less misleading) hook.

Note that the callbacks have access to the Trainer.state, so it is possible to move the ModelCheckpoint.on_validation_end logic into a higher level hook and leverage this state info. An elegant (imo) attribute to add to ModelCheckpoint could be monitor_state, so that for instance a user can say "monitor metric 'loss' but only while the trainer is in state 'train'".

class ModelCheckpoint(Callback):
    def __init__(
        self,
                ...
                monitor: Optional[str] = None,
                monitor_state: Optional[Union[str, List[str]] = None,  # must a subset of fit/validate/test/predict/etc.
                ...
        ):
                ...

Environment

On PL master (1.2.6)

  • PyTorch Version (e.g., 1.0): 1.7.1
  • OS (e.g., Linux): linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source): N/A
  • Python version: 3.7
  • CUDA/cuDNN version: 10.2
@tmcclintock tmcclintock added bug Something isn't working help wanted Open to be worked on labels Apr 1, 2021
@w-copper
Copy link

w-copper commented Apr 5, 2021

I also have this issue. I set the num_sanity_val_steps = 2, but I also get the MisconfigurationException.
The metrics in train_step_end can be moniterd but that in validation_step_end can not be moniterd.

@ananthsub
Copy link
Contributor

ananthsub commented Apr 5, 2021

I agree. There are 2 major failure cases with the existing error handling logic:

  1. If the training loop fails => no validation is run => no validation metric is logged for the checkpoint callback to monitor => the try/catch forces the checkpoint callback to run => checkpoint callback looks up the monitor and fails with a misconfiguration error. Even worse, this obscures the original error message

  2. if a subset of ranks fails, the failure => try/catch => checkpointing => checkpoint hangs during the broadcast here: https://github.com/PyTorchLightning/pytorch-lightning/blob/22a266d8b8cf57455cc863e20491e416ec635ba7/pytorch_lightning/callbacks/model_checkpoint.py#L724-L730
    I think we should:

  • Remove the error handling for "graceful" shutdown because it misses these cases
  • Move the special casing for early stopping and checkpointing out of the training loop and into the respective callbacks

Rather than have a monitor_state on the checkpoint callback, I'd prefer to have individual checkpoint callbacks, each tracking a single monitor and list of topK models for it. Otherwise with the monitor state, the tracking logic for best model across each of training+validation, multiple monitor values, and multiple save_top_k values becomes messy.

Related RFC: #6504

@awaelchli @shuyingsunshine21 @carmocca

@tchaton tchaton added the priority: 1 Medium priority task label Apr 6, 2021
@carmocca
Copy link
Contributor

carmocca commented Apr 7, 2021

Remove the error handling for "graceful" shutdown because it misses these cases

What do you mean with this exactly? Do you not want to try to save on keyboard interrupt?

I agree with everything else

@tmcclintock
Copy link
Author

@carmocca I think what @ananthsub is saying is that when in e.g. a DDP setting, if a subset of all the ranks fail (for whatever reason) then you get a hang when you call broadcast within the ModelCheckpoint callback.

Personally, I feel like that should be handled in a separate issue, but, yall would know better than me.

@ananthsub
Copy link
Contributor

@tmcclintock I agree. there are 2 issues here:

  1. The error handling, already mentioned above
  2. The training loop should not force call these callback methods' on_validation_end . There are existing TODOs to incorporate that logic into the callbacks themselves. The training loop should not explicitly run these functions @justusschock @awaelchli @carmocca

@stale
Copy link

stale bot commented May 10, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label May 10, 2021
@carmocca carmocca added this to the v1.4 milestone May 10, 2021
@stale stale bot removed the won't fix This will not be worked on label May 10, 2021
@edenlightning edenlightning removed this from the v1.4 milestone Jul 1, 2021
@stale
Copy link

stale bot commented Aug 1, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Aug 1, 2021
@carmocca
Copy link
Contributor

carmocca commented Aug 1, 2021

I believe this was fixed with the addition of #8389

@carmocca carmocca closed this as completed Aug 1, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 1 Medium priority task won't fix This will not be worked on
Projects
None yet
Development

No branches or pull requests

6 participants