Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accommodate FSDP full-precision param_dtype training with PyTorch < 2.0 #18278

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))


- Fixed issue where DDP subprocesses that used Hydra would set hydra's working directory to current directory ([#18145](https://github.com/Lightning-AI/lightning/pull/18145))


Expand Down
14 changes: 9 additions & 5 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable

if TYPE_CHECKING:
Expand Down Expand Up @@ -82,18 +82,22 @@ def __init__(self, precision: _PRECISION_INPUT, scaler: Optional["ShardedGradSca
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
# `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float32
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

Expand All @@ -111,7 +115,7 @@ def init_context(self) -> Generator[None, None, None]:

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
torch.set_default_dtype(self.mixed_precision_config.param_dtype or torch.float32)
yield
torch.set_default_dtype(default_dtype)

Expand Down
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))


- Issue warnings rather than failing on optim state dict loading/saving with `FSDPStrategy` PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Fixed an issue with reusing the same model across multiple trainer stages when using the `DeepSpeedStrategy` ([#17531](https://github.com/Lightning-AI/lightning/pull/17531))

Expand Down
14 changes: 9 additions & 5 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from lightning.fabric.plugins.precision.amp import _optimizer_handles_unscaling
from lightning.fabric.plugins.precision.fsdp import _PRECISION_INPUT
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation
from lightning.fabric.utilities.types import Optimizable
from lightning.pytorch.plugins.precision.precision_plugin import PrecisionPlugin
Expand Down Expand Up @@ -91,18 +91,22 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

# With PyTorch < 2.0, FSDP uses the noneness of `param_dtype` as a proxy for the `_uses_param_mixed_precision`
# property. In order to avoid FSDP assertion failures, we therefore avoid setting `param_dtype` to
# `torch.float32` here with PyTorch < 2.0.
if self.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "32-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float32
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

Expand All @@ -120,7 +124,7 @@ def init_context(self) -> Generator[None, None, None]:

"""
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(self.mixed_precision_config.param_dtype)
torch.set_default_dtype(self.mixed_precision_config.param_dtype or torch.float32)
yield
torch.set_default_dtype(default_dtype)

Expand Down
21 changes: 20 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@

log = logging.getLogger(__name__)

_LOAD_SAVE_OPTIM_STATE_WARN_1X = (
"Note that saving/restoring optimizer state using `FSDPStrategy` with PyTorch < 2.0"
" is not supported by Lightning."
)


class FSDPStrategy(ParallelStrategy):
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
Expand Down Expand Up @@ -176,7 +181,11 @@ def lightning_module_state_dict(self) -> Dict[str, Any]:

"""
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType

if _TORCH_GREATER_EQUAL_2_0:
from torch.distributed.fsdp.api import FullStateDictConfig, StateDictType
else:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType

assert self.model is not None

Expand Down Expand Up @@ -442,6 +451,11 @@ def register_strategies(cls, strategy_registry: _StrategyRegistry) -> None:
cls._registered_strategies.append("fsdp_cpu_offload")

def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
if not _TORCH_GREATER_EQUAL_2_0:
rank_zero_warn(
f"{_LOAD_SAVE_OPTIM_STATE_WARN_1X} Bypassing restoration of optimizer state.", category=UserWarning
)
return
from torch.distributed.fsdp import FullyShardedDataParallel, OptimStateKeyType

optimizer_states = checkpoint.get("optimizer_states")
Expand Down Expand Up @@ -476,6 +490,11 @@ def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
optimizer.load_state_dict(opt_state)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
if not _TORCH_GREATER_EQUAL_2_0:
rank_zero_warn(
f"{_LOAD_SAVE_OPTIM_STATE_WARN_1X} Bypassing saving of optimizer state.", category=UserWarning
)
return {}
from torch.distributed.fsdp import FullyShardedDataParallel, OptimStateKeyType

if isinstance(optimizer, LightningOptimizer):
Expand Down
22 changes: 19 additions & 3 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,27 @@ def test_fsdp_precision_support(*_):
@pytest.mark.parametrize(
("precision", "expected"),
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
pytest.param(
"16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
),
pytest.param(
"16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
),
pytest.param(
"bf16-mixed",
(torch.float32, torch.bfloat16, torch.bfloat16),
marks=RunIf(min_torch="2.0"),
id="bf16-mixed-ge2_0",
),
pytest.param(
"bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
),
pytest.param(
"32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0"
),
pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"),
],
)
def test_fsdp_precision_config(precision, expected):
Expand Down
22 changes: 19 additions & 3 deletions tests/tests_pytorch/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,27 @@ def test_fsdp_precision_support(*_):
@pytest.mark.parametrize(
("precision", "expected"),
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
("32-true", (torch.float32, torch.float32, torch.float32)),
pytest.param(
"16-mixed", (torch.float32, torch.float16, torch.float16), marks=RunIf(min_torch="2.0"), id="16-mixed-ge2_0"
),
pytest.param(
"16-mixed", (None, torch.float16, torch.float16), marks=RunIf(max_torch="2.0"), id="16-mixed-lt2_0"
),
pytest.param(
"bf16-mixed",
(torch.float32, torch.bfloat16, torch.bfloat16),
marks=RunIf(min_torch="2.0"),
id="bf16-mixed-ge2_0",
),
pytest.param(
"bf16-mixed", (None, torch.bfloat16, torch.bfloat16), marks=RunIf(max_torch="2.0"), id="bf16-mixed-lt2_0"
),
pytest.param(
"32-true", (torch.float32, torch.float32, torch.float32), marks=RunIf(min_torch="2.0"), id="32-true-ge2_0"
),
pytest.param("32-true", (None, torch.float32, torch.float32), marks=RunIf(max_torch="2.0"), id="32-true-lt2_0"),
],
)
def test_fsdp_precision_config(precision, expected):
Expand Down
96 changes: 56 additions & 40 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down Expand Up @@ -137,10 +137,10 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.trainer.strategy.precision_plugin, FSDPPrecisionPlugin)

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
param_dtype = None if not _TORCH_GREATER_EQUAL_2_0 else torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
Expand Down Expand Up @@ -506,7 +506,7 @@ def test_set_timeout(init_process_group_mock):
)


