From 26e9067396edde8af3c8272fc75b19a35de485f9 Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 00:07:39 +0800 Subject: [PATCH 01/11] feat: make per-cloud catalog lookup parallel Signed-off-by: Aylei --- sky/clouds/service_catalog/__init__.py | 10 ++++++---- sky/clouds/service_catalog/aws_catalog.py | 2 +- sky/clouds/service_catalog/common.py | 4 ++-- sky/optimizer.py | 7 +++++-- sky/utils/subprocess_utils.py | 15 ++++++++++++++- 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index d28b530ff06..d503edb3a80 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -9,7 +9,7 @@ from sky.clouds.service_catalog.constants import CATALOG_DIR from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL -from sky.utils import resources_utils +from sky.utils import resources_utils, subprocess_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -31,8 +31,7 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): if single: clouds = [clouds] # type: ignore - results = [] - for cloud in clouds: + def _execute_catalog_method(cloud: str): try: cloud_module = importlib.import_module( f'sky.clouds.service_catalog.{cloud.lower()}_catalog') @@ -46,7 +45,9 @@ def _map_clouds_catalog(clouds: CloudFilter, method_name: str, *args, **kwargs): raise AttributeError( f'Module "{cloud}_catalog" does not ' f'implement the "{method_name}" method') from None - results.append(method(*args, **kwargs)) + return method(*args, **kwargs) + + results = subprocess_utils.maybe_parallelize_cloud_operation(_execute_catalog_method, clouds) if single: return results[0] return results @@ -360,6 +361,7 @@ def is_image_tag_valid(tag: str, return _map_clouds_catalog(clouds, 'is_image_tag_valid', tag, region) + __all__ = [ 'list_accelerators', 'list_accelerator_counts', diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index bbd48863755..0557d2babae 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -101,7 +101,6 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings -@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). @@ -292,6 +291,7 @@ def get_region_zones_for_instance_type(instance_type: str, return us_region_list + other_region_list +@timeline.event def list_accelerators( gpus_only: bool, name_filter: Optional[str], diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 67c6e09b27e..31f63f9af3f 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -5,7 +5,7 @@ import os import time import typing -from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union import filelock import requests @@ -15,7 +15,7 @@ from sky.clouds import cloud as cloud_lib from sky.clouds import cloud_registry from sky.clouds.service_catalog import constants -from sky.utils import common_utils +from sky.utils import common_utils, subprocess_utils from sky.utils import rich_utils from sky.utils import ux_utils diff --git a/sky/optimizer.py b/sky/optimizer.py index 2f70dd39429..1c94b00274f 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -1293,9 +1293,12 @@ def _fill_in_launchable_resources( if resources.cloud is not None else enabled_clouds) # If clouds provide hints, store them for later printing. hints: Dict[clouds.Cloud, str] = {} - for cloud in clouds_list: - feasible_resources = cloud.get_feasible_launchable_resources( + def _get_feasible_launchable_resources(cloud: clouds.Cloud) -> Tuple[clouds.Cloud, resources_lib.Resources]: + return cloud, cloud.get_feasible_launchable_resources( resources, num_nodes=task.num_nodes) + feasible_resources_list = subprocess_utils.maybe_parallelize_cloud_operation( + _get_feasible_launchable_resources, clouds_list) + for cloud, feasible_resources in feasible_resources_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint if len(feasible_resources.resources_list) > 0: diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 992c6bbe3ff..62030f20021 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -120,7 +120,6 @@ def run_in_parallel(func: Callable, # Run the function in parallel on the arguments, keeping the order. return list(p.imap(func, args)) - def handle_returncode(returncode: int, command: str, error_msg: Union[str, Callable[[], str]], @@ -293,3 +292,17 @@ def kill_process_daemon(process_pid: int) -> None: # Disable input stdin=subprocess.DEVNULL, ) + +def maybe_parallelize_cloud_operation(func: Callable, clouds: List[Any], + num_threads: Optional[int] = None) -> List[Any]: + """Apply a function to a list of clouds, with parallelism if there is more than one cloud.""" + count = len(clouds) + if count == 0: + return [] + # Short-circuit in single cloud setup. + if count == 1: + return [func(clouds[0])] + # Cloud operations are assumed to be IO-bound, so the parallelism is set to the number of clouds by default, + # we are still safe because the number of clouds is enumarable even if this assumption does not hold. + processes = num_threads if num_threads is not None else count + return run_in_parallel(func, clouds, processes) \ No newline at end of file From c9a2ce37d960129ec865fa97ca4cd90e6d822f0c Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 01:03:58 +0800 Subject: [PATCH 02/11] lint and test Signed-off-by: Aylei --- sky/clouds/service_catalog/__init__.py | 9 +++++---- sky/clouds/service_catalog/common.py | 4 ++-- sky/optimizer.py | 24 +++++++++++++++++------- sky/utils/subprocess_utils.py | 19 +++++++++++++------ 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index d503edb3a80..c871c382def 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -9,7 +9,8 @@ from sky.clouds.service_catalog.constants import CATALOG_DIR from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL -from sky.utils import resources_utils, subprocess_utils +from sky.utils import resources_utils +from sky.utils import subprocess_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -46,8 +47,9 @@ def _execute_catalog_method(cloud: str): f'Module "{cloud}_catalog" does not ' f'implement the "{method_name}" method') from None return method(*args, **kwargs) - - results = subprocess_utils.maybe_parallelize_cloud_operation(_execute_catalog_method, clouds) + + results = subprocess_utils.maybe_parallelize_cloud_operation( + _execute_catalog_method, clouds) # type: ignore if single: return results[0] return results @@ -361,7 +363,6 @@ def is_image_tag_valid(tag: str, return _map_clouds_catalog(clouds, 'is_image_tag_valid', tag, region) - __all__ = [ 'list_accelerators', 'list_accelerator_counts', diff --git a/sky/clouds/service_catalog/common.py b/sky/clouds/service_catalog/common.py index 31f63f9af3f..67c6e09b27e 100644 --- a/sky/clouds/service_catalog/common.py +++ b/sky/clouds/service_catalog/common.py @@ -5,7 +5,7 @@ import os import time import typing -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import filelock import requests @@ -15,7 +15,7 @@ from sky.clouds import cloud as cloud_lib from sky.clouds import cloud_registry from sky.clouds.service_catalog import constants -from sky.utils import common_utils, subprocess_utils +from sky.utils import common_utils from sky.utils import rich_utils from sky.utils import ux_utils diff --git a/sky/optimizer.py b/sky/optimizer.py index 1c94b00274f..0c29ced7878 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -4,7 +4,7 @@ import enum import json import typing -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple import colorama import numpy as np @@ -1254,6 +1254,18 @@ def _check_specified_clouds(dag: 'dag_lib.Dag') -> None: f'{colorama.Fore.YELLOW}{msg}{colorama.Style.RESET_ALL}') +def _make_resource_finder( + resources: resources_lib.Resources, + num_nodes: int, +) -> Callable[[clouds.Cloud], Tuple[clouds.Cloud, resources_lib.Resources]]: + + def fn(cloud: clouds.Cloud) -> Tuple[clouds.Cloud, resources_lib.Resources]: + return cloud, cloud.get_feasible_launchable_resources( + resources, num_nodes) + + return fn + + def _fill_in_launchable_resources( task: task_lib.Task, blocked_resources: Optional[Iterable[resources_lib.Resources]], @@ -1293,12 +1305,10 @@ def _fill_in_launchable_resources( if resources.cloud is not None else enabled_clouds) # If clouds provide hints, store them for later printing. hints: Dict[clouds.Cloud, str] = {} - def _get_feasible_launchable_resources(cloud: clouds.Cloud) -> Tuple[clouds.Cloud, resources_lib.Resources]: - return cloud, cloud.get_feasible_launchable_resources( - resources, num_nodes=task.num_nodes) - feasible_resources_list = subprocess_utils.maybe_parallelize_cloud_operation( - _get_feasible_launchable_resources, clouds_list) - for cloud, feasible_resources in feasible_resources_list: + + feasible_list = subprocess_utils.maybe_parallelize_cloud_operation( + _make_resource_finder(resources, task.num_nodes), clouds_list) + for cloud, feasible_resources in feasible_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint if len(feasible_resources.resources_list) > 0: diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 62030f20021..b74b55fddab 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -120,6 +120,7 @@ def run_in_parallel(func: Callable, # Run the function in parallel on the arguments, keeping the order. return list(p.imap(func, args)) + def handle_returncode(returncode: int, command: str, error_msg: Union[str, Callable[[], str]], @@ -293,16 +294,22 @@ def kill_process_daemon(process_pid: int) -> None: stdin=subprocess.DEVNULL, ) -def maybe_parallelize_cloud_operation(func: Callable, clouds: List[Any], - num_threads: Optional[int] = None) -> List[Any]: - """Apply a function to a list of clouds, with parallelism if there is more than one cloud.""" + +def maybe_parallelize_cloud_operation( + func: Callable, + clouds: List[Any], + num_threads: Optional[int] = None) -> List[Any]: + """Apply a function to a list of clouds, + with parallelism if there is more than one cloud. + """ count = len(clouds) if count == 0: return [] # Short-circuit in single cloud setup. if count == 1: return [func(clouds[0])] - # Cloud operations are assumed to be IO-bound, so the parallelism is set to the number of clouds by default, - # we are still safe because the number of clouds is enumarable even if this assumption does not hold. + # Cloud operations are assumed to be IO-bound, so the parallelism is set to + # the number of clouds by default, we are still safe because the number of + # clouds is enumarable even if this assumption does not hold. processes = num_threads if num_threads is not None else count - return run_in_parallel(func, clouds, processes) \ No newline at end of file + return run_in_parallel(func, clouds, processes) From fb6c8ace43432595028e01d0c7fd887ac0de8b1f Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 11:00:20 +0800 Subject: [PATCH 03/11] address review comments Signed-off-by: Aylei --- sky/clouds/service_catalog/__init__.py | 4 ++-- sky/optimizer.py | 17 +++------------- sky/utils/subprocess_utils.py | 27 +++++++------------------- 3 files changed, 12 insertions(+), 36 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index c871c382def..0698aef4fee 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -48,8 +48,8 @@ def _execute_catalog_method(cloud: str): f'implement the "{method_name}" method') from None return method(*args, **kwargs) - results = subprocess_utils.maybe_parallelize_cloud_operation( - _execute_catalog_method, clouds) # type: ignore + results = subprocess_utils.run_in_parallel(_execute_catalog_method, clouds, + len(clouds)) # type: ignore if single: return results[0] return results diff --git a/sky/optimizer.py b/sky/optimizer.py index 0c29ced7878..d0f2999256f 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -1254,18 +1254,6 @@ def _check_specified_clouds(dag: 'dag_lib.Dag') -> None: f'{colorama.Fore.YELLOW}{msg}{colorama.Style.RESET_ALL}') -def _make_resource_finder( - resources: resources_lib.Resources, - num_nodes: int, -) -> Callable[[clouds.Cloud], Tuple[clouds.Cloud, resources_lib.Resources]]: - - def fn(cloud: clouds.Cloud) -> Tuple[clouds.Cloud, resources_lib.Resources]: - return cloud, cloud.get_feasible_launchable_resources( - resources, num_nodes) - - return fn - - def _fill_in_launchable_resources( task: task_lib.Task, blocked_resources: Optional[Iterable[resources_lib.Resources]], @@ -1306,8 +1294,9 @@ def _fill_in_launchable_resources( # If clouds provide hints, store them for later printing. hints: Dict[clouds.Cloud, str] = {} - feasible_list = subprocess_utils.maybe_parallelize_cloud_operation( - _make_resource_finder(resources, task.num_nodes), clouds_list) + feasible_list = subprocess_utils.run_in_parallel( + lambda cloud, r=resources, n=task.num_nodes: + (cloud, cloud.get_feasible_launchable_resources(r, n)), clouds_list) for cloud, feasible_resources in feasible_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index b74b55fddab..dbacf33a7e2 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -1,4 +1,5 @@ """Utility functions for subprocesses.""" +import collections from multiprocessing import pool import os import random @@ -113,6 +114,12 @@ def run_in_parallel(func: Callable, A list of the return values of the function func, in the same order as the arguments. """ + if isinstance(args, collections.abc.Sized): + if len(args) == 0: + return [] + # Short-circuit for single element + if len(args) == 1: + return [func(next(iter(args)))] # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long processes = num_threads if num_threads is not None else get_parallel_threads( ) @@ -293,23 +300,3 @@ def kill_process_daemon(process_pid: int) -> None: # Disable input stdin=subprocess.DEVNULL, ) - - -def maybe_parallelize_cloud_operation( - func: Callable, - clouds: List[Any], - num_threads: Optional[int] = None) -> List[Any]: - """Apply a function to a list of clouds, - with parallelism if there is more than one cloud. - """ - count = len(clouds) - if count == 0: - return [] - # Short-circuit in single cloud setup. - if count == 1: - return [func(clouds[0])] - # Cloud operations are assumed to be IO-bound, so the parallelism is set to - # the number of clouds by default, we are still safe because the number of - # clouds is enumarable even if this assumption does not hold. - processes = num_threads if num_threads is not None else count - return run_in_parallel(func, clouds, processes) From 3571b55a7b172ac65072776d508a47a29d09c77d Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 11:02:36 +0800 Subject: [PATCH 04/11] fix lint Signed-off-by: Aylei --- sky/optimizer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sky/optimizer.py b/sky/optimizer.py index d0f2999256f..d22029f1dc9 100644 --- a/sky/optimizer.py +++ b/sky/optimizer.py @@ -4,7 +4,7 @@ import enum import json import typing -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple import colorama import numpy as np @@ -1296,7 +1296,8 @@ def _fill_in_launchable_resources( feasible_list = subprocess_utils.run_in_parallel( lambda cloud, r=resources, n=task.num_nodes: - (cloud, cloud.get_feasible_launchable_resources(r, n)), clouds_list) + (cloud, cloud.get_feasible_launchable_resources(r, n)), + clouds_list) for cloud, feasible_resources in feasible_list: if feasible_resources.hint is not None: hints[cloud] = feasible_resources.hint From e8e13bbac13a970fae57f49f51083c09f5c02650 Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 20:26:59 +0800 Subject: [PATCH 05/11] feat: support sub thread status attaching Signed-off-by: Aylei --- sky/adaptors/oci.py | 4 ++ sky/clouds/service_catalog/__init__.py | 6 +- sky/clouds/service_catalog/aws_catalog.py | 1 + sky/task.py | 2 +- sky/utils/dag_utils.py | 2 +- sky/utils/rich_utils.py | 76 ++++++++++++++++++++--- 6 files changed, 76 insertions(+), 15 deletions(-) diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py index 7a5fafa854a..9fcb63dce0d 100644 --- a/sky/adaptors/oci.py +++ b/sky/adaptors/oci.py @@ -1,9 +1,13 @@ """Oracle OCI cloud adaptor""" import os +import logging from sky.adaptors import common +# Get rid of circuit breaker info logs, which may mess up console status +logging.getLogger('oci.circuit_breaker').setLevel(logging.WARNING) + CONFIG_PATH = '~/.oci/config' ENV_VAR_OCI_CONFIG = 'OCI_CONFIG' diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 0698aef4fee..e2679f5149e 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -9,8 +9,7 @@ from sky.clouds.service_catalog.constants import CATALOG_DIR from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL -from sky.utils import resources_utils -from sky.utils import subprocess_utils +from sky.utils import resources_utils, subprocess_utils, rich_utils, ux_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -74,7 +73,8 @@ def list_accelerators( Returns: A dictionary of canonical accelerator names mapped to a list of instance type offerings. See usage in cli.py. """ - results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, + with rich_utils.safe_status(ux_utils.spinner_message('Listing accelerators')): + results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, name_filter, region_filter, quantity_filter, case_sensitive, all_regions, require_price) if not isinstance(results, list): diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 0557d2babae..218e6fc4742 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -101,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/task.py b/sky/task.py index cebc616dc6d..9b27a6fb948 100644 --- a/sky/task.py +++ b/sky/task.py @@ -23,7 +23,7 @@ from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils - +from sky.utils import timeline if typing.TYPE_CHECKING: from sky import resources as resources_lib diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 3229f86abf9..9b62908af35 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -8,7 +8,7 @@ from sky import task as task_lib from sky.backends import backend_utils from sky.utils import common_utils -from sky.utils import ux_utils +from sky.utils import ux_utils, rich_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/utils/rich_utils.py b/sky/utils/rich_utils.py index 6badf621294..a016b42b759 100644 --- a/sky/utils/rich_utils.py +++ b/sky/utils/rich_utils.py @@ -1,16 +1,20 @@ """Rich status spinner utils.""" import contextlib import threading -from typing import Union +from typing import Union, Dict, Optional import rich.console as rich_console console = rich_console.Console(soft_wrap=True) _status = None _status_nesting_level = 0 +_main_message = None _logging_lock = threading.RLock() +# Track sub thread progress statuses +_thread_statuses: Dict[int, Optional[str]] = {} +_status_lock = threading.RLock() class _NoOpConsoleStatus: """An empty class for multi-threaded console.status.""" @@ -35,15 +39,17 @@ class _RevertibleStatus: """A wrapper for status that can revert to previous message after exit.""" def __init__(self, message: str): - if _status is not None: - self.previous_message = _status.status + if _main_message is not None: + self.previous_message = _main_message else: self.previous_message = None self.message = message def __enter__(self): global _status_nesting_level - _status.update(self.message) + global _main_message + _main_message = self.message + refresh() _status_nesting_level += 1 _status.__enter__() return _status @@ -57,10 +63,15 @@ def __exit__(self, exc_type, exc_val, exc_tb): _status.__exit__(exc_type, exc_val, exc_tb) _status = None else: - _status.update(self.previous_message) + global _main_message + _main_message = self.previous_message + refresh() def update(self, *args, **kwargs): _status.update(*args, **kwargs) + global _main_message + _main_message = _status.status + refresh() def stop(self): _status.stop() @@ -69,17 +80,62 @@ def start(self): _status.start() -def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]: +class _ThreadStatus: + """A wrapper of sub thread status""" + def __init__(self, message: str): + self.thread_id = threading.get_ident() + self.message = message + self.previous_message = _thread_statuses.get(self.thread_id) + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.previous_message is not None: + _thread_statuses[self.thread_id] = self.previous_message + else: + # No previous message, remove the thread status + if self.thread_id in _thread_statuses: + del _thread_statuses[self.thread_id] + refresh() + + def update(self, new_message: str): + self.message = new_message + _thread_statuses[self.thread_id] = new_message + refresh() + + def stop(self): + _thread_statuses[self.thread_id] = None + refresh() + + def start(self): + _thread_statuses[self.thread_id] = self.message + refresh() + +def refresh(): + """Refresh status to include all thread statuses.""" + if _status is None or _main_message is None: + return + with _status_lock: + msg = _main_message + for v in _thread_statuses.values(): + if v is not None: + msg = msg + f'\n └─ {v}' + _status.update(msg) + +def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']: """A wrapper for multi-threaded console.status.""" from sky import sky_logging # pylint: disable=import-outside-toplevel global _status - if (threading.current_thread() is threading.main_thread() and - not sky_logging.is_silent()): + if sky_logging.is_silent(): + return _NoOpConsoleStatus() + if threading.current_thread() is threading.main_thread(): if _status is None: _status = console.status(msg, refresh_per_second=8) return _RevertibleStatus(msg) - return _NoOpConsoleStatus() - + else: + return _ThreadStatus(msg) def stop_safe_status(): """Stops all nested statuses. From dad3a6527c1efd4d208cd26081befbecc8e75b55 Mon Sep 17 00:00:00 2001 From: Aylei Date: Thu, 19 Dec 2024 20:29:40 +0800 Subject: [PATCH 06/11] fix lint Signed-off-by: Aylei --- sky/adaptors/oci.py | 2 +- sky/clouds/service_catalog/__init__.py | 13 +++++++++---- sky/task.py | 2 +- sky/utils/dag_utils.py | 2 +- sky/utils/rich_utils.py | 7 ++++++- 5 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sky/adaptors/oci.py b/sky/adaptors/oci.py index 9fcb63dce0d..8712a717503 100644 --- a/sky/adaptors/oci.py +++ b/sky/adaptors/oci.py @@ -1,7 +1,7 @@ """Oracle OCI cloud adaptor""" -import os import logging +import os from sky.adaptors import common diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index e2679f5149e..be665fa8b39 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -9,7 +9,10 @@ from sky.clouds.service_catalog.constants import CATALOG_DIR from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL -from sky.utils import resources_utils, subprocess_utils, rich_utils, ux_utils +from sky.utils import resources_utils +from sky.utils import rich_utils +from sky.utils import subprocess_utils +from sky.utils import ux_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -73,10 +76,12 @@ def list_accelerators( Returns: A dictionary of canonical accelerator names mapped to a list of instance type offerings. See usage in cli.py. """ - with rich_utils.safe_status(ux_utils.spinner_message('Listing accelerators')): + with rich_utils.safe_status( + ux_utils.spinner_message('Listing accelerators')): results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, - name_filter, region_filter, quantity_filter, - case_sensitive, all_regions, require_price) + name_filter, region_filter, + quantity_filter, case_sensitive, + all_regions, require_price) if not isinstance(results, list): results = [results] ret: Dict[str, diff --git a/sky/task.py b/sky/task.py index 9b27a6fb948..cebc616dc6d 100644 --- a/sky/task.py +++ b/sky/task.py @@ -23,7 +23,7 @@ from sky.utils import common_utils from sky.utils import schemas from sky.utils import ux_utils -from sky.utils import timeline + if typing.TYPE_CHECKING: from sky import resources as resources_lib diff --git a/sky/utils/dag_utils.py b/sky/utils/dag_utils.py index 9b62908af35..3229f86abf9 100644 --- a/sky/utils/dag_utils.py +++ b/sky/utils/dag_utils.py @@ -8,7 +8,7 @@ from sky import task as task_lib from sky.backends import backend_utils from sky.utils import common_utils -from sky.utils import ux_utils, rich_utils +from sky.utils import ux_utils logger = sky_logging.init_logger(__name__) diff --git a/sky/utils/rich_utils.py b/sky/utils/rich_utils.py index a016b42b759..d724c968045 100644 --- a/sky/utils/rich_utils.py +++ b/sky/utils/rich_utils.py @@ -1,7 +1,7 @@ """Rich status spinner utils.""" import contextlib import threading -from typing import Union, Dict, Optional +from typing import Dict, Optional, Union import rich.console as rich_console @@ -16,6 +16,7 @@ _thread_statuses: Dict[int, Optional[str]] = {} _status_lock = threading.RLock() + class _NoOpConsoleStatus: """An empty class for multi-threaded console.status.""" @@ -82,6 +83,7 @@ def start(self): class _ThreadStatus: """A wrapper of sub thread status""" + def __init__(self, message: str): self.thread_id = threading.get_ident() self.message = message @@ -113,6 +115,7 @@ def start(self): _thread_statuses[self.thread_id] = self.message refresh() + def refresh(): """Refresh status to include all thread statuses.""" if _status is None or _main_message is None: @@ -124,6 +127,7 @@ def refresh(): msg = msg + f'\n └─ {v}' _status.update(msg) + def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']: """A wrapper for multi-threaded console.status.""" from sky import sky_logging # pylint: disable=import-outside-toplevel @@ -137,6 +141,7 @@ def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']: else: return _ThreadStatus(msg) + def stop_safe_status(): """Stops all nested statuses. From 99a1edb00fa8105e49dc06ea99317371a28e24a8 Mon Sep 17 00:00:00 2001 From: Aylei Date: Fri, 20 Dec 2024 16:27:54 +0800 Subject: [PATCH 07/11] revert sub thread status Signed-off-by: Aylei --- sky/clouds/service_catalog/__init__.py | 11 +--- sky/clouds/service_catalog/aws_catalog.py | 1 - sky/utils/rich_utils.py | 79 +++-------------------- 3 files changed, 12 insertions(+), 79 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index be665fa8b39..0698aef4fee 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -10,9 +10,7 @@ from sky.clouds.service_catalog.constants import CATALOG_SCHEMA_VERSION from sky.clouds.service_catalog.constants import HOSTED_CATALOG_DIR_URL from sky.utils import resources_utils -from sky.utils import rich_utils from sky.utils import subprocess_utils -from sky.utils import ux_utils if typing.TYPE_CHECKING: from sky.clouds import cloud @@ -76,12 +74,9 @@ def list_accelerators( Returns: A dictionary of canonical accelerator names mapped to a list of instance type offerings. See usage in cli.py. """ - with rich_utils.safe_status( - ux_utils.spinner_message('Listing accelerators')): - results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, - name_filter, region_filter, - quantity_filter, case_sensitive, - all_regions, require_price) + results = _map_clouds_catalog(clouds, 'list_accelerators', gpus_only, + name_filter, region_filter, quantity_filter, + case_sensitive, all_regions, require_price) if not isinstance(results, list): results = [results] ret: Dict[str, diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 218e6fc4742..0557d2babae 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -101,7 +101,6 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings -@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). diff --git a/sky/utils/rich_utils.py b/sky/utils/rich_utils.py index d724c968045..6badf621294 100644 --- a/sky/utils/rich_utils.py +++ b/sky/utils/rich_utils.py @@ -1,21 +1,16 @@ """Rich status spinner utils.""" import contextlib import threading -from typing import Dict, Optional, Union +from typing import Union import rich.console as rich_console console = rich_console.Console(soft_wrap=True) _status = None _status_nesting_level = 0 -_main_message = None _logging_lock = threading.RLock() -# Track sub thread progress statuses -_thread_statuses: Dict[int, Optional[str]] = {} -_status_lock = threading.RLock() - class _NoOpConsoleStatus: """An empty class for multi-threaded console.status.""" @@ -40,17 +35,15 @@ class _RevertibleStatus: """A wrapper for status that can revert to previous message after exit.""" def __init__(self, message: str): - if _main_message is not None: - self.previous_message = _main_message + if _status is not None: + self.previous_message = _status.status else: self.previous_message = None self.message = message def __enter__(self): global _status_nesting_level - global _main_message - _main_message = self.message - refresh() + _status.update(self.message) _status_nesting_level += 1 _status.__enter__() return _status @@ -64,15 +57,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): _status.__exit__(exc_type, exc_val, exc_tb) _status = None else: - global _main_message - _main_message = self.previous_message - refresh() + _status.update(self.previous_message) def update(self, *args, **kwargs): _status.update(*args, **kwargs) - global _main_message - _main_message = _status.status - refresh() def stop(self): _status.stop() @@ -81,65 +69,16 @@ def start(self): _status.start() -class _ThreadStatus: - """A wrapper of sub thread status""" - - def __init__(self, message: str): - self.thread_id = threading.get_ident() - self.message = message - self.previous_message = _thread_statuses.get(self.thread_id) - - def __enter__(self): - self.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.previous_message is not None: - _thread_statuses[self.thread_id] = self.previous_message - else: - # No previous message, remove the thread status - if self.thread_id in _thread_statuses: - del _thread_statuses[self.thread_id] - refresh() - - def update(self, new_message: str): - self.message = new_message - _thread_statuses[self.thread_id] = new_message - refresh() - - def stop(self): - _thread_statuses[self.thread_id] = None - refresh() - - def start(self): - _thread_statuses[self.thread_id] = self.message - refresh() - - -def refresh(): - """Refresh status to include all thread statuses.""" - if _status is None or _main_message is None: - return - with _status_lock: - msg = _main_message - for v in _thread_statuses.values(): - if v is not None: - msg = msg + f'\n └─ {v}' - _status.update(msg) - - -def safe_status(msg: str) -> Union['rich_console.Status', '_NoOpConsoleStatus']: +def safe_status(msg: str) -> Union['rich_console.Status', _NoOpConsoleStatus]: """A wrapper for multi-threaded console.status.""" from sky import sky_logging # pylint: disable=import-outside-toplevel global _status - if sky_logging.is_silent(): - return _NoOpConsoleStatus() - if threading.current_thread() is threading.main_thread(): + if (threading.current_thread() is threading.main_thread() and + not sky_logging.is_silent()): if _status is None: _status = console.status(msg, refresh_per_second=8) return _RevertibleStatus(msg) - else: - return _ThreadStatus(msg) + return _NoOpConsoleStatus() def stop_safe_status(): From 2ff4517c2ff7aadf5ef2a5f29c2ada96190ff92a Mon Sep 17 00:00:00 2001 From: Aylei Date: Fri, 20 Dec 2024 16:51:34 +0800 Subject: [PATCH 08/11] address review comments Signed-off-by: Aylei --- sky/backends/cloud_vm_ray_backend.py | 2 +- sky/clouds/service_catalog/__init__.py | 4 ++-- sky/provision/azure/instance.py | 2 +- sky/provision/kubernetes/instance.py | 3 ++- sky/utils/subprocess_utils.py | 16 +++++++--------- 5 files changed, 13 insertions(+), 14 deletions(-) diff --git a/sky/backends/cloud_vm_ray_backend.py b/sky/backends/cloud_vm_ray_backend.py index 8974a0129bd..99726719f1d 100644 --- a/sky/backends/cloud_vm_ray_backend.py +++ b/sky/backends/cloud_vm_ray_backend.py @@ -3309,7 +3309,7 @@ def error_message() -> str: # even if some of them raise exceptions. We should replace it with # multi-process. rich_utils.stop_safe_status() - subprocess_utils.run_in_parallel(_setup_node, range(num_nodes)) + subprocess_utils.run_in_parallel(_setup_node, list(range(num_nodes))) if detach_setup: # Only set this when setup needs to be run outside the self._setup() diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 0698aef4fee..9c096444bfb 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -48,8 +48,8 @@ def _execute_catalog_method(cloud: str): f'implement the "{method_name}" method') from None return method(*args, **kwargs) - results = subprocess_utils.run_in_parallel(_execute_catalog_method, clouds, - len(clouds)) # type: ignore + results = subprocess_utils.run_in_parallel(_execute_catalog_method, + list(clouds), len(clouds)) if single: return results[0] return results diff --git a/sky/provision/azure/instance.py b/sky/provision/azure/instance.py index 229d7361e22..4e461375a14 100644 --- a/sky/provision/azure/instance.py +++ b/sky/provision/azure/instance.py @@ -343,7 +343,7 @@ def create_single_instance(vm_i): _create_vm(compute_client, vm_name, node_tags, provider_config, node_config, network_interface.id) - subprocess_utils.run_in_parallel(create_single_instance, range(count)) + subprocess_utils.run_in_parallel(create_single_instance, list(range(count))) # Update disk performance tier performance_tier = node_config.get('disk_performance_tier', None) diff --git a/sky/provision/kubernetes/instance.py b/sky/provision/kubernetes/instance.py index c431b023ab9..a849dfc3044 100644 --- a/sky/provision/kubernetes/instance.py +++ b/sky/provision/kubernetes/instance.py @@ -804,7 +804,8 @@ def _create_pod_thread(i: int): # Create pods in parallel pods = subprocess_utils.run_in_parallel(_create_pod_thread, - range(to_start_count), _NUM_THREADS) + list(range(to_start_count)), + _NUM_THREADS) # Process created pods for pod in pods: diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index dbacf33a7e2..52a34ab7b68 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -1,12 +1,11 @@ """Utility functions for subprocesses.""" -import collections from multiprocessing import pool import os import random import resource import subprocess import time -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import colorama import psutil @@ -98,7 +97,7 @@ def get_parallel_threads(cloud_str: Optional[str] = None) -> int: def run_in_parallel(func: Callable, - args: Iterable[Any], + args: List[Any], num_threads: Optional[int] = None) -> List[Any]: """Run a function in parallel on a list of arguments. @@ -114,12 +113,11 @@ def run_in_parallel(func: Callable, A list of the return values of the function func, in the same order as the arguments. """ - if isinstance(args, collections.abc.Sized): - if len(args) == 0: - return [] - # Short-circuit for single element - if len(args) == 1: - return [func(next(iter(args)))] + if len(args) == 0: + return [] + # Short-circuit for single element + if len(args) == 1: + return [func(next(iter(args)))] # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long processes = num_threads if num_threads is not None else get_parallel_threads( ) From 46194133ee0f2c38712ab7f09ec01e5935045c71 Mon Sep 17 00:00:00 2001 From: Aylei Date: Fri, 20 Dec 2024 17:24:30 +0800 Subject: [PATCH 09/11] revert debug change Signed-off-by: Aylei --- sky/clouds/service_catalog/aws_catalog.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/clouds/service_catalog/aws_catalog.py b/sky/clouds/service_catalog/aws_catalog.py index 0557d2babae..bbd48863755 100644 --- a/sky/clouds/service_catalog/aws_catalog.py +++ b/sky/clouds/service_catalog/aws_catalog.py @@ -101,6 +101,7 @@ def _get_az_mappings(aws_user_hash: str) -> Optional['pd.DataFrame']: return az_mappings +@timeline.event def _fetch_and_apply_az_mapping(df: common.LazyDataFrame) -> 'pd.DataFrame': """Maps zone IDs (use1-az1) to zone names (us-east-1x). @@ -291,7 +292,6 @@ def get_region_zones_for_instance_type(instance_type: str, return us_region_list + other_region_list -@timeline.event def list_accelerators( gpus_only: bool, name_filter: Optional[str], From 9bb757e7d081b4446c356f44c7a4a24b44420f6c Mon Sep 17 00:00:00 2001 From: Aylei Date: Sat, 21 Dec 2024 20:49:12 +0800 Subject: [PATCH 10/11] Update sky/utils/subprocess_utils.py Co-authored-by: Zhanghao Wu --- sky/utils/subprocess_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sky/utils/subprocess_utils.py b/sky/utils/subprocess_utils.py index 52a34ab7b68..88d351632a3 100644 --- a/sky/utils/subprocess_utils.py +++ b/sky/utils/subprocess_utils.py @@ -117,7 +117,7 @@ def run_in_parallel(func: Callable, return [] # Short-circuit for single element if len(args) == 1: - return [func(next(iter(args)))] + return [func(args[0])] # Reference: https://stackoverflow.com/questions/25790279/python-multiprocessing-early-termination # pylint: disable=line-too-long processes = num_threads if num_threads is not None else get_parallel_threads( ) From a9026db663ec47defe71220d440c2c841774d365 Mon Sep 17 00:00:00 2001 From: Aylei Date: Mon, 23 Dec 2024 14:45:17 +0800 Subject: [PATCH 11/11] address review comments Signed-off-by: Aylei --- sky/clouds/service_catalog/__init__.py | 3 ++- sky/utils/accelerator_registry.py | 20 ++++++++++++-------- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sky/clouds/service_catalog/__init__.py b/sky/clouds/service_catalog/__init__.py index 9c096444bfb..3aad5a0b7fd 100644 --- a/sky/clouds/service_catalog/__init__.py +++ b/sky/clouds/service_catalog/__init__.py @@ -49,7 +49,8 @@ def _execute_catalog_method(cloud: str): return method(*args, **kwargs) results = subprocess_utils.run_in_parallel(_execute_catalog_method, - list(clouds), len(clouds)) + args=list(clouds), + num_threads=len(clouds)) if single: return results[0] return results diff --git a/sky/utils/accelerator_registry.py b/sky/utils/accelerator_registry.py index 78a708efb91..11dd5280ac4 100644 --- a/sky/utils/accelerator_registry.py +++ b/sky/utils/accelerator_registry.py @@ -3,6 +3,7 @@ from typing import Optional from sky.clouds import service_catalog +from sky.utils import rich_utils from sky.utils import ux_utils if typing.TYPE_CHECKING: @@ -88,14 +89,17 @@ def canonicalize_accelerator_name(accelerator: str, if accelerator.lower() in mapping: return mapping[accelerator.lower()] - # _ACCELERATORS may not be comprehensive. - # Users may manually add new accelerators to the catalogs, or download new - # catalogs (that have new accelerators) without upgrading SkyPilot. - # To cover such cases, we should search the accelerator name - # in the service catalog. - searched = service_catalog.list_accelerators(name_filter=accelerator, - case_sensitive=False, - clouds=cloud_str) + # Listing accelerators can be time-consuming since canonicalizing usually + # involves catalog reading with cache not warmed up. + with rich_utils.safe_status('Listing accelerators...'): + # _ACCELERATORS may not be comprehensive. + # Users may manually add new accelerators to the catalogs, or download + # new catalogs (that have new accelerators) without upgrading SkyPilot. + # To cover such cases, we should search the accelerator name + # in the service catalog. + searched = service_catalog.list_accelerators(name_filter=accelerator, + case_sensitive=False, + clouds=cloud_str) names = list(searched.keys()) # Exact match.