Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wait for job completion feature #20

Merged
merged 4 commits into from
Jul 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ray_provider/hooks/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def get_ray_job_logs(self, job_id: str) -> str:
"""
client = self.ray_client
logs = client.get_job_logs(job_id=job_id)
self.log.info(f"Logs for job {job_id}: {logs}")
return str(logs)

async def get_ray_tail_logs(self, job_id: str) -> AsyncIterator[str]:
Expand Down
70 changes: 42 additions & 28 deletions ray_provider/operators/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,11 @@ def _setup_load_balancer(self, name: str, namespace: str, context: Context) -> N
def execute(self, context: Context) -> None:
"""Execute the operator to set up the Ray cluster."""
try:
self.log.info("::group::Add KubeRay operator")
self.hook.install_kuberay_operator(version=self.kuberay_version)
self.log.info("::endgroup::")

self.log.info("::group::Create Ray Cluster")
self.log.info("Loading yaml content for Ray cluster CRD...")
cluster_spec = self.hook.load_yaml_content(self.ray_cluster_yaml)

Expand All @@ -123,11 +126,14 @@ def execute(self, context: Context) -> None:
group, version = api_version.split("/") if "/" in api_version else ("", api_version)

self._create_or_update_cluster(group, version, plural, name, namespace, cluster_spec)
self.log.info("::endgroup::")

if self.use_gpu:
self._setup_gpu_driver()

self.log.info("::group::Setup Load Balancer service")
self._setup_load_balancer(name, namespace, context)
self.log.info("::endgroup::")

except Exception as e:
self.log.error(f"Error setting up Ray cluster: {e}")
Expand Down Expand Up @@ -210,7 +216,9 @@ def execute(self, context: Context) -> None:
try:
if self.use_gpu:
self._delete_gpu_daemonset()
self.log.info("::group:: Delete Ray Cluster")
self._delete_ray_cluster()
self.log.info("::endgroup::")
self.hook.uninstall_kuberay_operator()
except Exception as e:
self.log.error(f"Error deleting Ray cluster: {e}")
Expand All @@ -231,7 +239,7 @@ class SubmitRayJob(BaseOperator):
:param num_gpus: Number of GPUs required for the job. Defaults to 0.
:param memory: Amount of memory required for the job. Defaults to 0.
:param resources: Additional resources required for the job. Defaults to None.
:param timeout: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param job_timeout_seconds: Maximum time to wait for job completion in seconds. Defaults to 600 seconds.
:param poll_interval: Interval between job status checks in seconds. Defaults to 60 seconds.
:param xcom_task_key: XCom key to retrieve dashboard URL. Defaults to None.
"""
Expand All @@ -248,7 +256,9 @@ def __init__(
num_gpus: int | float = 0,
memory: int | float = 0,
resources: dict[str, Any] | None = None,
timeout: int = 600,
fetch_logs: bool = True,
wait_for_completion: bool = True,
job_timeout_seconds: int = 600,
poll_interval: int = 60,
xcom_task_key: str | None = None,
**kwargs: Any,
Expand All @@ -261,7 +271,9 @@ def __init__(
self.num_gpus = num_gpus
self.memory = memory
self.ray_resources = resources
self.timeout = timeout
self.fetch_logs = fetch_logs
self.wait_for_completion = wait_for_completion
self.job_timeout_seconds = job_timeout_seconds
self.poll_interval = poll_interval
self.xcom_task_key = xcom_task_key
self.dashboard_url: str | None = None
Expand Down Expand Up @@ -303,29 +315,31 @@ def execute(self, context: Context) -> str:
)
self.log.info(f"Ray job submitted with id: {self.job_id}")

current_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_state:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
timeout=timedelta(seconds=self.timeout),
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
),
method_name="execute_complete",
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")
if self.wait_for_completion:
current_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_state:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
job_id=self.job_id,
conn_id=self.conn_id,
xcom_dashboard_url=self.dashboard_url,
poll_interval=self.poll_interval,
fetch_logs=self.fetch_logs,
),
method_name="execute_complete",
timeout=timedelta(seconds=self.job_timeout_seconds),
)
elif current_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
elif current_status == JobStatus.FAILED:
raise AirflowException(f"Job failed:\n{self.job_id}")
elif current_status == JobStatus.STOPPED:
raise AirflowException(f"Job was cancelled:\n{self.job_id}")
else:
raise AirflowException(f"Encountered unexpected state `{current_status}` for job_id `{self.job_id}`")

return self.job_id

Expand All @@ -337,10 +351,10 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
:param event: The event containing the job execution result.
:raises AirflowException: If the job execution fails or is cancelled.
"""
if event["status"] in ["error", "cancelled"]:
if event["status"] in [JobStatus.STOPPED, JobStatus.FAILED]:
self.log.info(f"Ray job {self.job_id} execution not completed...")
raise AirflowException(event["message"])
elif event["status"] == "success":
elif event["status"] == JobStatus.SUCCEEDED:
self.log.info(f"Ray job {self.job_id} execution succeeded ...")
else:
raise AirflowException(f"Unexpected event status: {event['status']}")
58 changes: 28 additions & 30 deletions ray_provider/triggers/ray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from functools import cached_property
from functools import cached_property, partial
from typing import Any, AsyncIterator

from airflow.triggers.base import BaseTrigger, TriggerEvent
Expand All @@ -24,11 +24,19 @@ class RayJobTrigger(BaseTrigger):
:param poll_interval: The interval in seconds at which to poll the job status. Defaults to 30 seconds.
"""

def __init__(self, job_id: str, conn_id: str, xcom_dashboard_url: str | None, poll_interval: int = 30):
def __init__(
self,
job_id: str,
conn_id: str,
xcom_dashboard_url: str | None,
poll_interval: int = 30,
fetch_logs: bool = True,
):
super().__init__() # type: ignore[no-untyped-call]
self.job_id = job_id
self.conn_id = conn_id
self.dashboard_url = xcom_dashboard_url
self.fetch_logs = fetch_logs
self.poll_interval = poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
Expand All @@ -43,6 +51,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"job_id": self.job_id,
"conn_id": self.conn_id,
"xcom_dashboard_url": self.dashboard_url,
"fetch_logs": self.fetch_logs,
"poll_interval": self.poll_interval,
},
)
Expand Down Expand Up @@ -71,38 +80,27 @@ async def run(self) -> AsyncIterator[TriggerEvent]:
while not self._is_terminal_state():
await asyncio.sleep(self.poll_interval)

# Stream logs if available
async for multi_line in self.hook.get_ray_tail_logs(self.job_id):
self.log.info(multi_line)
self.log.info(f"Fetch logs flag is set to : {self.fetch_logs}")
if self.fetch_logs:
# Stream logs if available
loop = asyncio.get_event_loop()
logs = await loop.run_in_executor(None, partial(self.hook.get_ray_job_logs, job_id=self.job_id))
self.log.info(f"::group::{self.job_id} logs")
for log in logs.split("\n"):
self.log.info(log)
self.log.info("::endgroup::")

completed_status = self.hook.get_ray_job_status(self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
if completed_status == JobStatus.SUCCEEDED:
yield TriggerEvent(
{
"status": "success",
"message": f"Job run {self.job_id} has completed successfully.",
"job_id": self.job_id,
}
)
elif completed_status == JobStatus.STOPPED:
yield TriggerEvent(
{
"status": "cancelled",
"message": f"Job run {self.job_id} has been stopped.",
"job_id": self.job_id,
}
)
else:
yield TriggerEvent(
{
"status": "error",
"message": f"Job run {self.job_id} has failed.",
"job_id": self.job_id,
}
)
yield TriggerEvent(
{
"status": completed_status,
"message": f"Job {self.job_id} completed with status {completed_status}",
"job_id": self.job_id,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e), "job_id": self.job_id})
yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
"""
Expand Down
11 changes: 6 additions & 5 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pytest
from airflow.exceptions import AirflowException, TaskDeferred
from ray.job_submission import JobStatus

from ray_provider.operators.ray import SubmitRayJob

Expand All @@ -13,7 +14,7 @@
num_gpus = 1
memory = 1024
resources = {"CPU": 2}
timeout = 600
job_timeout_seconds = 600
context = MagicMock()


Expand All @@ -27,7 +28,7 @@ def operator():
num_gpus=num_gpus,
memory=memory,
resources=resources,
timeout=timeout,
job_timeout_seconds=job_timeout_seconds,
task_id="Testcases",
)

Expand All @@ -42,7 +43,7 @@ def test_init(self, operator):
assert operator.num_gpus == num_gpus
assert operator.memory == memory
# assert operator.resources == resources
assert operator.timeout == timeout
assert operator.job_timeout_seconds == job_timeout_seconds

@patch("ray_provider.operators.ray.SubmitRayJob.hook")
def test_execute(self, mock_hook, operator):
Expand All @@ -65,13 +66,13 @@ def test_on_kill(self, mock_hook, operator):
mock_hook.delete_ray_job.assert_called_once_with("job_12345")

def test_execute_complete_success(self, operator):
event = {"status": "success", "message": "Job completed successfully"}
event = {"status": JobStatus.SUCCEEDED, "message": "Job completed successfully"}
operator.job_id = "job_12345"

assert operator.execute_complete(context, event) is None

def test_execute_complete_failure(self, operator):
event = {"status": "error", "message": "Job failed"}
event = {"status": JobStatus.FAILED, "message": "Job failed"}
operator.job_id = "job_12345"

with pytest.raises(AirflowException, match="Job failed"):
Expand Down
34 changes: 17 additions & 17 deletions tests/triggers/test_ray_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,38 @@

import pytest
from airflow.triggers.base import TriggerEvent
from ray.dashboard.modules.job.sdk import JobStatus
from ray.job_submission import JobStatus

from ray_provider.triggers.ray import RayJobTrigger


class TestRayJobTrigger:

@pytest.mark.asyncio
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_no_job_id(self, mock_hook, mock_is_terminal):
mock_is_terminal.return_value = True
mock_hook.get_ray_job_status.return_value = JobStatus.FAILED
trigger = RayJobTrigger(job_id="", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

generator = trigger.run()
event = await generator.asend(None)
assert event == TriggerEvent({"status": "error", "message": "Job run has failed.", "job_id": ""})
assert event == TriggerEvent(
{"status": JobStatus.FAILED, "message": "Job completed with status FAILED", "job_id": ""}
)

@pytest.mark.asyncio
@patch("ray_provider.triggers.ray.RayJobTrigger._is_terminal_state")
@patch("ray_provider.triggers.ray.RayJobTrigger.hook")
async def test_run_job_succeeded(self, mock_hook):
trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")

async def test_run_job_succeeded(self, mock_hook, mock_is_terminal):
mock_is_terminal.side_effect = [False, True]
mock_hook.get_ray_job_status.return_value = JobStatus.SUCCEEDED

trigger = RayJobTrigger(job_id="test_job_id", poll_interval=1, conn_id="test", xcom_dashboard_url="test")
generator = trigger.run()
async for event in generator:
assert event == TriggerEvent(
{
"status": "success",
"message": "Job run test_job_id has completed successfully.",
"job_id": "test_job_id",
}
)
break # Stop after the first event for testing purposes
event = await generator.asend(None)
assert event == TriggerEvent(
{
"status": JobStatus.SUCCEEDED,
"message": f"Job test_job_id completed with status {JobStatus.SUCCEEDED}",
"job_id": "test_job_id",
}
)
Loading