Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Standardize all stateful components on state_dict/load_state_dict #11429

Closed
jjenniferdai opened this issue Jan 11, 2022 · 8 comments · Fixed by #11469, #11637, #11638, #11887 or #11998
Closed
Assignees
Labels
checkpointing Related to checkpointing refactor
Milestone

Comments

@jjenniferdai
Copy link
Contributor

jjenniferdai commented Jan 11, 2022

Proposed refactor

Standardize all stateful components on state_dict/load_state_dict.

Background

PyTorch convention uses state_dict/load_state_dict for gathering and loading object state. In lightning, some components follow this convention, while other components do not (see Appedix: current state gathering section)

Motivation

Each component should contribute saving/loading their own state with the same APIs.

  1. independently contributing/loading one’s own local component state (aligning with PyTorch primitives: state_dict/load_state_dict)
  2. operating and depending on global component state (Lightning CheckpointHooks: on_save/load_checkpoint)

This issue is focused on aligning all components on 1. Following this convention will allow consistency across Lightning components and consistency with PyTorch conventions. Any stateful component can simply implement their own state_dict/load_state_dict methods to contribute their own state.
For now 2 is only adjusted as needed (Callbacks).

Pitch

Lightning already has this Stateful Protocol in auto_restart.py. We can move this _SupportsStateDict out to the more central core/hooks.py file:

https://github.com/PyTorchLightning/pytorch-lightning/blob/34c62da37dc1ed9a1de7e023c690c7528ee56c60/pytorch_lightning/utilities/auto_restart.py#L638-L646

@runtime_checkable
class Stateful(Protocol):
    """This class is used to detect if an object is stateful using `isinstance(obj, Stateful)`."""

    def state_dict(self) -> Dict[str, Any]:
        ...

    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        ...

Additional context

Part of #7740

Appendix: current state gathering:

Specifically, current components contribute state in the following different ways:

  1. Some components contribute saving/loading their own state with only on_save/load_checkpoint [DataModule, PrecisionPlugin, Callbacks]
    a. [DataModule, PrecisionPlugin] use on_save/load_checkpoint from CheckpointHooks https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/core/hooks.py#L765-L807
    b. [Callbacks] contribute saving/loading their own state with different on_save/load_checkpoint hook signatures and functionality https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/callbacks/base.py#L293-L323
    c. though [Loops] falls under 3. below, noting here that Loops also have their own on_save/load_checkpoint methods with different signatures. https://github.com/PyTorchLightning/pytorch-lightning/blob/948cfd24de4f64a2980395581f15544e5e37eab0/pytorch_lightning/loops/base.py#L253-L262
  2. Some components contribute saving/loading their own state with only state_dict/load_state_dict calls [Optimizers, LR schedulers]
  3. Some components have both [LightningModule, Loops]

Appendix: component, checkpoint_connector changes

Save aligning on state_dict

dump_checkpoint
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L310-L393
becomes:

def dump_checkpoint(self, weights_only: bool = False) -> dict:
   ...
   # dump callbacks
   # checkpoint["callbacks"] = self.trainer._call_callbacks_on_save_checkpoint(checkpoint)
   # becomes
   checkpoint["callbacks"] = self.trainer._call_callbacks_state_dict()
   
   ...
   # precision plugin
   # self.trainer.precision_plugin.on_save_checkpoint(checkpoint)
   # becomes
   prec_plugin = self.trainer.precision_plugin
   checkpoint[prec_plugin.__class__.__name__] = self.trainer.precision_plugin.state_dict()

   ...
   # give the model a chance to dump a few things
   # model.on_save_checkpoint(checkpoint)
   # if self.trainer.datamodule is not None:
   #    self.trainer.datamodule.on_save_checkpoint(checkpoint)
   # becomes
   
   # datamodule state
   dm = self.trainer.datamodule
   if dm is not None:
       checkpoint[dm.__class__.__name__] = dm.state_dict()
       
   # on_save_checkpoint calls
   model.on_save_checkpoint(checkpoint)
   dm.on_save_checkpoint(checkpoint)
   prec_plugin.on_save_checkpoint(checkpoint)
   for callback in self.trainer.callbacks:
       callback.on_save_checkpoint(self.trainer, model, checkpoint)

Load aligning on load_state_dict

see component Load sections below

Callbacks:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/59a7ba760548baadf6dbb30864b54cb01c7225a3/pytorch_lightning/callbacks/base.py#L293-L323

becomes
BC: on_save_checkpoint returns None instead of dict
BC: on_load_checkpoint arg takes entire checkpoint dict instead of callback_state

def state_dict(self) -> Dict[str, Any]:
    return {}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

def on_save_checkpoint(
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
    ) -> None:
    """Called by Lightning when saving a checkpoint to give you a chance to store or customize anything
    else you might want to save.
    Args:
        trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
        pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
        checkpoint: the checkpoint dictionary that will be saved.
    """
    pass

def on_load_checkpoint(
    self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> None:
    """Called by Lightning when loading a checkpoint to give you a chance to reload or customize anything
    else you may have saved in on_save_checkpoint.
    Args:
        trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance.
        pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance.
        checkpoint: entire loaded checkpoint dictionary
    """
    pass

Save

https://github.com/PyTorchLightning/pytorch-lightning/blob/85304d4672a9ed24a16f7f5b2abaa34148ab86f4/pytorch_lightning/trainer/trainer.py#L1603

def _call_callbacks_on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[str, dict]:
    ...
    state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)

becomes

def _call_callbacks_state_dict(self) -> Dict[str, dict]:
   ...
   state = callback.state_dict()

Load

https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/trainer.py#L1652

def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
    ...
    callback.on_load_checkpoint(self, self.lightning_module, state)

becomes

def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
    ...
    callback.load_state_dict(state)

restore_callbacks
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L205
becomes

for callback in self.trainer.callbacks:
    callback.on_load_checkpoint(self.trainer, self.trainer.lightning_module, self._loaded_checkpoint)
self.trainer._call_callbacks_load_state_dict(self._loaded_checkpoint)

Callback classes

update timer, pruning, model_checkpoint, finetuning, early_stopping to use state_dict/load_state_dict instead of on_save/load_checkpoint.

Precision Plugin:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py
add dummy state_dict/load_state_dict:

def state_dict(self) -> Dict[str, Any]:
    return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

Load

restore_training_state
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L189-L190
becomes

# restore precision plugin (scaler etc.)
prec_plugin = self.trainer.precision_plugin
prec_plugin.on_load_checkpoint(self._loaded_checkpoint)
if prec_plugin.__class__.__name__ in self._loaded_checkpoint:
    prec_plugin.load_state_dict(self._loaded_checkpoint[prec_plugin.__class__.__name__])

Precision Plugin classes

Update apex_amp, native_amp to use state_dict/load_state_dict instead of on_save/load_checkpoint

Datamodule:

Base Class

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/datamodule.py
add dummy state_dict/load_state_dict:

def state_dict(self) -> Dict[str, Any]:
    return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
    pass

Load

restore_datamodule
https://github.com/PyTorchLightning/pytorch-lightning/blob/06b8f82b8a97e7e5653486731751521478f58cce/pytorch_lightning/trainer/connectors/checkpoint_connector.py#L151
becomes

datamodule.on_load_checkpoint(self._loaded_checkpoint)
if datamodule.__class__.__name__ in self._loaded_checkpoint:
    datamodule.load_state_dict(self._loaded_checkpoint[datamodule.__class__.__name__])

If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @justusschock @awaelchli @akihironitta @rohitgr7 @ananthsub @ninginthecloud

@ananthsub ananthsub added the checkpointing Related to checkpointing label Jan 11, 2022
@ananthsub
Copy link
Contributor

ananthsub commented Jan 11, 2022

I'm strongly in favor of this! This would clarify any confusion around checkpointing (when to use on_save/load_checkpoint vs state_dict) while also following the most familiar PyTorch conventions (state_dict/load_state_dict).

This also allows the framework to be more judicious around which components need access to the global checkpoint dict before saving/loading to minimize concerns like what order they take effect in.

