Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mix Precision settings for FSDP Plugins #17670

Merged
merged 4 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `TPUBf16Precision` in favor of `XLABf16Precision` ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))


- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))


-


Expand Down
17 changes: 12 additions & 5 deletions src/lightning/fabric/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,20 @@ def mixed_precision_config(self) -> "TorchMixedPrecision":
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision as TorchMixedPrecision

if self.precision == "16-mixed":
dtype = torch.float16
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")

return TorchMixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made the run initialization in `WandbLogger` lazy to avoid creating artifacts when the CLI is used ([#17573](https://github.com/Lightning-AI/lightning/pull/17573))


- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670))


### Deprecated

- Deprecated the `SingleTPUStrategy` (`strategy="single_tpu"`) in favor of `SingleDeviceXLAStrategy` (`strategy="single_xla"`) ([#17383](https://github.com/Lightning-AI/lightning/pull/17383))
Expand Down
18 changes: 13 additions & 5 deletions src/lightning/pytorch/plugins/precision/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,22 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None

if self.precision == "16-mixed":
dtype = torch.float16
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-mixed":
dtype = torch.bfloat16
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")

return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
param_dtype=param_dtype,
reduce_dtype=reduce_dtype,
buffer_dtype=buffer_dtype,
)
17 changes: 13 additions & 4 deletions tests/tests_fabric/plugins/precision/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,19 @@ def test_fsdp_precision_support(*_):


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize(("precision", "expected"), [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)])
@pytest.mark.parametrize(
("precision", "expected"),
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
],
)
def test_fsdp_precision_config(precision, expected):
plugin = FSDPPrecision(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
assert config.reduce_dtype == expected[2]
26 changes: 19 additions & 7 deletions tests/tests_fabric/strategies/test_fsdp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,28 @@ def step(self, model, batch):
assert isinstance(forward_module, FullyShardedDataParallel)
assert isinstance(self._precision, FSDPPrecision)

precision = torch.float16 if self._precision.precision == "16-mixed" else torch.bfloat16
assert forward_module.mixed_precision.param_dtype == precision
assert forward_module.mixed_precision.reduce_dtype == precision
assert forward_module.mixed_precision.buffer_dtype == precision
if self._precision.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self._precision.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self._precision.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self._precision.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self._precision.precision}")

assert forward_module.mixed_precision.param_dtype == param_dtype
assert forward_module.mixed_precision.reduce_dtype == reduce_dtype
assert forward_module.mixed_precision.buffer_dtype == buffer_dtype

for layer_num in [0, 2]:
assert isinstance(original_module[layer_num], FullyShardedDataParallel)
assert original_module[layer_num].mixed_precision.param_dtype == precision
assert original_module[layer_num].mixed_precision.reduce_dtype == precision
assert original_module[layer_num].mixed_precision.buffer_dtype == precision
assert original_module[layer_num].mixed_precision.param_dtype == param_dtype
assert original_module[layer_num].mixed_precision.reduce_dtype == reduce_dtype
assert original_module[layer_num].mixed_precision.buffer_dtype == buffer_dtype

output = model(batch)
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
Expand Down
64 changes: 49 additions & 15 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,29 @@ def on_predict_batch_end(self, *_) -> None:
def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)
precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16
assert self.layer.mixed_precision.param_dtype == precision
assert self.layer.mixed_precision.reduce_dtype == precision
assert self.layer.mixed_precision.buffer_dtype == precision

if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")

assert self.layer.mixed_precision.param_dtype == param_dtype
assert self.layer.mixed_precision.reduce_dtype == reduce_dtype
assert self.layer.mixed_precision.buffer_dtype == buffer_dtype

for layer_num in [0, 2]:
assert isinstance(self.layer.module[layer_num], FullyShardedDataParallel)
assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype
assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype


class TestFSDPModelAutoWrapped(BoringModel):
Expand Down Expand Up @@ -114,16 +127,28 @@ def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, torch.nn.Sequential)
assert isinstance(self.trainer.strategy.precision_plugin, FSDPMixedPrecisionPlugin)

precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16
if self.trainer.precision == "16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-mixed":
param_dtype = torch.float32
reduce_dtype = buffer_dtype = torch.bfloat16
elif self.trainer.precision == "16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.float16
elif self.trainer.precision == "bf16-true":
param_dtype = reduce_dtype = buffer_dtype = torch.bfloat16
else:
raise ValueError(f"Unknown precision {self.trainer.precision}")

for layer_num in [0, 2]:
if not self.should_be_wrapped[layer_num]:
# this layer is not wrapped
assert not isinstance(self.layer[layer_num], FullyShardedDataParallel)
continue
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
assert self.layer[layer_num].mixed_precision.buffer_dtype == precision
assert self.layer[layer_num].mixed_precision.param_dtype == param_dtype
assert self.layer[layer_num].mixed_precision.reduce_dtype == reduce_dtype
assert self.layer[layer_num].mixed_precision.buffer_dtype == buffer_dtype


def _run_multiple_stages(trainer, model, model_path: Optional[str] = None):
Expand Down Expand Up @@ -174,13 +199,22 @@ def test_invalid_on_cpu(tmpdir):


@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize(("precision", "expected"), [("16-mixed", torch.float16), ("bf16-mixed", torch.bfloat16)])
@pytest.mark.parametrize(
("precision", "expected"),
[
("16-mixed", (torch.float32, torch.float16, torch.float16)),
("bf16-mixed", (torch.float32, torch.bfloat16, torch.bfloat16)),
("16-true", (torch.float16, torch.float16, torch.float16)),
("bf16-true", (torch.bfloat16, torch.bfloat16, torch.bfloat16)),
],
)
def test_precision_plugin_config(precision, expected):
plugin = FSDPMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
assert config.reduce_dtype == expected

assert config.param_dtype == expected[0]
assert config.buffer_dtype == expected[1]
assert config.reduce_dtype == expected[2]


@RunIf(min_torch="1.12")
Expand Down