Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature: add _generate_works_json method #15767

Merged
merged 11 commits into from
Nov 22, 2022
189 changes: 115 additions & 74 deletions src/lightning_app/runners/cloud.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import fnmatch
import json
import os
import random
import string
Expand Down Expand Up @@ -70,6 +71,119 @@
logger = Logger(__name__)


def _get_work_specs(app: LightningApp) -> List[V1Work]:
works: List[V1Work] = []
for work in app.works:
_validate_build_spec_and_compute(work)

if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))

return works


def _to_clean_dict(swagger_object, map_attributes):
"""Returns the swagger object properties as a dict with correct object names."""

if hasattr(swagger_object, "to_dict"):
attribute_map = swagger_object.attribute_map
result = {}
for key in attribute_map.keys():
value = getattr(swagger_object, key)
value = _to_clean_dict(value, map_attributes)
if value is not None and value != {}:
key = attribute_map[key] if map_attributes else key
result[key] = value
return result
elif isinstance(swagger_object, list):
return [_to_clean_dict(x, map_attributes) for x in swagger_object]
elif isinstance(swagger_object, dict):
return {key: _to_clean_dict(value, map_attributes) for key, value in swagger_object.items()}
return swagger_object


def _generate_works_json(filepath: str, map_attributes: bool) -> str:
app = CloudRuntime.load_app_from_file(filepath)
works = _get_work_specs(app)
works_json = json.dumps(_to_clean_dict(works, map_attributes), separators=(",", ":"))
return works_json


def _generate_works_json_web(filepath: str) -> str:
return _generate_works_json(filepath, True)


def _generate_works_json_gallery(filepath: str) -> str:
return _generate_works_json(filepath, False)


@dataclass
class CloudRuntime(Runtime):

