Skip to content

Commit

Permalink
Fix multinode cloud component (#15965)
Browse files Browse the repository at this point in the history
* fix multinode cloud component
* add tests

(cherry picked from commit d21b899)
  • Loading branch information
justusschock authored and Borda committed Dec 9, 2022
1 parent d1509ad commit b9bb24a
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950))
- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))
- Fixed MultiNode Component to use separate cloud computes ([#15965](https://github.com/Lightning-AI/lightning/pull/15965))


## [1.8.3] - 2022-11-22

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/components/multi_node/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(
*[
work_cls(
*work_args,
cloud_compute=cloud_compute,
cloud_compute=cloud_compute.clone(),
**work_kwargs,
parallel=True,
)
Expand Down
10 changes: 9 additions & 1 deletion src/lightning_app/utilities/packaging/cloud_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __post_init__(self) -> None:

# All `default` CloudCompute are identified in the same way.
if self._internal_id is None:
self._internal_id = "default" if self.name == "default" else uuid4().hex[:7]
self._internal_id = self._generate_id()

# Internal arguments for now.
self.preemptible = False
Expand Down Expand Up @@ -118,6 +118,14 @@ def id(self) -> Optional[str]:
def is_default(self) -> bool:
return self.name == "default"

def _generate_id(self):
return "default" if self.name == "default" else uuid4().hex[:7]

def clone(self):
new_dict = self.to_dict()
new_dict["_internal_id"] = self._generate_id()
return self.from_dict(new_dict)


def _verify_mount_root_dirs_are_unique(mounts: Union[None, Mount, List[Mount], Tuple[Mount]]) -> None:
if isinstance(mounts, (list, tuple, set)):
Expand Down
12 changes: 12 additions & 0 deletions tests/tests_app/components/multi_node/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from re import escape
from unittest import mock

import pytest
from tests_app.helpers.utils import no_warning_call
Expand All @@ -17,3 +18,14 @@ def run(self):

with no_warning_call(UserWarning, match=escape("You set MultiNode(num_nodes=1, ...)` but ")):
MultiNode(Work, num_nodes=1, cloud_compute=CloudCompute("gpu"))


@mock.patch("lightning_app.components.multi_node.base.is_running_in_cloud", mock.Mock(return_value=True))
def test_multi_node_separate_cloud_computes():
class Work(LightningWork):
def run(self):
pass

m = MultiNode(Work, num_nodes=2, cloud_compute=CloudCompute("gpu"))

assert len({w.cloud_compute._internal_id for w in m.ws}) == len(m.ws)
18 changes: 18 additions & 0 deletions tests/tests_app/utilities/packaging/test_cloud_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,21 @@ def test_cloud_compute_with_non_unique_mount_root_dirs():

with pytest.raises(ValueError, match="Every Mount attached to a work must have a unique"):
CloudCompute("gpu", mounts=[mount_1, mount_2])


def test_cloud_compute_clone():
c1 = CloudCompute("gpu")
c2 = c1.clone()

assert isinstance(c2, CloudCompute)

c1_dict = c1.to_dict()
c2_dict = c2.to_dict()

assert len(c1_dict) == len(c2_dict)

for k in c1_dict.keys():
if k == "_internal_id":
assert c1_dict[k] != c2_dict[k]
else:
assert c1_dict[k] == c2_dict[k]

0 comments on commit b9bb24a

Please sign in to comment.