Skip to content

Commit

Permalink
Fix num_nodes not set for FSDPStrategy (#17438)
Browse files Browse the repository at this point in the history
(cherry picked from commit 2e5a7f9)
  • Loading branch information
awaelchli authored and lantiga committed Apr 24, 2023
1 parent 8b11e99 commit 8c9cf00
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed issue where `Model.load_from_checkpoint("checkpoint.ckpt", map_location=map_location)` would always return model on CPU ([#17308](https://github.com/Lightning-AI/lightning/pull/17308))

- Fixed an issue that caused `num_nodes` not to be set correctly for `FSDPStrategy` ([#17438](https://github.com/Lightning-AI/lightning/pull/17438))


## [2.0.1] - 2023-03-30
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def _lazy_init_strategy(self) -> None:
else:
self.strategy.parallel_devices = self._parallel_devices
if hasattr(self.strategy, "num_nodes"):
self.strategy._num_nodes = self._num_nodes_flag
self.strategy.num_nodes = self._num_nodes_flag
if hasattr(self.strategy, "_layer_sync"):
self.strategy._layer_sync = self._layer_sync
if hasattr(self.strategy, "set_world_ranks"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,3 +971,17 @@ def _mock_tpu_available(value):
assert isinstance(connector.strategy.cluster_environment, XLAEnvironment)
assert connector.strategy._start_method == "fork"
assert connector.strategy.launcher.is_interactive_compatible


@pytest.mark.parametrize(
"strategy",
[
"ddp",
"ddp_spawn",
pytest.param("deepspeed", marks=RunIf(deepspeed=True)),
pytest.param("fsdp", marks=RunIf(min_torch="1.12.0")),
],
)
def test_connector_sets_num_nodes(strategy, cuda_count_2):
trainer = Trainer(accelerator="cuda", strategy=strategy, devices=2, num_nodes=2)
assert trainer.strategy.num_nodes == 2

0 comments on commit 8c9cf00

Please sign in to comment.