@RunIf(min_torch="1.12")
@RunIf(min_torch="2.0")
def test_fsdp_strategy_load_optimizer_states_multiple():
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])

Expand Down Expand Up @@ -551,32 +551,38 @@ def test_fsdp_strategy_save_optimizer_states(tmpdir, wrap_min_params):
trainer.save_checkpoint(model_path)

model_state_dict = trainer.strategy.lightning_module_state_dict()
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
if not _TORCH_GREATER_EQUAL_2_0:
if trainer.global_rank == 0:
with pytest.warns(UserWarning, match="Bypassing saving of optimizer state"):
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
else:
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
else:
optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
if trainer.global_rank != 0:
assert len(model_state_dict) == 0

if trainer.global_rank != 0:
assert len(model_state_dict) == 0
if _TORCH_GREATER_EQUAL_2_1:
assert len(optimizer_state_dict) == 0

if _TORCH_GREATER_EQUAL_2_1:
assert len(optimizer_state_dict) == 0
# restore model to ddp
model = TestBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)

# restore model to ddp
model = TestBoringModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy="ddp", max_epochs=1)
# This step will restore the model and optimizer states
trainer.fit(model, ckpt_path=model_path)

# This step will restore the model and optimizer states
trainer.fit(model, ckpt_path=model_path)
# Get the model and optimizer states from the restored ddp model
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

# Get the model and optimizer states from the restored ddp model
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)

if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)

torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)

trainer.strategy.barrier()

Expand Down Expand Up @@ -619,23 +625,29 @@ def test_fsdp_strategy_load_optimizer_states(tmpdir, wrap_min_params):
barebones=True,
)

trainer.fit(model, ckpt_path=model_path)

restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())
if not _TORCH_GREATER_EQUAL_2_0:
if trainer.global_rank == 0:
with pytest.warns(UserWarning, match="Bypassing restoration of optimizer state"):
trainer.fit(model, ckpt_path=model_path)
else:
trainer.fit(model, ckpt_path=model_path)
else:
trainer.fit(model, ckpt_path=model_path)
restored_model_state_dict = trainer.strategy.lightning_module_state_dict()
restored_optimizer_state_dict = trainer.strategy.optimizer_state(model.optimizers())

if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0
if trainer.global_rank != 0:
assert len(restored_model_state_dict) == 0

if _TORCH_GREATER_EQUAL_2_1:
assert len(restored_optimizer_state_dict) == 0
if _TORCH_GREATER_EQUAL_2_1:
assert len(restored_optimizer_state_dict) == 0

if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)
if trainer.global_rank == 0:
# assert everything is the same
assert len(model_state_dict) == len(restored_model_state_dict)
assert len(optimizer_state_dict) == len(restored_optimizer_state_dict)
torch.testing.assert_close(model_state_dict, restored_model_state_dict, atol=0, rtol=0)
torch.testing.assert_close(optimizer_state_dict, restored_optimizer_state_dict, atol=0, rtol=0)

trainer.strategy.barrier()

Expand All @@ -654,7 +666,7 @@ def test_configure_model(precision, expected_dtype):
devices=2,
strategy=FSDPStrategy(auto_wrap_policy=always_wrap_policy),
precision=precision,
fast_dev_run=1,
max_epochs=1,
)

class MyModel(BoringModel):
Expand All @@ -666,6 +678,10 @@ def configure_model(self):
assert self.layer.weight.device == expected_device
assert self.layer.weight.dtype == expected_dtype

def configure_optimizers(self):
# There is some issue with SGD optimizer state in FSDP
return torch.optim.AdamW(self.layer.parameters(), lr=0.1)

def on_fit_start(self):
# Parameters get sharded in `.setup()` and moved to the target device
assert self.layer.weight.device == torch.device("cuda", self.local_rank)
Expand Down