Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop catching generic Exception in operators #100

Merged
merged 1 commit into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ray_provider/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class RayAirflowException(Exception):
pass
163 changes: 89 additions & 74 deletions ray_provider/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from functools import cached_property
from typing import Any

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodOperatorHookProtocol
from airflow.utils.context import Context
from kubernetes.client.exceptions import ApiException
from ray.job_submission import JobStatus

from ray_provider.constants import TERMINAL_JOB_STATUSES
from ray_provider.exceptions import RayAirflowException
from ray_provider.hooks import RayHook
from ray_provider.triggers import RayJobTrigger

Expand Down Expand Up @@ -41,7 +43,7 @@ def __init__(
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml
self.update_if_exists = update_if_exists

@cached_property
@property
def hook(self) -> RayHook:
"""Lazily initialize and return the RayHook."""
return RayHook(conn_id=self.conn_id)
Expand All @@ -52,13 +54,15 @@ def execute(self, context: Context) -> None:

:param context: The context in which the operator is being executed.
"""
self.log.info(f"Trying to setup the ray cluster defined in {self.ray_cluster_yaml}")
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
self.log.info("Finished setting up the ray cluster.")


class DeleteRayCluster(BaseOperator):
Expand All @@ -82,7 +86,7 @@ def __init__(
self.ray_cluster_yaml = ray_cluster_yaml
self.gpu_device_plugin_yaml = gpu_device_plugin_yaml

@cached_property
@property
def hook(self) -> PodOperatorHookProtocol:
"""Lazily initialize and return the RayHook."""
return RayHook(conn_id=self.conn_id)
Expand All @@ -93,7 +97,9 @@ def execute(self, context: Context) -> None:

:param context: The context in which the operator is being executed.
"""
self.log.info(f"Trying to delete the ray cluster defined in {self.ray_cluster_yaml}")
self.hook.delete_ray_cluster(self.ray_cluster_yaml, self.gpu_device_plugin_yaml)
self.log.info("Finished deleting the ray cluster.")


class SubmitRayJob(BaseOperator):
Expand Down Expand Up @@ -173,7 +179,6 @@ def __init__(
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
self.job_id = ""
self.terminal_states = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
tatiana marked this conversation as resolved.
Show resolved Hide resolved

def on_kill(self) -> None:
"""
Expand Down Expand Up @@ -226,28 +231,36 @@ def _setup_cluster(self, context: Context) -> None:
Set up the Ray cluster if a cluster YAML is provided.

:param context: The context in which the task is being executed.
:raises Exception: If there's an error during cluster setup.
"""
if self.ray_cluster_yaml:
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
try:
self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
except ApiException as e:
self.log.info(f"Unable to setup the Ray cluster using {self.ray_cluster_yaml}")
self.log.error("Exception details:", exc_info=True)
self.log.info("Trying to delete any parts of the RayCluster that may have been spun up...")
self._delete_cluster()
raise e
else:
self.log.info(f"Skipping setting up a Ray cluster because no `ray_cluster_yaml` was given.")

def _delete_cluster(self) -> None:
"""
Delete the Ray cluster if a cluster YAML is provided.

:raises Exception: If there's an error during cluster deletion.
"""
if self.ray_cluster_yaml:
self.hook.delete_ray_cluster(
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
)
else:
self.log.info(f"Skipping deleting the Ray cluster because no `ray_cluster_yaml` was given.")

def execute(self, context: Context) -> str:
"""
Expand All @@ -258,58 +271,51 @@ def execute(self, context: Context) -> str:

:param context: The context in which the task is being executed.
:return: The job ID of the submitted Ray job.
:raises AirflowException: If the job fails, is cancelled, or reaches an unexpected state.
"""

try:
self._setup_cluster(context=context)

self.dashboard_url = self._get_dashboard_url(context)

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(
f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`"
)
return self.job_id
except Exception as e:
self._delete_cluster()
raise AirflowException(f"SubmitRayJob operator failed due to {e}. Cleaning up resources...")
self.log.info("::group:: (SubmitJob 1/5) Setup Cluster")
self._setup_cluster(context=context)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 2/5) Identify Dashboard URL")
self.dashboard_url = self._get_dashboard_url(context)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 3/5) Submit job")
self.log.info(f"Ray job with id {self.job_id} submitted")
self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
runtime_env=self.runtime_env,
entrypoint_num_cpus=self.num_cpus,
entrypoint_num_gpus=self.num_gpus,
entrypoint_memory=self.memory,
entrypoint_resources=self.ray_resources,
)
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 4/5) Wait for completion")
if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in TERMINAL_JOB_STATUSES:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
ray_cluster_yaml=self.ray_cluster_yaml,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=self.job_timeout_seconds,
)

return self.job_id

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Expand All @@ -320,15 +326,24 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:

:param context: The context in which the task is being executed.
:param event: The event containing the job execution result.
:raises AirflowException: If the job execution fails, is cancelled, or reaches an unexpected state.
:raises RayAirflowException: If the job execution fails, is cancelled, or reaches an unexpected state.
tatiana marked this conversation as resolved.
Show resolved Hide resolved
"""
try:
if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]:
self.log.info(f"Ray job {self.job_id} execution not completed successfully...")
raise AirflowException(f"Job {self.job_id} {event['status'].lower()}: {event['message']}")
elif event["status"] == JobStatus.SUCCEEDED:
self.log.info(f"Ray job {self.job_id} execution succeeded.")
self.log.info("::endgroup::")
self.log.info("::group:: (SubmitJob 5/5) Execution completed")

self._delete_cluster()

job_status = event["status"]
if job_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
return
else:
self.log.info(f"Ray job {self.job_id} execution not completed successfully...")
if job_status in (JobStatus.FAILED, JobStatus.STOPPED):
msg = f"Job {self.job_id} {job_status.lower()}: {event['message']}"
else:
raise AirflowException(f"Unexpected event status for job {self.job_id}: {event['status']}")
finally:
self._delete_cluster()
msg = f"Encountered unexpected state `{job_status}` for job_id `{self.job_id}`"

self.log.info("::endgroup::")

raise RayAirflowException(msg)
Loading