Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert removal of empty-parameters check for configure_optimizers() with FSDP #18785

Merged
merged 5 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
carmocca marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,18 @@ def setup(self, trainer: "pl.Trainer") -> None:
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
if self.kwargs.get("use_orig_params"):
return super().setup_optimizers(trainer)
if any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):

invalid_params_error = False
try:
# In PyTorch < 2.0, or if `use_orig_params=False` the user needs to do access
# `self.trainer.model.parameters()` in configure_optimizers()
super().setup_optimizers(trainer)
except ValueError as ex:
if "optimizer got an empty parameter list" not in str(ex):
raise
invalid_params_error = True

if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
raise ValueError(
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
Expand Down
18 changes: 15 additions & 3 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,16 +359,22 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):


@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True)
def test_invalid_parameters_in_optimizer():
@pytest.mark.parametrize("use_orig_params", [None, False, True])
def test_invalid_parameters_in_optimizer(use_orig_params):
fsdp_kwargs = {}
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not None:
fsdp_kwargs = {"use_orig_params": use_orig_params}

trainer = Trainer(
strategy="fsdp",
strategy=FSDPStrategy(**fsdp_kwargs),
accelerator="cuda",
devices=1,
fast_dev_run=1,
)

error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0
if _TORCH_GREATER_EQUAL_2_0 and (_TORCH_GREATER_EQUAL_2_1 or use_orig_params is not False)
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)

Expand All @@ -385,6 +391,12 @@ def configure_optimizers(self):
layer = torch.nn.Linear(4, 5)
return torch.optim.Adam(layer.parameters(), lr=1e-2)

error_context = (
nullcontext()
if _TORCH_GREATER_EQUAL_2_0 and use_orig_params is not False
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
)

model = NoFlatParametersModel()
with error_context:
trainer.fit(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def on_before_optimizer_step(self, optimizer, *_):

def test_step_with_optimizer_closure(tmpdir):
"""Tests that `step` works with optimizer_closure."""
seed_everything(1)

class TestModel(BoringModel):
_losses = []
Expand Down