Skip to content

Commit

Permalink
Merge branch 'Lightning-AI:master' into hybrid_fsdp_stage
Browse files Browse the repository at this point in the history
  • Loading branch information
Liyang90 authored Mar 5, 2024
2 parents 3315893 + b871f7a commit 4652b74
Show file tree
Hide file tree
Showing 23 changed files with 169 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pkg-install.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ jobs:
done
- name: Install pytest doctest extension
run: |
pip install -q "pytest-doctestplus>=0.9.0"
pip install -q -r requirements/doctests.txt
pip list
- name: DocTest package
Expand Down
19 changes: 18 additions & 1 deletion docs/source-fabric/advanced/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Speed up models by compiling them
#################################

Compiling your PyTorch model can result in significant speedups, especially on the latest generations of GPUs.
This guide shows you how to apply ``torch.compile`` correctly in your code.
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.

.. note::

Expand Down Expand Up @@ -223,6 +223,9 @@ On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically a
Numbers produced with NVIDIA A100 SXM4 40GB, PyTorch 2.2.0, CUDA 12.1.


If you still see recompilation issues after dealing with the aforementioned cases, there is a `Compile Profiler in PyTorch <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html#excessive-recompilation>`_ for further investigation.


----


Expand Down Expand Up @@ -301,4 +304,18 @@ However, should you have issues compiling DDP and FSDP models, you can opt out o
model = fabric.setup(model, _reapply_compile=False)
----


********************
Additional Resources
********************

Here are a few resources for further reading after you complete this tutorial:

- `PyTorch 2.0 Paper <https://pytorch.org/blog/pytorch-2-paper-tutorial/>`_
- `GenAI with PyTorch 2.0 blog post series <https://pytorch.org/blog/accelerating-generative-ai-4/>`_
- `Training Production AI Models with PyTorch 2.0 <https://pytorch.org/blog/training-production-ai-models/>`_
- `Empowering Models with Performance: The Art of Generalized Model Transformation Approach <https://pytorch.org/blog/empowering-models-performance/>`_

|
5 changes: 4 additions & 1 deletion docs/source-fabric/api/fabric_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,16 @@ See also: :doc:`../fundamentals/accelerators`
strategy
========

Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"xla"``, ``"deepspeed"``, ``"fsdp"````.
Choose a training strategy: ``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"ddp_find_unused_parameters_true"``, ``"xla"``, ``"deepspeed"``, ``"fsdp"``.

.. code-block:: python
# Running with the DistributedDataParallel strategy on 4 GPUs
fabric = Fabric(strategy="ddp", accelerator="gpu", devices=4)
# Running with the DDP strategy with find unused parameters enabled on 4 GPUs
fabric = Fabric(strategy="ddp_find_unused_parameters_true", accelerator="gpu", devices=4)
# Running with the DDP Spawn strategy using 4 CPU processes
fabric = Fabric(strategy="ddp_spawn", accelerator="cpu", devices=4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ It is possible to convert a distributed checkpoint to a regular, single-file che

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
fabric consolidate path/to/my/checkpoint
You will need to do this for example if you want to load the checkpoint into a script that doesn't use FSDP, or need to export the checkpoint to a different format for deployment, evaluation, etc.

Expand All @@ -202,7 +202,7 @@ You will need to do this for example if you want to load the checkpoint into a s

.. code-block:: bash
python -m lightning.fabric.utilities.consolidate_checkpoint my-checkpoint.ckpt
fabric consolidate my-checkpoint.ckpt
This saves a new file ``my-checkpoint.ckpt.consolidated`` next to the sharded checkpoint which you can load normally in PyTorch:

Expand Down
23 changes: 20 additions & 3 deletions docs/source-pytorch/advanced/compile.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Speed up models by compiling them
#################################

Compiling your LightningModule can result in significant speedups, especially on the latest generations of GPUs.
This guide shows you how to apply ``torch.compile`` correctly in your code.
This guide shows you how to apply `torch.compile <https://pytorch.org/docs/2.2/generated/torch.compile.html>`_ correctly in your code.

.. note::

Expand Down Expand Up @@ -192,6 +192,8 @@ However, when this is not possible, you can request PyTorch to compile the code
A model compiled with ``dynamic=True`` will typically be slower than a model compiled with static shapes, but it will avoid the extreme cost of recompilation every iteration.
On PyTorch 2.2 and later, ``torch.compile`` will detect dynamism automatically and you should no longer need to set this.

If you still see recompilation issues after dealing with the aforementioned cases, there is a `Compile Profiler in PyTorch <https://pytorch.org/docs/stable/torch.compiler_troubleshooting.html#excessive-recompilation>`_ for further investigation.


----

Expand Down Expand Up @@ -251,9 +253,9 @@ Always compare the speed and memory usage of the compiled model against the orig
Limitations
***********

There are a few limitations you should be aware of when using ``torch.compile`` in conjunction with the Trainer:
There are a few limitations you should be aware of when using ``torch.compile`` **in conjunction with the Trainer**:

* ``torch.compile`` currently does not get reapplied over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment.
* The Trainer currently does not reapply ``torch.compile`` over DDP/FSDP, meaning distributed operations can't benefit from speed ups at the moment.
This limitation will be lifted in the future.

* In some cases, using ``self.log()`` in your LightningModule will cause compilation errors.
Expand All @@ -270,4 +272,19 @@ There are a few limitations you should be aware of when using ``torch.compile``
self.model = torch.compile(self.model)
...
----


********************
Additional Resources
********************

Here are a few resources for further reading after you complete this tutorial:

- `PyTorch 2.0 Paper <https://pytorch.org/blog/pytorch-2-paper-tutorial/>`_
- `GenAI with PyTorch 2.0 blog post series <https://pytorch.org/blog/accelerating-generative-ai-4/>`_
- `Training Production AI Models with PyTorch 2.0 <https://pytorch.org/blog/training-production-ai-models/>`_
- `Empowering Models with Performance: The Art of Generalized Model Transformation Approach <https://pytorch.org/blog/empowering-models-performance/>`_

|
2 changes: 1 addition & 1 deletion docs/source-pytorch/ecosystem/asr_nlp_tts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ To install from a local clone of NeMo:
./reinstall.sh # from cloned NeMo's git root
For Docker users, the NeMo container is available on
`NGC <https://ngc.nvidia.com/catalog/containers/nvidia:nemo>`_.
`NGC <https://catalog.ngc.nvidia.com/orgs/nvidia/collections/nemotrainingframework>`_.

.. code-block:: bash
Expand Down
2 changes: 2 additions & 0 deletions requirements/doctests.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pytest ==7.4.0
pytest-doctestplus ==1.0.0
2 changes: 1 addition & 1 deletion requirements/fabric/strategies.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
# note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods`
# shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372
deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" # strict
bitsandbytes ==0.41.0 # strict
bitsandbytes >=0.42.0,<0.43.0
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ hydra-core >=1.0.5, <1.4.0
jsonargparse[signatures] >=4.27.5, <4.28.0
rich >=12.3.0, <13.6.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes ==0.41.0 # strict
bitsandbytes >=0.42.0,<0.43.0
2 changes: 1 addition & 1 deletion src/lightning/app/cli/react-ui-template/ui/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"@vitejs/plugin-react": "^1.0.7",
"prettier": "^2.5.1",
"typescript": "^4.5.4",
"vite": "^2.9.16"
"vite": "^2.9.17"
},
"main": "index.js",
"license": "MIT"
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/app/cli/react-ui-template/ui/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -1260,10 +1260,10 @@ update-browserslist-db@^1.0.4:
escalade "^3.1.1"
picocolors "^1.0.0"

vite@^2.9.16:
version "2.9.16"
resolved "https://registry.yarnpkg.com/vite/-/vite-2.9.16.tgz#daf7ba50f5cc37a7bf51b118ba06bc36e97898e9"
integrity sha512-X+6q8KPyeuBvTQV8AVSnKDvXoBMnTx8zxh54sOwmmuOdxkjMmEJXH2UEchA+vTMps1xw9vL64uwJOWryULg7nA==
vite@^2.9.17:
version "2.9.17"
resolved "https://registry.yarnpkg.com/vite/-/vite-2.9.17.tgz#6b770525e12fa2a2e3a0fa0d028d304f4f7dc7d4"
integrity sha512-XxcRzra6d7xrKXH66jZUgb+srThoPu+TLJc06GifUyKq9JmjHkc1Numc8ra0h56rju2jfVWw3B3fs5l3OFMvUw==
dependencies:
esbuild "^0.14.27"
postcss "^8.4.13"
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
- Enabled consolidating distributed checkpoints through `fabric consolidate` in the new CLI [#19560](https://github.com/Lightning-AI/pytorch-lightning/pull/19560))

-

Expand Down Expand Up @@ -46,13 +46,20 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
-

-

-


## [2.2.1] - 2024-03-04

### Fixed

- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))


