-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Standardize all stateful components on state_dict
/load_state_dict
#11429
Comments
I'm strongly in favor of this! This would clarify any confusion around checkpointing (when to use on_save/load_checkpoint vs state_dict) while also following the most familiar PyTorch conventions (state_dict/load_state_dict). This also allows the framework to be more judicious around which components need access to the global checkpoint dict before saving/loading to minimize concerns like what order they take effect in. Nit: Given that the Lightning framework dictates the Callback, DataModule, and PrecisionPlugin interfaces, and if the |
good call, edited! |
Old related issue/PR that would probably also benefit from this standardization? #6361 |
I'm a bit confused about what this is actually proposing, on one hand snippets use
but then you mention that
will become
So for each of these components, what does this propose to the following use-cases?
|
Moved code changes for individual components, checkpoint_connector to the bottom appendix to make this more clear/readable. Want to get alignment on the convention for Motivation point 1
Correct, this is only for aligning all component APIs for Motivation point 1. No changes to Motivation point 2.
Updated the Callbacks save/load section snippets to include the method name changes from
These should map to the Motivation point 1. v.s. Motivation point 2. distinctions:
|
Hey @jjenniferdai. I think this adds value but would need to be properly documented to avoid confusion on which method takes either the full checkpoint or the component associated state. |
Thumbs up for this! "checkpoint" is a loaded term that represents various things. |
Moving this back to 1.6 if that's ok - the 2 remaining linked PRs are already reviewed and just need to be merged |
Proposed refactor
Standardize all stateful components on
state_dict
/load_state_dict
.Background
PyTorch convention uses
state_dict
/load_state_dict
for gathering and loading object state. In lightning, some components follow this convention, while other components do not (see Appedix: current state gathering section)Motivation
Each component should contribute saving/loading their own state with the same APIs.
state_dict
/load_state_dict
)CheckpointHooks
:on_save/load_checkpoint
)This issue is focused on aligning all components on 1. Following this convention will allow consistency across Lightning components and consistency with PyTorch conventions. Any stateful component can simply implement their own
state_dict
/load_state_dict
methods to contribute their own state.For now 2 is only adjusted as needed (Callbacks).
Pitch
Lightning already has this Stateful Protocol in
auto_restart.py
. We can move this_SupportsStateDict
out to the more centralcore/hooks.py
file:https://github.com/PyTorchLightning/pytorch-lightning/blob/34c62da37dc1ed9a1de7e023c690c7528ee56c60/pytorch_lightning/utilities/auto_restart.py#L638-L646
Additional context
Part of #7740
Appendix: current state gathering:
Specifically, current components contribute state in the following different ways:
on_save/load_checkpoint
[DataModule, PrecisionPlugin, Callbacks]a. [DataModule, PrecisionPlugin] use
on_save/load_checkpoint
fromCheckpointHooks
https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/core/hooks.py#L765-L807b. [Callbacks] contribute saving/loading their own state with different
on_save/load_checkpoint
hook signatures and functionality https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/callbacks/base.py#L293-L323c. though [Loops] falls under 3. below, noting here that Loops also have their own
on_save/load_checkpoint
methods with different signatures. https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/loops/base.py#L253-L262state_dict
/load_state_dict
calls [Optimizers, LR schedulers]Appendix: component, checkpoint_connector changes
Save aligning on state_dict
dump_checkpoint
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L310-L393
becomes:
Load aligning on load_state_dict
see component Load sections below
Callbacks:
Base Class
https://github.com/PyTorchLightning/pytorch-lightning/blob/59a7ba760548baadf6dbb30864b54cb01c7225a3/pytorch_lightning/callbacks/base.py#L293-L323
becomes
BC:
on_save_checkpoint
returnsNone
instead ofdict
BC:
on_load_checkpoint
arg takes entirecheckpoint
dict instead ofcallback_state
Save
https://github.com/PyTorchLightning/pytorch-lightning/blob/85304d4672a9ed24a16f7f5b2abaa34148ab86f4/pytorch_lightning/trainer/trainer.py#L1603
becomes
Load
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/trainer.py#L1652
becomes
restore_callbacks
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L205
becomes
Callback classes
update
timer
,pruning
,model_checkpoint
,finetuning
,early_stopping
to usestate_dict
/load_state_dict
instead ofon_save/load_checkpoint
.Precision Plugin:
Base Class
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py
add dummy
state_dict/load_state_dict
:Load
restore_training_state
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L189-L190
becomes
Precision Plugin classes
Update
apex_amp
,native_amp
to usestate_dict/load_state_dict
instead ofon_save/load_checkpoint
Datamodule:
Base Class
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/datamodule.py
add dummy
state_dict/load_state_dict
:Load
restore_datamodule
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L151
becomes
If you enjoy Lightning, check out our other projects! ⚡
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.
Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.
Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.
cc @justusschock @awaelchli @akihironitta @rohitgr7 @ananthsub @ninginthecloud
The text was updated successfully, but these errors were encountered: