Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Update checkpointing.rst and callbacks.rst for Stateful support #12351

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions docs/source/common/checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,27 @@ A Lightning checkpoint has everything needed to restore a training session inclu
- LightningModule's state_dict
- State of all optimizers
- State of all learning rate schedulers
- State of all callbacks
- State of all callbacks (for stateful callbacks)
- State of datamodule (for stateful datamodules)
- The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
- State of Loops (if using Fault-Tolerant training)


Individual Component States
===========================

Each component can save and load their state by implementing PyTorch ``state_dict``, ``load_state_dict`` stateful protocol.
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
For details on implementing your own stateful callbacks and datamodules, reference their docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`.
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved


Operating on Global Checkpoint Component States
===============================================

If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can add/delete/modify custom states in your checkpoints before they are being saved or loaded.
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule``
or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.
ananthsub marked this conversation as resolved.
Show resolved Hide resolved


*****************
Checkpoint Saving
*****************
Expand Down Expand Up @@ -102,14 +118,6 @@ If using custom saving functions cannot be avoided, we recommend using the :func
model parallel distributed strategies such as deepspeed or sharded training.


Modifying Checkpoint When Saving and Loading
============================================

You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and
:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``.


Checkpointing Hyperparameters
=============================

Expand Down
22 changes: 17 additions & 5 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ Persisting State
----------------

Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback's state as part of model checkpoint files using the callback hooks
:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`.
choose to persist your callback's state as part of model checkpoint files using
:meth:`~pytorch_lightning.callbacks.Callback.state_dict` and :meth:`~pytorch_lightning.callbacks.Callback.load_state_dict`.
Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
Expand Down Expand Up @@ -147,10 +147,10 @@ the following example.
if self.what == "batches":
self.state["batches"] += 1

def on_load_checkpoint(self, trainer, pl_module, callback_state):
self.state.update(callback_state)
def load_state_dict(self, state_dict):
self.state.update(state_dict)

def on_save_checkpoint(self, trainer, pl_module, checkpoint):
def state_dict(self):
return self.state.copy()


Expand Down Expand Up @@ -422,12 +422,24 @@ on_exception
.. automethod:: pytorch_lightning.callbacks.Callback.on_exception
:noindex:

state_dict
~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. automethod:: pytorch_lightning.callbacks.Callback.state_dict
:noindex:

on_save_checkpoint
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint
:noindex:

load_state_dict
~~~~~~~~~~~~~~~~~~
Borda marked this conversation as resolved.
Show resolved Hide resolved

.. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict
:noindex:

on_load_checkpoint
jjenniferdai marked this conversation as resolved.
Show resolved Hide resolved
~~~~~~~~~~~~~~~~~~

Expand Down