Skip to content

Commit

Permalink
Merge branch 'dev/zeping/fix_smoke_tests' into dev/zeping/more_fix_of…
Browse files Browse the repository at this point in the history
…_smoke_tests
  • Loading branch information
zpoint committed Jan 10, 2025
2 parents ad89aae + ddcfb0a commit 834536b
Show file tree
Hide file tree
Showing 11 changed files with 332 additions and 31 deletions.
146 changes: 146 additions & 0 deletions sky/backends/cloud_vm_ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3891,6 +3891,152 @@ def tail_managed_job_logs(self,
stdin=subprocess.DEVNULL,
)

def sync_down_managed_job_logs(
self,
handle: CloudVmRayResourceHandle,
job_id: Optional[int] = None,
job_name: Optional[str] = None,
controller: bool = False,
local_dir: str = constants.SKY_LOGS_DIRECTORY) -> Dict[str, str]:
"""Sync down logs for a managed job.
Args:
handle: The handle to the cluster.
job_id: The job ID to sync down logs for.
job_name: The job name to sync down logs for.
controller: Whether to sync down logs for the controller.
local_dir: The local directory to sync down logs to.
Returns:
A dictionary mapping job_id to log path.
"""
# if job_name is not None, job_id should be None
assert job_name is None or job_id is None, (job_name, job_id)
if job_id is None and job_name is not None:
# generate code to get the job_id
code = managed_jobs.ManagedJobCodeGen.get_all_job_ids_by_name(
job_name=job_name)
returncode, run_timestamps, stderr = self.run_on_head(
handle,
code,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(returncode, code,
'Failed to sync down logs.',
stderr)
job_ids = common_utils.decode_payload(run_timestamps)
if not job_ids:
logger.info(f'{colorama.Fore.YELLOW}'
'No matching job found'
f'{colorama.Style.RESET_ALL}')
return {}
elif len(job_ids) > 1:
logger.info(
f'{colorama.Fore.YELLOW}'
f'Multiple jobs IDs found under the name {job_name}. '
'Downloading the latest job logs.'
f'{colorama.Style.RESET_ALL}')
job_ids = [job_ids[0]] # descending order
else:
job_ids = [job_id]

# get the run_timestamp
# the function takes in [job_id]
code = job_lib.JobLibCodeGen.get_run_timestamp_with_globbing(job_ids)
returncode, run_timestamps, stderr = self.run_on_head(
handle,
code,
stream_logs=False,
require_outputs=True,
separate_stderr=True)
subprocess_utils.handle_returncode(returncode, code,
'Failed to sync logs.', stderr)
# returns with a dict of {job_id: run_timestamp}
run_timestamps = common_utils.decode_payload(run_timestamps)
if not run_timestamps:
logger.info(f'{colorama.Fore.YELLOW}'
'No matching log directories found'
f'{colorama.Style.RESET_ALL}')
return {}

run_timestamp = list(run_timestamps.values())[0]
job_id = list(run_timestamps.keys())[0]
local_log_dir = ''
if controller: # download controller logs
remote_log_dir = os.path.join(constants.SKY_LOGS_DIRECTORY,
run_timestamp)
local_log_dir = os.path.expanduser(
os.path.join(local_dir, run_timestamp))

logger.info(f'{colorama.Fore.CYAN}'
f'Job {job_ids} local logs: {local_log_dir}'
f'{colorama.Style.RESET_ALL}')

runners = handle.get_command_runners()

def _rsync_down(args) -> None:
"""Rsync down logs from remote nodes.
Args:
args: A tuple of (runner, local_log_dir, remote_log_dir)
"""
(runner, local_log_dir, remote_log_dir) = args
try:
os.makedirs(local_log_dir, exist_ok=True)
runner.rsync(
source=f'{remote_log_dir}/',
target=local_log_dir,
up=False,
stream_logs=False,
)
except exceptions.CommandError as e:
if e.returncode == exceptions.RSYNC_FILE_NOT_FOUND_CODE:
# Raised by rsync_down. Remote log dir may not exist
# since the job can be run on some part of the nodes.
logger.debug(
f'{runner.node_id} does not have the tasks/*.')
else:
raise

parallel_args = [[runner, *item]
for item in zip([local_log_dir], [remote_log_dir])
for runner in runners]
subprocess_utils.run_in_parallel(_rsync_down, parallel_args)
else: # download job logs
local_log_dir = os.path.expanduser(
os.path.join(local_dir, 'managed_jobs', run_timestamp))
os.makedirs(os.path.dirname(local_log_dir), exist_ok=True)
log_file = os.path.join(local_log_dir, 'run.log')

code = managed_jobs.ManagedJobCodeGen.stream_logs(job_name=None,
job_id=job_id,
follow=False,
controller=False)

# With the stdin=subprocess.DEVNULL, the ctrl-c will not
# kill the process, so we need to handle it manually here.
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, backend_utils.interrupt_handler)
signal.signal(signal.SIGTSTP, backend_utils.stop_handler)

# We redirect the output to the log file
# and disable the STDOUT and STDERR
self.run_on_head(
handle,
code,
log_path=log_file,
stream_logs=False,
process_stream=False,
ssh_mode=command_runner.SshMode.INTERACTIVE,
stdin=subprocess.DEVNULL,
)

logger.info(f'{colorama.Fore.CYAN}'
f'Job {job_id} logs: {local_log_dir}'
f'{colorama.Style.RESET_ALL}')
return {str(job_id): local_log_dir}

def tail_serve_logs(self, handle: CloudVmRayResourceHandle,
service_name: str, target: serve_lib.ServiceComponent,
replica_id: Optional[int], follow: bool) -> None:
Expand Down
26 changes: 19 additions & 7 deletions sky/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3933,17 +3933,29 @@ def jobs_cancel(name: Optional[str], job_ids: Tuple[int], all: bool, yes: bool):
required=False,
help='Query the latest job logs, restarting the jobs controller if stopped.'
)
@click.option('--sync-down',
'-s',
default=False,
is_flag=True,
required=False,
help='Download logs for all jobs shown in the queue.')
@click.argument('job_id', required=False, type=int)
@usage_lib.entrypoint
def jobs_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller: bool, refresh: bool):
"""Tail the log of a managed job."""
controller: bool, refresh: bool, sync_down: bool):
"""Tail or sync down the log of a managed job."""
try:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller,
refresh=refresh)
if sync_down:
managed_jobs.sync_down_logs(name=name,
job_id=job_id,
controller=controller,
refresh=refresh)
else:
managed_jobs.tail_logs(name=name,
job_id=job_id,
follow=follow,
controller=controller,
refresh=refresh)
except exceptions.ClusterNotUpError:
with ux_utils.print_exception_no_traceback():
raise
Expand Down
29 changes: 17 additions & 12 deletions sky/clouds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ class AWSIdentityType(enum.Enum):

CUSTOM_PROCESS = 'custom-process'

ASSUME_ROLE = 'assume-role'

# Name Value Type Location
# ---- ----- ---- --------
# profile <not set> None None
Expand Down Expand Up @@ -626,6 +628,17 @@ def check_credentials(cls) -> Tuple[bool, Optional[str]]:
# the file. i.e. the custom process will be assigned the IAM role of the
# task: skypilot-v1.
hints = f'AWS custom-process is set.{single_cloud_hint}'
elif identity_type == AWSIdentityType.ASSUME_ROLE:
# When using ASSUME ROLE, the credentials are coming from a different
# source profile. So we don't check for the existence of ~/.aws/credentials.
# i.e. the assumed role will be assigned the IAM role of the
# task: skypilot-v1.
hints = f'AWS assume-role is set.{single_cloud_hint}'
elif identity_type == AWSIdentityType.ENV:
# When using ENV vars, the credentials are coming from the environment
# variables. So we don't check for the existence of ~/.aws/credentials.
# i.e. the identity is not determined by the file.
hints = f'AWS env is set.{single_cloud_hint}'
else:
# This file is required because it is required by the VMs launched on
# other clouds to access private s3 buckets and resources like EC2.
Expand Down Expand Up @@ -677,18 +690,10 @@ def _is_access_key_of_type(type_str: str) -> bool:
f'Unexpected `aws configure list` output:\n{output}')
return len(results) == 1

if _is_access_key_of_type(AWSIdentityType.SSO.value):
return AWSIdentityType.SSO
elif _is_access_key_of_type(AWSIdentityType.IAM_ROLE.value):
return AWSIdentityType.IAM_ROLE
elif _is_access_key_of_type(AWSIdentityType.CONTAINER_ROLE.value):
return AWSIdentityType.CONTAINER_ROLE
elif _is_access_key_of_type(AWSIdentityType.ENV.value):
return AWSIdentityType.ENV
elif _is_access_key_of_type(AWSIdentityType.CUSTOM_PROCESS.value):
return AWSIdentityType.CUSTOM_PROCESS
else:
return AWSIdentityType.SHARED_CREDENTIALS_FILE
for identity_type in AWSIdentityType:
if _is_access_key_of_type(identity_type.value):
return identity_type
return AWSIdentityType.SHARED_CREDENTIALS_FILE

@classmethod
@functools.lru_cache(maxsize=1)
Expand Down
2 changes: 2 additions & 0 deletions sky/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from sky.jobs.core import launch
from sky.jobs.core import queue
from sky.jobs.core import queue_from_kubernetes_pod
from sky.jobs.core import sync_down_logs
from sky.jobs.core import tail_logs
from sky.jobs.recovery_strategy import DEFAULT_RECOVERY_STRATEGY
from sky.jobs.recovery_strategy import RECOVERY_STRATEGIES
Expand Down Expand Up @@ -37,6 +38,7 @@
'queue',
'queue_from_kubernetes_pod',
'tail_logs',
'sync_down_logs',
# utils
'ManagedJobCodeGen',
'format_job_table',
Expand Down
46 changes: 46 additions & 0 deletions sky/jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,52 @@ def tail_logs(name: Optional[str], job_id: Optional[int], follow: bool,
controller=controller)


@usage_lib.entrypoint
def sync_down_logs(
name: Optional[str],
job_id: Optional[int],
refresh: bool,
controller: bool,
local_dir: str = skylet_constants.SKY_LOGS_DIRECTORY) -> None:
"""Sync down logs of managed jobs.
Please refer to sky.cli.job_logs for documentation.
Raises:
ValueError: invalid arguments.
sky.exceptions.ClusterNotUpError: the jobs controller is not up.
"""
# TODO(zhwu): Automatically restart the jobs controller
if name is not None and job_id is not None:
with ux_utils.print_exception_no_traceback():
raise ValueError('Cannot specify both name and job_id.')

jobs_controller_type = controller_utils.Controllers.JOBS_CONTROLLER
job_name_or_id_str = ''
if job_id is not None:
job_name_or_id_str = str(job_id)
elif name is not None:
job_name_or_id_str = f'-n {name}'
else:
job_name_or_id_str = ''
handle = _maybe_restart_controller(
refresh,
stopped_message=(
f'{jobs_controller_type.value.name.capitalize()} is stopped. To '
f'get the logs, run: {colorama.Style.BRIGHT}sky jobs logs '
f'-r --sync-down {job_name_or_id_str}{colorama.Style.RESET_ALL}'),
spinner_message='Retrieving job logs')

backend = backend_utils.get_backend_from_handle(handle)
assert isinstance(backend, backends.CloudVmRayBackend), backend

backend.sync_down_managed_job_logs(handle,
job_id=job_id,
job_name=name,
controller=controller,
local_dir=local_dir)


spot_launch = common_utils.deprecated_function(
launch,
name='sky.jobs.launch',
Expand Down
27 changes: 27 additions & 0 deletions sky/jobs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,33 @@ def get_nonterminal_job_ids_by_name(name: Optional[str]) -> List[int]:
return job_ids


def get_all_job_ids_by_name(name: Optional[str]) -> List[int]:
"""Get all job ids by name."""
name_filter = ''
field_values = []
if name is not None:
# We match the job name from `job_info` for the jobs submitted after
# #1982, and from `spot` for the jobs submitted before #1982, whose
# job_info is not available.
name_filter = ('WHERE (job_info.name=(?) OR '
'(job_info.name IS NULL AND spot.task_name=(?)))')
field_values = [name, name]

# Left outer join is used here instead of join, because the job_info does
# not contain the managed jobs submitted before #1982.
with db_utils.safe_cursor(_DB_PATH) as cursor:
rows = cursor.execute(
f"""\
SELECT DISTINCT spot.spot_job_id
FROM spot
LEFT OUTER JOIN job_info
ON spot.spot_job_id=job_info.spot_job_id
{name_filter}
ORDER BY spot.spot_job_id DESC""", field_values).fetchall()
job_ids = [row[0] for row in rows if row[0] is not None]
return job_ids


def _get_all_task_ids_statuses(
job_id: int) -> List[Tuple[int, ManagedJobStatus]]:
with db_utils.safe_cursor(_DB_PATH) as cursor:
Expand Down
9 changes: 9 additions & 0 deletions sky/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,15 @@ def cancel_job_by_name(cls, job_name: str) -> str:
""")
return cls._build(code)

@classmethod
def get_all_job_ids_by_name(cls, job_name: str) -> str:
code = textwrap.dedent(f"""\
from sky.utils import common_utils
job_id = managed_job_state.get_all_job_ids_by_name({job_name!r})
print(common_utils.encode_payload(job_id), end="", flush=True)
""")
return cls._build(code)

@classmethod
def stream_logs(cls,
job_name: Optional[str],
Expand Down
Loading

0 comments on commit 834536b

Please sign in to comment.