Nit: Given that the Lightning framework dictates the Callback, DataModule, and PrecisionPlugin interfaces, and if the isinstance(obj, Stateful) checks are cumbersome, we could offer dummy implementations of state_dict/load_state_dict on those base class with the expectation that they are meant to be overridden. But I personally prefer keeping the base classes as lean as possible, and I'm not sure the isinstance checks are really so bad.

@jjenniferdai
Copy link
Contributor Author

Nit: If we don't want to deal with the isinstance checks everywhere, we could offer dummy implementations of state_dict/load_state_dict on the callback, datamodule, and precision plugin interfaces, given that the Lightning framework dictates these interfaces.

good call, edited!

@awaelchli
Copy link
Contributor

  1. I see the value, I am slightly in favor of this
  2. Need to be careful with this, checkpointing bugs are the worst because in most cases we can't fix "broken" checkpoints.
  3. To be clear that I don't misunderstand, the proposal is not to remove the on_x_checkpoint hooks, yes? Because these are the ones that users implement and they express in plain simple words what they do. In contrast, state_dict is not very clear for a user what it does.

Old related issue/PR that would probably also benefit from this standardization? #6361

@carmocca
Copy link
Contributor

I'm a bit confused about what this is actually proposing, on one hand snippets use

self.trainer._call_callbacks_on_save_checkpoint(checkpoint)

but then you mention that

state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)

will become

state = callback.state_dict()

So for each of these components, what does this propose to the following use-cases?

  • User wants to get the state
  • User wants to set a state
  • User wants to customize the state generation
  • User wants to customize the state loading

@jjenniferdai
Copy link
Contributor Author

jjenniferdai commented Jan 12, 2022

Moved code changes for individual components, checkpoint_connector to the bottom appendix to make this more clear/readable. Want to get alignment on the convention for Motivation point 1

To be clear that I don't misunderstand, the proposal is not to remove the on_x_checkpoint hooks, yes?

Correct, this is only for aligning all component APIs for Motivation point 1. No changes to Motivation point 2. on_save/load_checkpoint yet except for necessary accommodations (Callbacks)

_call_callbacks_on_save_checkpoint

Updated the Callbacks save/load section snippets to include the method name changes from _call_callbacks_on_save_checkpoint --> _call_callbacks_state_dict

So for each of these components, what does this propose to the following use-cases?

These should map to the Motivation point 1. v.s. Motivation point 2. distinctions:

  • User wants to get/set the state (Motivation point 1. use state_dict/load_state_dict)
  • User wants to customize the state generation/loading (Motivation point 2. use on_save/load_checkpoint)

@tchaton
Copy link
Contributor

tchaton commented Jan 12, 2022

Hey @jjenniferdai. I think this adds value but would need to be properly documented to avoid confusion on which method takes either the full checkpoint or the component associated state.

@zzzwen
Copy link

zzzwen commented Jan 13, 2022

Thumbs up for this! "checkpoint" is a loaded term that represents various things.
state_dict()/load_state_dict() is standard in Pytorch core, and for someone that usually deals with nn.Module or torch.optimizer, it can be more natural to change state_dict.

@jjenniferdai jjenniferdai self-assigned this Feb 9, 2022
@ananthsub ananthsub mentioned this issue Feb 12, 2022
8 tasks
@carmocca carmocca added this to the 1.6 milestone Feb 14, 2022
@carmocca carmocca moved this to In Progress in Frameworks Planning Feb 14, 2022
Repository owner moved this from In Progress to Done in Frameworks Planning Mar 19, 2022
@awaelchli awaelchli reopened this Mar 19, 2022
@carmocca carmocca modified the milestones: 1.6, 1.7 Mar 21, 2022
@carmocca carmocca moved this from Done to In Progress in Frameworks Planning Mar 21, 2022
@jjenniferdai
Copy link
Contributor Author

Moving this back to 1.6 if that's ok - the 2 remaining linked PRs are already reviewed and just need to be merged
cc @carmocca @ananthsub @awaelchli @tchaton

@jjenniferdai jjenniferdai modified the milestones: 1.7, 1.6 Mar 24, 2022
Repository owner moved this from In Progress to Done in Frameworks Planning Mar 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment