diff --git a/docs/source-app/workflows/byoc/index.rst b/docs/source-app/workflows/byoc/index.rst index 2cabf046939ba..1134301e6dd43 100644 --- a/docs/source-app/workflows/byoc/index.rst +++ b/docs/source-app/workflows/byoc/index.rst @@ -63,7 +63,7 @@ Parameters ^^^^^^^^^^ +------------------------+----------------------------------------------------------------------------------------------------+ -|Parameter | Descritption | +|Parameter | Description | +========================+====================================================================================================+ | provider | The cloud provider where your cluster is located. | | | | @@ -78,18 +78,7 @@ Parameters +------------------------+----------------------------------------------------------------------------------------------------+ | region | AWS region containing compute resources | +------------------------+----------------------------------------------------------------------------------------------------+ -| instance-types | Instance types that you want to support, for computer jobs within the cluster. | -| | | -| | For now, this is the AWS instance types supported by the cluster. | -+------------------------+----------------------------------------------------------------------------------------------------+ -| enable-performance | Specifies if the cluster uses cost savings mode. | -| | | -| | In cost saving mode the number of compute nodes is reduced to one, reducing the cost for clusters | -| | with low utilization. | -+------------------------+----------------------------------------------------------------------------------------------------+ -| edit-before-creation | Enables interactive editing of requests before submitting it to Lightning AI. | -+------------------------+----------------------------------------------------------------------------------------------------+ -| wait | Waits for the cluster to be in a RUNNING state. Only use this for debugging. | +| async | Cluster creation will happen in the background. | +------------------------+----------------------------------------------------------------------------------------------------+ ---- diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index e0357d8a28f41..c143a5fc7154e 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed - The `MultiNode` components now warn the user when running with `num_nodes > 1` locally ([#15806](https://github.com/Lightning-AI/lightning/pull/15806)) +- Cluster creation and deletion now waits by default [#15458](https://github.com/Lightning-AI/lightning/pull/15458) ### Deprecated diff --git a/src/lightning_app/cli/cmd_clusters.py b/src/lightning_app/cli/cmd_clusters.py index e76b9c0695842..c1baf11c5e273 100644 --- a/src/lightning_app/cli/cmd_clusters.py +++ b/src/lightning_app/cli/cmd_clusters.py @@ -3,9 +3,10 @@ import time from datetime import datetime from textwrap import dedent -from typing import Any, List +from typing import Any, List, Union import click +import lightning_cloud from lightning_cloud.openapi import ( Externalv1Cluster, V1AWSClusterDriverSpec, @@ -15,8 +16,10 @@ V1ClusterState, V1ClusterType, V1CreateClusterRequest, + V1GetClusterResponse, V1KubernetesClusterDriver, ) +from lightning_utilities.core.enums import StrEnum from rich.console import Console from rich.table import Table from rich.text import Text @@ -25,10 +28,26 @@ from lightning_app.utilities.network import LightningClient from lightning_app.utilities.openapi import create_openapi_object, string2dict -CLUSTER_STATE_CHECKING_TIMEOUT = 60 MAX_CLUSTER_WAIT_TIME = 5400 +class ClusterState(StrEnum): + UNSPECIFIED = "unspecified" + QUEUED = "queued" + PENDING = "pending" + RUNNING = "running" + FAILED = "error" + DELETED = "deleted" + + def __str__(self) -> str: + return str(self.value) + + @classmethod + def from_api(cls, status: V1ClusterState) -> "ClusterState": + parsed = str(status).lower().split("_", maxsplit=2)[-1] + return cls.from_str(parsed) + + class ClusterList(Formatable): def __init__(self, clusters: List[Externalv1Cluster]): self.clusters = clusters @@ -86,7 +105,7 @@ def create( region: str = "us-east-1", external_id: str = None, edit_before_creation: bool = False, - wait: bool = False, + do_async: bool = False, ) -> None: """request Lightning AI BYOC compute cluster creation. @@ -97,7 +116,7 @@ def create( region: AWS region containing compute resources external_id: AWS IAM Role external ID edit_before_creation: Enables interactive editing of requests before submitting it to Lightning AI. - wait: Waits for the cluster to be in a RUNNING state. Only use this for debugging. + do_async: Triggers cluster creation in the background and exits """ performance_profile = V1ClusterPerformanceProfile.DEFAULT if cost_savings: @@ -130,22 +149,31 @@ def create( click.echo("cluster unchanged") resp = self.api_client.cluster_service_create_cluster(body=new_body) - if wait: - _wait_for_cluster_state(self.api_client, resp.id, V1ClusterState.RUNNING) - click.echo( dedent( f"""\ - {resp.id} is now being created... This can take up to an hour. + BYOC cluster creation triggered successfully! + This can take up to an hour to complete. To view the status of your clusters use: - `lightning list clusters` + lightning list clusters To view cluster logs use: - `lightning show cluster logs {resp.id}` - """ + lightning show cluster logs {cluster_name} + + To delete the cluster run: + lightning delete cluster {cluster_name} + """ ) ) + background_message = "\nCluster will be created in the background!" + if do_async: + click.echo(background_message) + else: + try: + _wait_for_cluster_state(self.api_client, resp.id, V1ClusterState.RUNNING) + except KeyboardInterrupt: + click.echo(background_message) def get_clusters(self) -> ClusterList: resp = self.api_client.cluster_service_list_clusters(phase_not_in=[V1ClusterState.DELETED]) @@ -156,7 +184,7 @@ def list(self) -> None: console = Console() console.print(clusters.as_table()) - def delete(self, cluster_id: str, force: bool = False, wait: bool = False) -> None: + def delete(self, cluster_id: str, force: bool = False, do_async: bool = False) -> None: if force: click.echo( """ @@ -167,47 +195,86 @@ def delete(self, cluster_id: str, force: bool = False, wait: bool = False) -> No ) click.confirm("Do you want to continue?", abort=True) + resp: V1GetClusterResponse = self.api_client.cluster_service_get_cluster(id=cluster_id) + bucket_name = resp.spec.driver.kubernetes.aws.bucket_name + self.api_client.cluster_service_delete_cluster(id=cluster_id, force=force) - click.echo("Cluster deletion triggered successfully") + click.echo( + dedent( + f"""\ + Cluster deletion triggered successfully + + For safety purposes we will not delete anything in the S3 bucket associated with the cluster: + {bucket_name} - if wait: - _wait_for_cluster_state(self.api_client, cluster_id, V1ClusterState.DELETED) + You may want to delete it manually using the AWS CLI: + aws s3 rb --force s3://{bucket_name} + """ + ) + ) + + background_message = "\nCluster will be deleted in the background!" + if do_async: + click.echo(background_message) + else: + try: + _wait_for_cluster_state(self.api_client, cluster_id, V1ClusterState.DELETED) + except KeyboardInterrupt: + click.echo(background_message) def _wait_for_cluster_state( api_client: LightningClient, cluster_id: str, target_state: V1ClusterState, - max_wait_time: int = MAX_CLUSTER_WAIT_TIME, - check_timeout: int = CLUSTER_STATE_CHECKING_TIMEOUT, + timeout_seconds: int = MAX_CLUSTER_WAIT_TIME, + poll_duration_seconds: int = 10, ) -> None: """_wait_for_cluster_state waits until the provided cluster has reached a desired state, or failed. + Messages will be displayed to the user as the cluster changes state. + We poll the API server for any changes + Args: api_client: LightningClient used for polling cluster_id: Specifies the cluster to wait for target_state: Specifies the desired state the target cluster needs to meet - max_wait_time: Maximum duration to wait (in seconds) - check_timeout: duration between polling for the cluster state (in seconds) + timeout_seconds: Maximum duration to wait + poll_duration_seconds: duration between polling for the cluster state """ start = time.time() elapsed = 0 - while elapsed < max_wait_time: - cluster_resp = api_client.cluster_service_list_clusters() - new_cluster = None - for clust in cluster_resp.clusters: - if clust.id == cluster_id: - new_cluster = clust - break - if new_cluster is not None: - if new_cluster.status.phase == target_state: + + click.echo(f"Waiting for cluster to be {ClusterState.from_api(target_state)}...") + while elapsed < timeout_seconds: + try: + resp: V1GetClusterResponse = api_client.cluster_service_get_cluster(id=cluster_id) + click.echo(_cluster_status_long(cluster=resp, desired_state=target_state, elapsed=elapsed)) + if resp.status.phase == target_state: break - elif new_cluster.status.phase == V1ClusterState.FAILED: - raise click.ClickException(f"Cluster {cluster_id} is in failed state.") - time.sleep(check_timeout) - elapsed = int(time.time() - start) + time.sleep(poll_duration_seconds) + elapsed = int(time.time() - start) + except lightning_cloud.openapi.rest.ApiException as e: + if e.status == 404 and target_state == V1ClusterState.DELETED: + return + raise else: - raise click.ClickException("Max wait time elapsed") + state_str = ClusterState.from_api(target_state) + raise click.ClickException( + dedent( + f"""\ + The cluster has not entered the {state_str} state within {_format_elapsed_seconds(timeout_seconds)}. + + The cluster may eventually be {state_str} afterwards, please check its status using: + lighting list clusters + + To view cluster logs use: + lightning show cluster logs {cluster_id} + + Contact support@lightning.ai for additional help + """ + ) + ) def _check_cluster_name_is_valid(_ctx: Any, _param: Any, value: str) -> str: @@ -219,3 +286,76 @@ def _check_cluster_name_is_valid(_ctx: Any, _param: Any, value: str) -> str: Provide a cluster name using valid characters and try again.""" ) return value + + +def _cluster_status_long(cluster: V1GetClusterResponse, desired_state: V1ClusterState, elapsed: float) -> str: + """Echos a long-form status message to the user about the cluster state. + + Args: + cluster: The cluster object + elapsed: Seconds since we've started polling + """ + + cluster_name = cluster.name + current_state = cluster.status.phase + current_reason = cluster.status.reason + bucket_name = cluster.spec.driver.kubernetes.aws.bucket_name + + duration = _format_elapsed_seconds(elapsed) + + if current_state == V1ClusterState.FAILED: + return dedent( + f"""\ + The requested cluster operation for cluster {cluster_name} has errors: + {current_reason} + + --- + We are automatically retrying, and an automated alert has been created + + WARNING: Any non-deleted cluster may be using resources. + To avoid incuring cost on your cloud provider, delete the cluster using the following command: + lightning delete cluster {cluster_name} + + Contact support@lightning.ai for additional help + """ + ) + + if desired_state == current_state == V1ClusterState.RUNNING: + return dedent( + f"""\ + Cluster {cluster_name} is now running and ready to use. + To launch an app on this cluster use: lightning run app app.py --cloud --cluster-id {cluster_name} + """ + ) + + if desired_state == V1ClusterState.RUNNING: + return f"Cluster {cluster_name} is being created [elapsed={duration}]" + + if desired_state == current_state == V1ClusterState.DELETED: + return dedent( + f"""\ + Cluster {cluster_name} has been successfully deleted, and almost all AWS resources have been removed + + For safety purposes we kept the S3 bucket associated with the cluster: {bucket_name} + + You may want to delete it manually using the AWS CLI: + aws s3 rb --force s3://{bucket_name} + """ + ) + + if desired_state == V1ClusterState.DELETED: + return f"Cluster {cluster_name} is being deleted [elapsed={duration}]" + + raise click.ClickException(f"Unknown cluster desired state {desired_state}") + + +def _format_elapsed_seconds(seconds: Union[float, int]) -> str: + """Turns seconds into a duration string. + + >>> _format_elapsed_seconds(5) + '05s' + >>> _format_elapsed_seconds(60) + '01m00s' + """ + minutes, seconds = divmod(seconds, 60) + return (f"{minutes:02}m" if minutes else "") + f"{seconds:02}s" diff --git a/src/lightning_app/cli/lightning_cli_create.py b/src/lightning_app/cli/lightning_cli_create.py index 75803056a85ce..6ff602b451459 100644 --- a/src/lightning_app/cli/lightning_cli_create.py +++ b/src/lightning_app/cli/lightning_cli_create.py @@ -37,6 +37,7 @@ def create() -> None: type=bool, required=False, default=False, + hidden=True, is_flag=True, help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for performance. This makes runs more expensive but start-up times decrease.""", @@ -45,16 +46,17 @@ def create() -> None: "--edit-before-creation", default=False, is_flag=True, + hidden=True, help="Edit the cluster specs before submitting them to the API server.", ) @click.option( - "--wait", - "wait", + "--async", + "do_async", type=bool, required=False, default=False, is_flag=True, - help="Enabling this flag makes the CLI wait until the cluster is running.", + help="This flag makes the CLI return immediately and lets the cluster creation happen in the background.", ) def create_cluster( cluster_name: str, @@ -64,7 +66,7 @@ def create_cluster( provider: str, edit_before_creation: bool, enable_performance: bool, - wait: bool, + do_async: bool, **kwargs: Any, ) -> None: """Create a Lightning AI BYOC compute cluster with your cloud provider credentials.""" @@ -79,7 +81,7 @@ def create_cluster( external_id=external_id, edit_before_creation=edit_before_creation, cost_savings=not enable_performance, - wait=wait, + do_async=do_async, ) diff --git a/src/lightning_app/cli/lightning_cli_delete.py b/src/lightning_app/cli/lightning_cli_delete.py index bbe2508c27d3a..cf9915cd6f832 100644 --- a/src/lightning_app/cli/lightning_cli_delete.py +++ b/src/lightning_app/cli/lightning_cli_delete.py @@ -13,27 +13,15 @@ def delete() -> None: @delete.command("cluster") @click.argument("cluster", type=str) @click.option( - "--force", - "force", + "--async", + "do_async", type=bool, required=False, default=False, is_flag=True, - help="""Delete a BYOC cluster from Lightning AI. This does NOT delete any resources created by the cluster, - it just removes the entry from Lightning AI. - - WARNING: You should NOT use this under normal circumstances.""", -) -@click.option( - "--wait", - "wait", - type=bool, - required=False, - default=False, - is_flag=True, - help="Enabling this flag makes the CLI wait until the cluster is deleted.", + help="This flag makes the CLI return immediately and lets the cluster deletion happen in the background", ) -def delete_cluster(cluster: str, force: bool = False, wait: bool = False) -> None: +def delete_cluster(cluster: str, force: bool = False, do_async: bool = False) -> None: """Delete a Lightning AI BYOC cluster and all associated cloud provider resources. Deleting a cluster also deletes all apps that were started on the cluster. @@ -49,7 +37,7 @@ def delete_cluster(cluster: str, force: bool = False, wait: bool = False) -> Non VPC components, etc. are irreversibly deleted and cannot be recovered! """ cluster_manager = AWSClusterManager() - cluster_manager.delete(cluster_id=cluster, force=force, wait=wait) + cluster_manager.delete(cluster_id=cluster, force=force, do_async=do_async) @delete.command("ssh-key") diff --git a/tests/tests_app/cli/test_cli.py b/tests/tests_app/cli/test_cli.py index c3f5085d9c322..d5a3c4780a248 100644 --- a/tests/tests_app/cli/test_cli.py +++ b/tests/tests_app/cli/test_cli.py @@ -106,14 +106,7 @@ def test_main_lightning_cli_help(): @mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) @mock.patch("lightning_app.cli.cmd_clusters.AWSClusterManager.create") -@pytest.mark.parametrize( - "extra_arguments,expected_cost_savings_mode", - [ - ([], True), - (["--enable-performance"], False), - ], -) -def test_create_cluster(create_command: mock.MagicMock, extra_arguments, expected_cost_savings_mode): +def test_create_cluster(create_command: mock.MagicMock): runner = CliRunner() runner.invoke( create_cluster, @@ -125,8 +118,7 @@ def test_create_cluster(create_command: mock.MagicMock, extra_arguments, expecte "dummy", "--role-arn", "arn:aws:iam::1234567890:role/lai-byoc", - ] - + extra_arguments, + ], ) create_command.assert_called_once_with( @@ -135,8 +127,8 @@ def test_create_cluster(create_command: mock.MagicMock, extra_arguments, expecte role_arn="arn:aws:iam::1234567890:role/lai-byoc", external_id="dummy", edit_before_creation=False, - cost_savings=expected_cost_savings_mode, - wait=False, + cost_savings=True, + do_async=False, ) @@ -164,7 +156,7 @@ def test_delete_cluster(delete: mock.MagicMock): runner = CliRunner() runner.invoke(delete_cluster, ["test-7"]) - delete.assert_called_once_with(cluster_id="test-7", force=False, wait=False) + delete.assert_called_once_with(cluster_id="test-7", force=False, do_async=False) @mock.patch("lightning_app.utilities.login.Auth._run_server") diff --git a/tests/tests_app/cli/test_cmd_clusters.py b/tests/tests_app/cli/test_cmd_clusters.py index 92df6c172c9f0..a39610f0bb4a8 100644 --- a/tests/tests_app/cli/test_cmd_clusters.py +++ b/tests/tests_app/cli/test_cmd_clusters.py @@ -4,7 +4,6 @@ import click import pytest from lightning_cloud.openapi import ( - Externalv1Cluster, V1AWSClusterDriverSpec, V1ClusterDriver, V1ClusterPerformanceProfile, @@ -13,36 +12,58 @@ V1ClusterStatus, V1ClusterType, V1CreateClusterRequest, + V1GetClusterResponse, V1KubernetesClusterDriver, - V1ListClustersResponse, ) from lightning_app.cli import cmd_clusters from lightning_app.cli.cmd_clusters import AWSClusterManager +@pytest.fixture(params=[True, False]) +def async_or_interrupt(request, monkeypatch): + # Simulate hitting ctrl-c immediately while waiting for cluster to create + if not request.param: + monkeypatch.setattr(cmd_clusters, "_wait_for_cluster_state", mock.MagicMock(side_effect=KeyboardInterrupt)) + return request.param + + +@pytest.fixture +def spec(): + return V1ClusterSpec( + driver=V1ClusterDriver( + kubernetes=V1KubernetesClusterDriver( + aws=V1AWSClusterDriverSpec( + bucket_name="test-bucket", + ), + ), + ), + ) + + class FakeLightningClient: - def __init__(self, list_responses=[], consume=True): - self.list_responses = list_responses - self.list_call_count = 0 + def __init__(self, get_responses=[], consume=True): + self.get_responses = get_responses + self.get_call_count = 0 self.consume = consume - def cluster_service_list_clusters(self, phase_not_in=None): - self.list_call_count = self.list_call_count + 1 + def cluster_service_get_cluster(self, id: str): + self.get_call_count = self.get_call_count + 1 if self.consume: - return self.list_responses.pop() - return self.list_responses[0] + return self.get_responses.pop(0) + return self.get_responses[0] @mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) @mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_create_cluster") -def test_create_cluster(api: mock.MagicMock): +def test_create_cluster_api(api: mock.MagicMock, async_or_interrupt): cluster_manager = AWSClusterManager() cluster_manager.create( cluster_name="test-7", external_id="dummy", role_arn="arn:aws:iam::1234567890:role/lai-byoc", region="us-west-2", + do_async=async_or_interrupt, ) api.assert_called_once_with( @@ -76,11 +97,13 @@ def test_list_clusters(api: mock.MagicMock): @mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock()) @mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_delete_cluster") -def test_delete_cluster(api: mock.MagicMock): +@mock.patch("lightning_app.utilities.network.LightningClient.cluster_service_get_cluster") +def test_delete_cluster_api(api_get: mock.MagicMock, api_delete: mock.MagicMock, async_or_interrupt, spec): + api_get.return_value = V1GetClusterResponse(spec=spec) cluster_manager = AWSClusterManager() - cluster_manager.delete(cluster_id="test-7") + cluster_manager.delete(cluster_id="test-7", do_async=async_or_interrupt) - api.assert_called_once_with(id="test-7", force=False) + api_delete.assert_called_once_with(id="test-7", force=False) class Test_check_cluster_name_is_valid: @@ -104,32 +127,40 @@ class Test_wait_for_cluster_state: @pytest.mark.parametrize( "previous_state", [V1ClusterState.QUEUED, V1ClusterState.PENDING, V1ClusterState.UNSPECIFIED] ) - def test_happy_path(self, target_state, previous_state): + def test_happy_path(self, target_state, previous_state, spec): client = FakeLightningClient( - list_responses=[ - V1ListClustersResponse( - clusters=[Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=state))] + get_responses=[ + V1GetClusterResponse( + id="test-cluster", + status=V1ClusterStatus(phase=state), + spec=spec, ) for state in [previous_state, target_state] ] ) - cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, check_timeout=0.1) - assert client.list_call_count == 1 + cmd_clusters._wait_for_cluster_state(client, "test-cluster", target_state, poll_duration_seconds=0.1) + assert client.get_call_count == 2 @pytest.mark.parametrize("target_state", [V1ClusterState.RUNNING, V1ClusterState.DELETED]) - def test_times_out(self, target_state): + def test_times_out(self, target_state, spec): client = FakeLightningClient( - list_responses=[ - V1ListClustersResponse( - clusters=[ - Externalv1Cluster(id="test-cluster", status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED)) - ] + get_responses=[ + V1GetClusterResponse( + id="test-cluster", + status=V1ClusterStatus(phase=V1ClusterState.UNSPECIFIED), + spec=spec, ) ], consume=False, ) with pytest.raises(click.ClickException) as e: cmd_clusters._wait_for_cluster_state( - client, "test-cluster", target_state, max_wait_time=0.4, check_timeout=0.2 + client, "test-cluster", target_state, timeout_seconds=0.1, poll_duration_seconds=0.1 ) - assert "Max wait time elapsed" in str(e.value) + + if target_state == V1ClusterState.DELETED: + expected_state = "deleted" + if target_state == V1ClusterState.RUNNING: + expected_state = "running" + + assert e.match(f"The cluster has not entered the {expected_state} state")