Expand Down Expand Up @@ -141,80 +255,7 @@ def dispatch(
if not ENABLE_PUSHING_STATE_ENDPOINT:
v1_env_vars.append(V1EnvVar(name="ENABLE_PUSHING_STATE_ENDPOINT", value="0"))

works: List[V1Work] = []
for work in self.app.works:
_validate_build_spec_and_compute(work)

if not work._start_with_flow:
continue

work_requirements = "\n".join(work.cloud_build_config.requirements)
build_spec = V1BuildSpec(
commands=work.cloud_build_config.build_commands(),
python_dependencies=V1PythonDependencyInfo(
package_manager=V1PackageManager.PIP, packages=work_requirements
),
image=work.cloud_build_config.image,
)
user_compute_config = V1UserRequestedComputeConfig(
name=work.cloud_compute.name,
count=1,
disk_size=work.cloud_compute.disk_size,
preemptible=work.cloud_compute.preemptible,
shm_size=work.cloud_compute.shm_size,
)

drive_specs: List[V1LightningworkDrives] = []
for drive_attr_name, drive in [
(k, getattr(work, k)) for k in work._state if isinstance(getattr(work, k), Drive)
]:
if drive.protocol == "lit://":
drive_type = V1DriveType.NO_MOUNT_S3
source_type = V1SourceType.S3
else:
raise RuntimeError(
f"unknown drive protocol `{drive.protocol}`. Please verify this "
f"drive type has been configured for use in the cloud dispatcher."
)

drive_specs.append(
V1LightningworkDrives(
drive=V1Drive(
metadata=V1Metadata(
name=f"{work.name}.{drive_attr_name}",
),
spec=V1DriveSpec(
drive_type=drive_type,
source_type=source_type,
source=f"{drive.protocol}{drive.id}",
),
status=V1DriveStatus(),
),
mount_location=str(drive.root_folder),
),
)

# TODO: Move this to the CloudCompute class and update backend
if work.cloud_compute.mounts is not None:
mounts = work.cloud_compute.mounts
if isinstance(mounts, Mount):
mounts = [mounts]
for mount in mounts:
drive_specs.append(
_create_mount_drive_spec(
work_name=work.name,
mount=mount,
)
)

random_name = "".join(random.choice(string.ascii_lowercase) for _ in range(5))
work_spec = V1LightningworkSpec(
build_spec=build_spec,
drives=drive_specs,
user_requested_compute_config=user_compute_config,
network_config=[V1NetworkConfig(name=random_name, port=work.port)],
)
works.append(V1Work(name=work.name, spec=work_spec))
works: List[V1Work] = _get_work_specs(self.app)

# We need to collect a spec for each flow that contains a frontend so that the backend knows
# for which flows it needs to start servers by invoking the cli (see the serve_frontend() method below)
Expand Down
84 changes: 79 additions & 5 deletions tests/tests_app/runners/test_cloud.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
import re
import sys
from copy import copy
from pathlib import Path
Expand Down Expand Up @@ -42,7 +43,11 @@

from lightning_app import _PROJECT_ROOT, BuildConfig, LightningApp, LightningWork
from lightning_app.runners import backends, cloud, CloudRuntime
from lightning_app.runners.cloud import _validate_build_spec_and_compute
from lightning_app.runners.cloud import (
_generate_works_json_gallery,
_generate_works_json_web,
_validate_build_spec_and_compute,
)
from lightning_app.storage import Drive, Mount
from lightning_app.testing.helpers import EmptyFlow, EmptyWork
from lightning_app.utilities.cloud import _get_project
Expand Down Expand Up @@ -644,7 +649,6 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
),
],
user_requested_compute_config=V1UserRequestedComputeConfig(
Expand Down Expand Up @@ -869,7 +873,6 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
)
lit_drive_2_spec = V1LightningworkDrives(
drive=V1Drive(
Expand All @@ -883,7 +886,6 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
)

# order of drives in the spec is non-deterministic, so there are two options
Expand Down Expand Up @@ -1103,7 +1105,6 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo
),
status=V1DriveStatus(),
),
mount_location=str(tmpdir),
),
V1LightningworkDrives(
drive=V1Drive(
Expand Down Expand Up @@ -1282,6 +1283,79 @@ def test_load_app_from_file_mock_imports(tmpdir, lines):
os.remove(app_file)


@pytest.mark.parametrize(
"generator,expected",
[
(
_generate_works_json_web,
[
{
"name": "root.work",
"spec": {
"buildSpec": {
"commands": [],
"pythonDependencies": {"packageManager": "PACKAGE_MANAGER_PIP", "packages": ""},
},
"drives": [],
"networkConfig": [{"name": "*", "port": "*"}],
"userRequestedComputeConfig": {
"count": 1,
"diskSize": 0,
"name": "default",
"preemptible": "*",
"shmSize": 0,
},
},
}
],
),
(
_generate_works_json_gallery,
[
{
"name": "root.work",
"spec": {
"build_spec": {
"commands": [],
"python_dependencies": {"package_manager": "PACKAGE_MANAGER_PIP", "packages": ""},
},
"drives": [],
"network_config": [{"name": "*", "port": "*"}],
"user_requested_compute_config": {
"count": 1,
"disk_size": 0,
"name": "default",
"preemptible": "*",
"shm_size": 0,
},
},
}
],
),
],
)
@pytest.mark.skipif(sys.platform != "linux", reason="Causing conflicts on non-linux")
def test_generate_works_json(tmpdir, generator, expected):
path = copy(sys.path)
app_file = os.path.join(tmpdir, "app.py")

with open(app_file, "w") as f:
lines = [
"from lightning_app import LightningApp",
"from lightning_app.testing.helpers import EmptyWork",
"app = LightningApp(EmptyWork())",
]
f.write("\n".join(lines))

works_string = generator(app_file)
expected = re.escape(str(expected).replace("'", '"').replace(" ", "")).replace('"\\*"', "(.*)")
assert re.fullmatch(expected, works_string)

# Cleanup PATH to prevent conflict with other tests
sys.path = path
os.remove(app_file)


def test_incompatible_cloud_compute_and_build_config():
"""Test that an exception is raised when a build config has a custom image defined, but the cloud compute is
the default.
Expand Down