Skip to content

Commit

Permalink
feature: add _generate_works_json method (#15767)
Browse files Browse the repository at this point in the history
(cherry picked from commit 51bb845)
  • Loading branch information
yurijmikhalevich authored and Borda committed Nov 30, 2022
1 parent 24ffa59 commit 197c94d
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 79 deletions.
189 changes: 115 additions & 74 deletions src/lightning_app/runners/cloud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fnmatch
import json
import os
import random
import string
Expand Down Expand Up @@ -70,6 +71,119 @@
logger = Logger(__name__)


def _get_work_specs(app: LightningApp) -> List[V1Work]:
works: List[V1Work] = []
for work in app.works:
_validate_build_spec_and_compute(work)

if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))

return works


def _to_clean_dict(swagger_object, map_attributes):
"""Returns the swagger object properties as a dict with correct object names."""

if hasattr(swagger_object, "to_dict"):
attribute_map = swagger_object.attribute_map
result = {}
for key in attribute_map.keys():
value = getattr(swagger_object, key)
value = _to_clean_dict(value, map_attributes)
if value is not None and value != {}:
key = attribute_map[key] if map_attributes else key
result[key] = value
return result
elif isinstance(swagger_object, list):
return [_to_clean_dict(x, map_attributes) for x in swagger_object]
elif isinstance(swagger_object, dict):
return {key: _to_clean_dict(value, map_attributes) for key, value in swagger_object.items()}
return swagger_object


def _generate_works_json(filepath: str, map_attributes: bool) -> str:
app = CloudRuntime.load_app_from_file(filepath)
works = _get_work_specs(app)
works_json = json.dumps(_to_clean_dict(works, map_attributes), separators=(",", ":"))
return works_json


def _generate_works_json_web(filepath: str) -> str:
return _generate_works_json(filepath, True)


def _generate_works_json_gallery(filepath: str) -> str:
return _generate_works_json(filepath, False)


@dataclass
class CloudRuntime(Runtime):

Expand Down Expand Up @@ -141,80 +255,7 @@ def dispatch(
if not ENABLE_PUSHING_STATE_ENDPOINT:
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))

works: List[V1Work] = []
for work in self.app.works:
_validate_build_spec_and_compute(work)

if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
),
mount_location=str(drive.root_folder),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))
works: List[V1Work] = _get_work_specs(self.app)

# We need to collect a spec for each flow that contains a frontend so that the backend knows
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
Expand Down
84 changes: 79 additions & 5 deletions tests/tests_app/runners/test_cloud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import re
import sys
from copy import copy
from pathlib import Path
Expand Down Expand Up @@ -42,7 +43,11 @@

from lightning_app import _PROJECT_ROOT, BuildConfig, LightningApp, LightningWork
from lightning_app.runners import backends, cloud, CloudRuntime
from lightning_app.runners.cloud import _validate_build_spec_and_compute
from lightning_app.runners.cloud import (
_generate_works_json_gallery,
_generate_works_json_web,
_validate_build_spec_and_compute,
)
from lightning_app.storage import Drive, Mount
from lightning_app.testing.helpers import EmptyFlow, EmptyWork
from lightning_app.utilities.cloud import _get_project
Expand Down Expand Up @@ -644,7 +649,6 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(
Expand Down Expand Up @@ -869,7 +873,6 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
)
lit_drive_2_spec = V1LightningworkDrives(
drive=V1Drive(
Expand All @@ -883,7 +886,6 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
)

# order of drives in the spec is non-deterministic, so there are two options
Expand Down Expand Up @@ -1103,7 +1105,6 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
),
V1LightningworkDrives(
drive=V1Drive(
Expand Down Expand Up @@ -1282,6 +1283,79 @@ def test_load_app_from_file_mock_imports(tmpdir, lines):
os.remove(app_file)


@pytest.mark.parametrize(
"generator,expected",
[
(
_generate_works_json_web,
[
{
"name": "root.work",
"spec": {
"buildSpec": {
"commands": [],
"pythonDependencies": {"packageManager": "PACKAGE_MANAGER_PIP", "packages": ""},
},
"drives": [],
"networkConfig": [{"name": "*", "port": "*"}],
"userRequestedComputeConfig": {
"count": 1,
"diskSize": 0,
"name": "default",
"preemptible": "*",
"shmSize": 0,
},
},
}
],
),
(
_generate_works_json_gallery,
[
{
"name": "root.work",
"spec": {
"build_spec": {
"commands": [],
"python_dependencies": {"package_manager": "PACKAGE_MANAGER_PIP", "packages": ""},
},
"drives": [],
"network_config": [{"name": "*", "port": "*"}],
"user_requested_compute_config": {
"count": 1,
"disk_size": 0,
"name": "default",
"preemptible": "*",
"shm_size": 0,
},
},
}
],
),
],
)
@pytest.mark.skipif(sys.platform != "linux", reason="Causing conflicts on non-linux")
def test_generate_works_json(tmpdir, generator, expected):
path = copy(sys.path)
app_file = os.path.join(tmpdir, "app.py")

with open(app_file, "w") as f:
lines = [
"from lightning_app import LightningApp",
"from lightning_app.testing.helpers import EmptyWork",
"app = LightningApp(EmptyWork())",
]
f.write("\n".join(lines))

works_string = generator(app_file)
expected = re.escape(str(expected).replace("'", '"').replace(" ", "")).replace('"\\*"', "(.*)")
assert re.fullmatch(expected, works_string)

# Cleanup PATH to prevent conflict with other tests
sys.path = path
os.remove(app_file)


def test_incompatible_cloud_compute_and_build_config():
"""Test that an exception is raised when a build config has a custom image defined, but the cloud compute is
the default.
Expand Down

0 comments on commit 197c94d

Please sign in to comment.