Skip to content

Commit

Permalink
Dedicated docs page for distributed checkpoints (Trainer) (#19299)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Jan 17, 2024
1 parent 6655c4d commit 93c1ab0
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 89 deletions.
6 changes: 3 additions & 3 deletions docs/source-fabric/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to t
The memory consumption for training is generally made up of

1. the model parameters,
2. the layer activations (forward) and
3. the gradients (backward).
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
2. the layer activations (forward),
3. the gradients (backward) and
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter).

|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ The distributed checkpoint format is the default when you train with the :doc:`F
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
This reduces memory peaks and speeds up the saving to disk.
The resulting checkpoint folder will have this structure:

.. collapse:: Full example

Expand Down Expand Up @@ -103,6 +102,7 @@ The resulting checkpoint folder will have this structure:
├── __1_0.distcp
├── __2_0.distcp
├── __3_0.distcp
├── .metadata
└── meta.pt
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
Expand Down
12 changes: 7 additions & 5 deletions docs/source-pytorch/advanced/model_parallel/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Even a single H100 GPU with 80 GB of VRAM (the biggest today) is not enough to t
The memory consumption for training is generally made up of

1. the model parameters,
2. the layer activations (forward) and
3. the gradients (backward).
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter),
2. the layer activations (forward),
3. the gradients (backward) and
4. the optimizer states (e.g., Adam has two additional exponential averages per parameter).

|
Expand Down Expand Up @@ -200,7 +200,8 @@ Before:
class LanguageModel(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64) # 1B parameters
# 1B parameters
self.model = Transformer(vocab_size=vocab_size, nlayers=32, nhid=4096, ninp=1024, nhead=64)
After:

Expand Down Expand Up @@ -397,13 +398,14 @@ The resulting checkpoint folder will have this structure:
├── .metadata
├── __0_0.distcp
├── __1_0.distcp
...
└── meta.pt
The “sharded” checkpoint format is the most efficient to save and load in Lightning.

**Which checkpoint format should I use?**

- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to convert the sharded checkpoint into a regular checkpoint file.
- ``state_dict_type="sharded"``: Use for pre-training very large models. It is fast and uses less memory, but it is less portable. An extra step is needed to :doc:`convert the sharded checkpoint into a regular checkpoint file <../../common/checkpointing_expert>`.
- ``state_dict_type="full"``: Use when pre-training small to moderately large models (less than 10B parameters), when fine-tuning, and when portability is required.


Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/checkpointing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Checkpointing

.. displayitem::
:header: Distributed checkpoints
:description: Customize checkpointing for custom distributed strategies and accelerators.
:description: Save and load very large models efficiently with distributed checkpoints
:col_css: col-md-4
:button_link: checkpointing_expert.html
:height: 150
Expand Down
171 changes: 92 additions & 79 deletions docs/source-pytorch/common/checkpointing_expert.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,121 +6,134 @@
Distributed checkpoints (expert)
################################

*********************************
Writing your own Checkpoint class
*********************************
Generally, the bigger your model is, the longer it takes to save a checkpoint to disk.
With distributed checkpoints (sometimes called sharded checkpoints), you can save and load the state of your training script with multiple GPUs or nodes more efficiently, avoiding memory issues.

We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.

----

***********************
Customize Checkpointing
***********************

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
*****************************
Save a distributed checkpoint
*****************************

Lightning supports modifying the checkpointing save/load functionality through the ``CheckpointIO``. This encapsulates the save/load logic
that is managed by the ``Strategy``. ``CheckpointIO`` is different from :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_save_checkpoint`
and :meth:`~lightning.pytorch.core.hooks.CheckpointHooks.on_load_checkpoint` methods as it determines how the checkpoint is saved/loaded to storage rather than
what's saved in the checkpoint.
The distributed checkpoint format can be enabled when you train with the :doc:`FSDP strategy <../advanced/model_parallel/fsdp>`.

.. code-block:: python
.. TODO:: I don't understand this...
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
******************************
Built-in Checkpoint IO Plugins
******************************
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
strategy = FSDPStrategy(state_dict_type="sharded")
.. list-table:: Built-in Checkpoint IO Plugins
:widths: 25 75
:header-rows: 1
# 2. Pass the strategy to the Trainer
trainer = L.Trainer(devices=2, strategy=strategy, ...)
* - Plugin
- Description
* - :class:`~lightning.pytorch.plugins.io.TorchCheckpointIO`
- CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints
respectively, common for most use cases.
* - :class:`~lightning.pytorch.plugins.io.XLACheckpointIO`
- CheckpointIO that utilizes ``xm.save`` to save checkpoints for TPU training strategies.
* - :class:`~lightning.pytorch.plugins.io.AsyncCheckpointIO`
- ``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
# 3. Run the trainer
trainer.fit(model)
***************************
Custom Checkpoint IO Plugin
***************************
With ``state_dict_type="sharded"``, each process/GPU will save its own file into a folder at the given path.
This reduces memory peaks and speeds up the saving to disk.

``CheckpointIO`` can be extended to include your custom save/load functionality to and from a path. The ``CheckpointIO`` object can be passed to either a ``Trainer`` directly or a ``Strategy`` as shown below:
.. collapse:: Full example

.. code-block:: python
.. code-block:: python
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.plugins import CheckpointIO
from lightning.pytorch.strategies import SingleDeviceStrategy
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import LightningTransformer
model = LightningTransformer()
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
...
strategy = FSDPStrategy(state_dict_type="sharded")
trainer = L.Trainer(
accelerator="cuda",
devices=4,
strategy=strategy,
max_steps=3,
)
trainer.fit(model)
def load_checkpoint(self, path, storage_options=None):
...
def remove_checkpoint(self, path):
...
Check the contents of the checkpoint folder:

.. code-block:: bash
custom_checkpoint_io = CustomCheckpointIO()
ls -a lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt/
# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
.. code-block::
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
epoch=0-step=3.ckpt/
├── __0_0.distcp
├── __1_0.distcp
├── __2_0.distcp
├── __3_0.distcp
├── .metadata
└── meta.pt
.. note::
The ``.distcp`` files contain the tensor shards from each process/GPU. You can see that the size of these files
is roughly 1/4 of the total size of the checkpoint since the script distributes the model across 4 GPUs.

Some ``Strategy``s like ``DeepSpeedStrategy`` do not support custom ``CheckpointIO`` as checkpointing logic is not modifiable.

----

**************************
Asynchronous Checkpointing
**************************

.. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature.
*****************************
Load a distributed checkpoint
*****************************

To enable saving the checkpoints asynchronously without blocking your training, you can configure
:class:`~lightning.pytorch.plugins.io.async_plugin.AsyncCheckpointIO` plugin to ``Trainer``.
You can easily load a distributed checkpoint in Trainer if your script uses :doc:`FSDP <../advanced/model_parallel/fsdp>`.

.. code-block:: python
from lightning.pytorch.plugins.io import AsyncCheckpointIO
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
# 1. Select the FSDP strategy and set the sharded/distributed checkpoint format
strategy = FSDPStrategy(state_dict_type="sharded")
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
# 2. Pass the strategy to the Trainer
trainer = L.Trainer(devices=2, strategy=strategy, ...)
# 3. Set the checkpoint path to load
trainer.fit(model, ckpt_path="path/to/checkpoint")
It uses its base ``CheckpointIO`` plugin's saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base ``CheckpointIO`` will be set-up for you and all you need to provide is the ``AsyncCheckpointIO`` instance to the ``Trainer``.
But if you want the plugin to use your own custom base ``CheckpointIO`` and want the base to behave asynchronously, pass it as an argument while initializing ``AsyncCheckpointIO``.
Note that you can load the distributed checkpoint even if the world size has changed, i.e., you are running on a different number of GPUs than when you saved the checkpoint.

.. code-block:: python
.. collapse:: Full example

.. code-block:: python
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
from lightning.pytorch.demos import LightningTransformer
model = LightningTransformer()
strategy = FSDPStrategy(state_dict_type="sharded")
trainer = L.Trainer(
accelerator="cuda",
devices=2,
strategy=strategy,
max_steps=5,
)
trainer.fit(model, ckpt_path="lightning_logs/version_0/checkpoints/epoch=0-step=3.ckpt")
.. important::

If you want to load a distributed checkpoint into a script that doesn't use FSDP (or Trainer at all), then you will have to :ref:`convert it to a single-file checkpoint first <Convert dist-checkpoint>`.


----


.. _Convert dist-checkpoint:

from lightning.pytorch.plugins.io import AsyncCheckpointIO
********************************
Convert a distributed checkpoint
********************************

base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])
Coming soon.
8 changes: 8 additions & 0 deletions docs/source-pytorch/glossary/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Console Logging <../common/console_logs>
Debugging <../debug/debugging>
DeepSpeed <../advanced/model_parallel/deepspeed>
Distributed Checkpoints <../common/checkpointing_expert>
Early stopping <../common/early_stopping>
Experiment manager (Logger) <../visualize/experiment_managers>
Finetuning <../advanced/finetuning>
Expand Down Expand Up @@ -113,6 +114,13 @@ Glossary
:button_link: ../advanced/model_parallel/deepspeed.html
:height: 100

.. displayitem::
:header: Distributed Checkpoints
:description: Save and load very large models efficiently with distributed checkpoints
:col_css: col-md-12
:button_link: ../common/checkpointing_expert.html
:height: 100

.. displayitem::
:header: Early stopping
:description: Stop the training when no improvement is observed
Expand Down

0 comments on commit 93c1ab0

Please sign in to comment.