Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[do not merge] Implement dynamic configuration and much more #94

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Merge branch 'main' into issue-81-mess
  • Loading branch information
tatiana authored Nov 29, 2024
commit 76aeb5be03c1aeb3ff0507221fea47828ed1d72f
30 changes: 16 additions & 14 deletions dev/dags/ray_dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@

The print_context tasks in the downstream DAGs output the received context to the logs.
"""
from pathlib import Path

import re
from pathlib import Path

import yaml
from airflow import DAG
from airflow.decorators import dag, task
from airflow.decorators import task
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.utils.context import Context
from airflow.utils.dates import days_ago
from jinja2 import Template
import yaml

from ray_provider.decorators import ray


CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"
Expand All @@ -44,7 +43,7 @@ def slugify(value):
"""
Replace invalid characters with hyphens and make lowercase.
"""
return re.sub(r'[^\w\-\.]', '-', value).lower()
return re.sub(r"[^\w\-\.]", "-", value).lower()


def create_config_from_context(context, **kwargs):
Expand All @@ -54,11 +53,13 @@ def create_config_from_context(context, **kwargs):
raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-")
raycluster_name = slugify(raycluster_name)

raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get("raycluster_k8s_yml_filename", default_name + ".yml")
raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", default_name + ".yml"
)
raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-")
raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename)

with open(RAY_SPEC, "r") as file:
with open(RAY_SPEC) as file:
data = yaml.safe_load(file)
data["metadata"]["name"] = raycluster_name

Expand All @@ -75,7 +76,9 @@ def print_context(**kwargs):
# Retrieve `conf` passed from the parent DAG
print(kwargs)
cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided")
raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get("raycluster_k8s_yml_filename", "No ray cluster YML filename provided")
raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", "No ray cluster YML filename provided"
)
print(f"Received cluster name: {cluster_name}")
print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}")

Expand Down Expand Up @@ -153,7 +156,6 @@ def square(x):
print(f"Mean of this population is {mean}")
return mean


data = generate_data()
process_data_with_ray(data)

Expand All @@ -172,7 +174,7 @@ def square(x):
trigger_dag_id="ray_dynamic_config_child_1",
conf={
"raycluster_name": "first-{{ dag_run.id }}",
"raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml"
"raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml",
},
)

Expand All @@ -184,12 +186,12 @@ def square(x):

# Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters
# Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently
#trigger_dag_3 = TriggerDagRunOperator(
# trigger_dag_3 = TriggerDagRunOperator(
# task_id="trigger_downstream_dag_3",
# trigger_dag_id="ray_dynamic_config_child_2",
# conf={},
#)
# )

empty_task >> trigger_dag_1
trigger_dag_1 >> trigger_dag_2
#trigger_dag_1 >> trigger_dag_3
# trigger_dag_1 >> trigger_dag_3
2 changes: 1 addition & 1 deletion ray_provider/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ray.job_submission import JobStatus


TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
61 changes: 28 additions & 33 deletions ray_provider/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import inspect
import os
import re
import shutil
import tempfile
import textwrap
from datetime import timedelta
from pathlib import Path
from tempfile import mkdtemp
from typing import Any, Callable

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
Expand All @@ -31,54 +29,49 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):
"""

custom_operator_name = "@task.ray"
_config: None | dict[str, Any] | Callable[..., dict[str, Any]] = None
_config: dict[str, Any] | Callable[..., dict[str, Any]] = dict()

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
self._config = config
self.kwargs = kwargs
super().__init__(
conn_id="",
entrypoint="python script.py",
runtime_env={},
**kwargs
)

def _build_config(self, context: Context) -> dict:
if isinstance(self._config, Callable):
return self._build_config_from_callable(context)
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

def _build_config(self, context: Context) -> dict[str, Any]:
if callable(self._config):
config_params = inspect.signature(self._config).parameters
config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"}
if "context" in config_params:
config_kwargs["context"] = context
config = self._config(**config_kwargs)
assert isinstance(config, dict)
return config
return self._config

def _build_config_from_callable(self, context: Context) -> dict[str, Any]:
config_params = inspect.signature(self._config).parameters

config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"}

if "context" in config_params:
config_kwargs["context"] = context

# Call config with the prepared arguments
return self._config(**config_kwargs)

def _load_config(self, config: dict) -> None:
def _load_config(self, config: dict[str, Any]) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory")
self.memory: int | float = config.get("memory", 1)
self.ray_resources: dict[str, Any] | None = config.get("resources")
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml")
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get("gpu_device_plugin_yaml")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
job_timeout_seconds = config.get("job_timeout_seconds", 600)
self.job_timeout_seconds: int = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
self.job_timeout_seconds: timedelta | None = (
timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key")

Expand All @@ -87,7 +80,7 @@ def _load_config(self, config: dict) -> None:
if not isinstance(self.num_cpus, (int, float)):
raise RayAirflowException("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")
raise RayAirflowException("num_gpus should be an integer or float value")

def execute(self, context: Context) -> Any:
"""
Expand Down Expand Up @@ -139,8 +132,8 @@ def _extract_function_body(self, source: str) -> str:
if "@ray.task" not in source:
raise RayAirflowException("Unable to parse this body. Expects the `@ray.task` decorator.")
lines = source.split("\n")
# TODO: This approach is extremely hacky. Review it.
# It feels a mistake to have a user-facing module named the same as the offical ray SDK
# TODO: Review the current approach, that is quite hacky.
# It feels a mistake to have a user-facing module named the same as the official ray SDK.
# In particular, the decorator is working in a very artificial way, where ray means two different things
# at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK)
# Find the line where the ray.task decorator is
Expand All @@ -150,17 +143,19 @@ def _extract_function_body(self, source: str) -> str:

# Include everything except the ray.task decorator line
body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :])
self.log.info("Ray job that is going to be executed: \m %s", body)

if not body:
raise RayAirflowException("Failed to extract Ray pipeline code decorated with @ray.task")
# Dedent the body
return textwrap.dedent(body)


class ray:
@staticmethod
def task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None,
config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Expand Down
17 changes: 11 additions & 6 deletions ray_provider/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,17 @@ def execute(self, context: Context) -> None:

:param context: The context in which the operator is being executed.
"""
self.log.info("Trying to setup ray cluster")
self.log.info(f"Trying to setup the ray cluster defined in {self.ray_cluster_yaml}")

self.hook.setup_ray_cluster(
context=context,
ray_cluster_yaml=self.ray_cluster_yaml,
kuberay_version=self.kuberay_version,
gpu_device_plugin_yaml=self.gpu_device_plugin_yaml,
update_if_exists=self.update_if_exists,
)
self.log.info("Finished setting up the ray cluster")

self.log.info("Finished setting up the ray cluster.")


class DeleteRayCluster(BaseOperator):
Expand Down Expand Up @@ -274,7 +276,6 @@ def execute(self, context: Context) -> str:
:return: The job ID of the submitted Ray job.
"""

#try:
self.log.info("::group:: (SubmitJob 1/5) Setup Cluster")
self._setup_cluster(context=context)
self.log.info("::endgroup::")
Expand All @@ -284,7 +285,9 @@ def execute(self, context: Context) -> str:
self.log.info("::endgroup::")

self.log.info("::group:: (SubmitJob 3/5) Submit job")

self.log.info(f"Ray job submitted with id: {self.job_id}")

self.job_id = self.hook.submit_ray_job(
dashboard_url=self.dashboard_url,
entrypoint=self.entrypoint,
Expand All @@ -301,7 +304,7 @@ def execute(self, context: Context) -> str:
current_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Current job status for {self.job_id} is: {current_status}")

if current_status not in self.terminal_states:
if current_status not in TERMINAL_JOB_STATUSES:
self.log.info("Deferring the polling to RayJobTrigger...")
self.defer(
trigger=RayJobTrigger(
Expand All @@ -317,6 +320,8 @@ def execute(self, context: Context) -> str:
timeout=self.job_timeout_seconds,
)

return self.job_id

def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
"""
Handle the completion of a deferred Ray job execution.
Expand All @@ -336,7 +341,7 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
job_status = event["status"]
if job_status == JobStatus.SUCCEEDED:
self.log.info("Job %s completed successfully", self.job_id)
return self.job_id
return
else:
self.log.info(f"Ray job {self.job_id} execution not completed successfully...")
if job_status in (JobStatus.FAILED, JobStatus.STOPPED):
Expand All @@ -346,4 +351,4 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:

self.log.info("::endgroup::")

raise AirflowException(msg)
raise RayAirflowException(msg)
68 changes: 29 additions & 39 deletions ray_provider/triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,45 +121,35 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

:yield: TriggerEvent containing the status, message, and job ID related to the job.
"""
# This is used indirectly when the Ray decorator is used.
# If not imported, DAGs that used the Ray decorator fail when triggered

self.log.info(f"::group:: Trigger 1/2: Checking the job status")
self.log.info(f"Polling for job {self.job_id} every {self.poll_interval} seconds...")

tasks = [self._poll_status()]
if self.fetch_logs:
tasks.append(self._stream_logs())
self.log.info(f"::endgroup::")
await asyncio.gather(*tasks)

self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state")
completed_status = self.hook.get_ray_job_status(self.dashboard_url, self.job_id)
self.log.info(f"Status of completed job {self.job_id} is: {completed_status}")
self.log.info(f"::endgroup::")

yield TriggerEvent(
{
"status": completed_status,
"message": f"Job {self.job_id} completed with status {completed_status}",
"job_id": self.job_id,
}
)
#except Exception as e:
# self.log.error(f"Error occurred: {str(e)}")
# await self.cleanup()
# yield TriggerEvent({"status": str(JobStatus.FAILED), "message": str(e), "job_id": self.job_id})

def _is_terminal_state(self) -> bool:
"""
Checks if the Ray job is in a terminal state.

A terminal state is one of the following: SUCCEEDED, STOPPED, or FAILED.

:return: True if the job is in a terminal state, False otherwise.
"""
return self.hook.get_ray_job_status(self.dashboard_url, self.job_id) in (
JobStatus.SUCCEEDED,
JobStatus.STOPPED,
JobStatus.FAILED,
)
try:
tasks = [self._poll_status()]
if self.fetch_logs:
tasks.append(self._stream_logs())
await asyncio.gather(*tasks)
except ApiException as e:
error_msg = str(e)
self.log.info(f"::endgroup::")
self.log.error("::group:: Trigger unable to poll job status")
self.log.error("Exception details:", exc_info=True)
self.log.info("Attempting to clean up...")
await self.cleanup()
self.log.info("Cleanup completed!")
self.log.info(f"::endgroup::")

yield TriggerEvent({"status": "EXCEPTION", "message": error_msg, "job_id": self.job_id})
else:
self.log.info(f"::endgroup::")
self.log.info(f"::group:: Trigger 2/2: Job reached a terminal state")
self.log.info(f"Status of completed job {self.job_id} is: {self._job_status}")
self.log.info(f"::endgroup::")

yield TriggerEvent(
{
"status": self._job_status,
"message": f"Job {self.job_id} completed with status {self._job_status}",
"job_id": self.job_id,
}
)
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.