From affae7af130590c29c72fe05238ec3ad4fd24bef Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 16 Mar 2022 18:41:45 -0700 Subject: [PATCH 1/6] first commit --- docs/source/common/checkpointing.rst | 26 +++++++++++++++++--------- docs/source/extensions/callbacks.rst | 22 +++++++++++++++++----- 2 files changed, 34 insertions(+), 14 deletions(-) diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 8d96bfb3c1822..446a3a231ad54 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -28,11 +28,27 @@ A Lightning checkpoint has everything needed to restore a training session inclu - LightningModule's state_dict - State of all optimizers - State of all learning rate schedulers -- State of all callbacks +- State of all callbacks (for stateful callbacks) +- State of datamodule (for stateful datamodules) - The hyperparameters used for that model if passed in as hparams (Argparse.Namespace) - State of Loops (if using Fault-Tolerant training) +Individual Component States +=========================== + +Each component can save and load their state by implementing PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. +For details on implementing your own stateful callbacks and datamodules, reference their docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. + + +Operating on Global Checkpoint Component States +=============================================== + +If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can add/delete/modify custom states in your checkpoints before they are being saved or loaded. +For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` +or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``. + + ***************** Checkpoint Saving ***************** @@ -102,14 +118,6 @@ If using custom saving functions cannot be avoided, we recommend using the :func model parallel distributed strategies such as deepspeed or sharded training. -Modifying Checkpoint When Saving and Loading -============================================ - -You can add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` -and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and -:meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``. - - Checkpointing Hyperparameters ============================= diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index d5f02ee5c1cdb..ef44e3ba6ce31 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -116,8 +116,8 @@ Persisting State ---------------- Some callbacks require internal state in order to function properly. You can optionally -choose to persist your callback's state as part of model checkpoint files using the callback hooks -:meth:`~pytorch_lightning.callbacks.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.Callback.on_load_checkpoint`. +choose to persist your callback's state as part of model checkpoint files using +:meth:`~pytorch_lightning.callbacks.Callback.state_dict` and :meth:`~pytorch_lightning.callbacks.Callback.load_state_dict`. Note that the returned state must be able to be pickled. When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough @@ -147,10 +147,10 @@ the following example. if self.what == "batches": self.state["batches"] += 1 - def on_load_checkpoint(self, trainer, pl_module, callback_state): - self.state.update(callback_state) + def load_state_dict(self, state_dict): + self.state.update(state_dict) - def on_save_checkpoint(self, trainer, pl_module, checkpoint): + def state_dict(self): return self.state.copy() @@ -422,12 +422,24 @@ on_exception .. automethod:: pytorch_lightning.callbacks.Callback.on_exception :noindex: +state_dict +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.callbacks.Callback.state_dict + :noindex: + on_save_checkpoint ~~~~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.callbacks.Callback.on_save_checkpoint :noindex: +load_state_dict +~~~~~~~~~~~~~~~~~~ + +.. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict + :noindex: + on_load_checkpoint ~~~~~~~~~~~~~~~~~~ From 5fd3b3628757dd5e8e54030ee51dcc65d4767c36 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Thu, 17 Mar 2022 09:54:11 -0700 Subject: [PATCH 2/6] Update docs/source/common/checkpointing.rst Co-authored-by: Rohit Gupta --- docs/source/common/checkpointing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 446a3a231ad54..b4ec8bf23fb65 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -44,7 +44,7 @@ For details on implementing your own stateful callbacks and datamodules, referen Operating on Global Checkpoint Component States =============================================== -If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can add/delete/modify custom states in your checkpoints before they are being saved or loaded. +If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can read/add/delete/modify custom states in your checkpoints before they are being saved or loaded. For this you can override :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint` and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` in your ``LightningModule`` or :meth:`~pytorch_lightning.callbacks.base.Callback.on_save_checkpoint` and :meth:`~pytorch_lightning.callbacks.base.Callback.on_load_checkpoint` methods in your ``Callback``. From 3443304010a6aabe357ef384dac83b4fe5020d42 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Thu, 17 Mar 2022 09:55:01 -0700 Subject: [PATCH 3/6] Update docs/source/common/checkpointing.rst Co-authored-by: Rohit Gupta --- docs/source/common/checkpointing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index b4ec8bf23fb65..34b00584835d2 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -37,7 +37,7 @@ A Lightning checkpoint has everything needed to restore a training session inclu Individual Component States =========================== -Each component can save and load their state by implementing PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. +Each component can save and load its state by implementing PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. For details on implementing your own stateful callbacks and datamodules, reference their docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. From 191eb3596f048d1f2c504082f9363665b5e6b811 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Sat, 19 Mar 2022 19:25:26 -0700 Subject: [PATCH 4/6] Update docs/source/common/checkpointing.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/common/checkpointing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 34b00584835d2..1c899d1f42987 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -37,7 +37,7 @@ A Lightning checkpoint has everything needed to restore a training session inclu Individual Component States =========================== -Each component can save and load its state by implementing PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. +Each component can save and load its state by implementing the PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. For details on implementing your own stateful callbacks and datamodules, reference their docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. From 80accfad8f022bd1f24fe941a833c51c1d837477 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Sat, 19 Mar 2022 19:25:40 -0700 Subject: [PATCH 5/6] Update docs/source/common/checkpointing.rst MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- docs/source/common/checkpointing.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/common/checkpointing.rst b/docs/source/common/checkpointing.rst index 1c899d1f42987..73496052793b6 100644 --- a/docs/source/common/checkpointing.rst +++ b/docs/source/common/checkpointing.rst @@ -38,7 +38,7 @@ Individual Component States =========================== Each component can save and load its state by implementing the PyTorch ``state_dict``, ``load_state_dict`` stateful protocol. -For details on implementing your own stateful callbacks and datamodules, reference their docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. +For details on implementing your own stateful callbacks and datamodules, refer to the individual docs pages at :doc:`callbacks <../extensions/callbacks>` and :doc:`datamodules <../extensions/datamodules>`. Operating on Global Checkpoint Component States From 9d78d4202d8497b7e273fc42862f453d90de9956 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 21 Mar 2022 18:23:01 +0100 Subject: [PATCH 6/6] Apply suggestions from code review --- docs/source/extensions/callbacks.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/extensions/callbacks.rst b/docs/source/extensions/callbacks.rst index ef44e3ba6ce31..19b43f4f4d127 100644 --- a/docs/source/extensions/callbacks.rst +++ b/docs/source/extensions/callbacks.rst @@ -423,7 +423,7 @@ on_exception :noindex: state_dict -~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~ .. automethod:: pytorch_lightning.callbacks.Callback.state_dict :noindex: @@ -435,7 +435,7 @@ on_save_checkpoint :noindex: load_state_dict -~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~ .. automethod:: pytorch_lightning.callbacks.Callback.load_state_dict :noindex: