From b9bb24a086cc109ebcaf34fb7224d0f3927ed61d Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 9 Dec 2022 13:02:58 +0100 Subject: [PATCH] Fix multinode cloud component (#15965) * fix multinode cloud component * add tests (cherry picked from commit d21b8992eead8f544a41792e4ef40a2710423a62) --- src/lightning_app/CHANGELOG.md | 2 ++ .../components/multi_node/base.py | 2 +- .../utilities/packaging/cloud_compute.py | 10 +++++++++- .../components/multi_node/test_base.py | 12 ++++++++++++ .../utilities/packaging/test_cloud_compute.py | 18 ++++++++++++++++++ 5 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index c63af4ac40f0c..a25c8ecd9fe39 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950)) - Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951)) - Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963)) +- Fixed MultiNode Component to use separate cloud computes ([#15965](https://github.com/Lightning-AI/lightning/pull/15965)) + ## [1.8.3] - 2022-11-22 diff --git a/src/lightning_app/components/multi_node/base.py b/src/lightning_app/components/multi_node/base.py index ee4f2b3abd4fb..5662442b7375a 100644 --- a/src/lightning_app/components/multi_node/base.py +++ b/src/lightning_app/components/multi_node/base.py @@ -66,7 +66,7 @@ def run( *[ work_cls( *work_args, - cloud_compute=cloud_compute, + cloud_compute=cloud_compute.clone(), **work_kwargs, parallel=True, ) diff --git a/src/lightning_app/utilities/packaging/cloud_compute.py b/src/lightning_app/utilities/packaging/cloud_compute.py index f3b162ed042c6..ca6c9705ae866 100644 --- a/src/lightning_app/utilities/packaging/cloud_compute.py +++ b/src/lightning_app/utilities/packaging/cloud_compute.py @@ -82,7 +82,7 @@ def __post_init__(self) -> None: # All `default` CloudCompute are identified in the same way. if self._internal_id is None: - self._internal_id = "default" if self.name == "default" else uuid4().hex[:7] + self._internal_id = self._generate_id() # Internal arguments for now. self.preemptible = False @@ -118,6 +118,14 @@ def id(self) -> Optional[str]: def is_default(self) -> bool: return self.name == "default" + def _generate_id(self): + return "default" if self.name == "default" else uuid4().hex[:7] + + def clone(self): + new_dict = self.to_dict() + new_dict["_internal_id"] = self._generate_id() + return self.from_dict(new_dict) + def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None: if isinstance(mounts, (list, tuple, set)): diff --git a/tests/tests_app/components/multi_node/test_base.py b/tests/tests_app/components/multi_node/test_base.py index e23535fbfe970..2c6aed1120c0a 100644 --- a/tests/tests_app/components/multi_node/test_base.py +++ b/tests/tests_app/components/multi_node/test_base.py @@ -1,4 +1,5 @@ from re import escape +from unittest import mock import pytest from tests_app.helpers.utils import no_warning_call @@ -17,3 +18,14 @@ def run(self): with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")): MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu")) + + +@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", mock.Mock(return_value=True)) +def test_multi_node_separate_cloud_computes(): + class Work(LightningWork): + def run(self): + pass + + m = MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu")) + + assert len({w.cloud_compute._internal_id for w in m.ws}) == len(m.ws) diff --git a/tests/tests_app/utilities/packaging/test_cloud_compute.py b/tests/tests_app/utilities/packaging/test_cloud_compute.py index aa0395aa5451a..f2670723f132a 100644 --- a/tests/tests_app/utilities/packaging/test_cloud_compute.py +++ b/tests/tests_app/utilities/packaging/test_cloud_compute.py @@ -41,3 +41,21 @@ def test_cloud_compute_with_non_unique_mount_root_dirs(): with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"): CloudCompute("gpu", mounts=[mount_1, mount_2]) + + +def test_cloud_compute_clone(): + c1 = CloudCompute("gpu") + c2 = c1.clone() + + assert isinstance(c2, CloudCompute) + + c1_dict = c1.to_dict() + c2_dict = c2.to_dict() + + assert len(c1_dict) == len(c2_dict) + + for k in c1_dict.keys(): + if k == "_internal_id": + assert c1_dict[k] != c2_dict[k] + else: + assert c1_dict[k] == c2_dict[k]