## [2.2.0] - 2024-02-08

### Added
Expand Down
34 changes: 34 additions & 0 deletions src/lightning/fabric/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from argparse import Namespace
from typing import Any, List, Optional

import torch
from lightning_utilities.core.imports import RequirementCache
from typing_extensions import get_args

from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS
from lightning.fabric.strategies import STRATEGY_REGISTRY
from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args
from lightning.fabric.utilities.device_parser import _parse_gpu_ids
from lightning.fabric.utilities.distributed import _suggested_max_num_threads
from lightning.fabric.utilities.load import _load_distributed_checkpoint

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -154,6 +157,37 @@ def _run(**kwargs: Any) -> None:
script_args = list(kwargs.pop("script_args", []))
main(args=Namespace(**kwargs), script_args=script_args)

@_main.command(
"consolidate",
context_settings={
"ignore_unknown_options": True,
},
)
@click.argument(
"checkpoint_folder",
type=click.Path(exists=True),
)
@click.option(
"--output_file",
type=click.Path(exists=True),
default=None,
help=(
"Path to the file where the converted checkpoint should be saved. The file should not already exist."
" If no path is provided, the file will be saved next to the input checkpoint folder with the same name"
" and a '.consolidated' suffix."
),
)
def _consolidate(checkpoint_folder: str, output_file: Optional[str]) -> None:
"""Convert a distributed/sharded checkpoint into a single file that can be loaded with `torch.load()`.
Only supports FSDP sharded checkpoints at the moment.
"""
args = Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file)
config = _process_cli_args(args)
checkpoint = _load_distributed_checkpoint(config.checkpoint_folder)
torch.save(checkpoint, config.output_file)


def _set_env_variables(args: Namespace) -> None:
"""Set the environment variables for the new processes.
Expand Down
7 changes: 3 additions & 4 deletions src/lightning/fabric/plugins/precision/bitsandbytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@

log = logging.getLogger(__name__)

# TODO: unpin after resolving the `quant_state` format breaking changes
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes==0.41.0")
_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.42.0")


class BitsandbytesPrecision(Precision):
Expand Down Expand Up @@ -344,7 +343,7 @@ def quantize(
def to_empty(self, *, device: _DEVICE, recurse: bool = True) -> Self:
if self.weight.dtype == torch.uint8: # was quantized
# cannot init the quantized params directly
weight = torch.empty(self.weight.quant_state[1], device=device, dtype=torch.half)
weight = torch.empty(self.weight.quant_state.shape, device=device, dtype=torch.half)
else:
weight = torch.empty_like(self.weight.data, device=device)
device = torch.device(device)
Expand All @@ -366,7 +365,7 @@ def reset_parameters(self) -> None:
linear_init_finished = isinstance(self.weight, bnb.nn.Params4bit)
if linear_init_finished and self.weight.dtype == torch.uint8: # was quantized
# cannot init the quantized params directly
weight = torch.empty(self.weight.quant_state[1], device=self.weight.device, dtype=torch.half)
weight = torch.empty(self.weight.quant_state.shape, device=self.weight.device, dtype=torch.half)
else:
weight = self.weight.data
torch.nn.init.kaiming_uniform_(weight, a=math.sqrt(5))
Expand Down
6 changes: 3 additions & 3 deletions src/lightning/fabric/utilities/consolidate_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.load import _METADATA_FILENAME, _load_distributed_checkpoint

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,8 +38,8 @@ def _parse_cli_args() -> Namespace:


def _process_cli_args(args: Namespace) -> Namespace:
if not _TORCH_GREATER_EQUAL_2_1:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.1.")
if not _TORCH_GREATER_EQUAL_2_3:
_log.error("Processing distributed checkpoints requires PyTorch >= 2.3.")
exit(1)

checkpoint_folder = Path(args.checkpoint_folder)
Expand Down
Loading

0 comments on commit 4652b74

Please sign in to comment.