Skip to content

Commit

Permalink
Merge branch 'main' into issue-81-mess
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Nov 29, 2024
2 parents 206e57d + 56387cc commit 76aeb5b
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 332 deletions.
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ setup-dev: ## Setup development environment

.PHONY: build-whl
build-whl: setup-dev ## Build installable whl file
# Delete any previous wheels, so different versions don't conflict
rm dev/include/*
cd dev
python3 -m build --outdir dev/include/

Expand Down
30 changes: 16 additions & 14 deletions dev/dags/ray_dynamic_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,21 @@
The print_context tasks in the downstream DAGs output the received context to the logs.
"""
from pathlib import Path

import re
from pathlib import Path

import yaml
from airflow import DAG
from airflow.decorators import dag, task
from airflow.decorators import task
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from airflow.utils.context import Context
from airflow.utils.dates import days_ago
from jinja2 import Template
import yaml

from ray_provider.decorators import ray


CONN_ID = "ray_conn"
RAY_SPEC = Path(__file__).parent / "scripts/ray.yaml"
FOLDER_PATH = Path(__file__).parent / "ray_scripts"
Expand All @@ -44,7 +43,7 @@ def slugify(value):
"""
Replace invalid characters with hyphens and make lowercase.
"""
return re.sub(r'[^\w\-\.]', '-', value).lower()
return re.sub(r"[^\w\-\.]", "-", value).lower()


def create_config_from_context(context, **kwargs):
Expand All @@ -54,11 +53,13 @@ def create_config_from_context(context, **kwargs):
raycluster_name = Template(raycluster_name_template).render(context).replace("_", "-")
raycluster_name = slugify(raycluster_name)

raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get("raycluster_k8s_yml_filename", default_name + ".yml")
raycluster_k8s_yml_filename_template = context.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", default_name + ".yml"
)
raycluster_k8s_yml_filename = Template(raycluster_k8s_yml_filename_template).render(context).replace("_", "-")
raycluster_k8s_yml_filename = slugify(raycluster_k8s_yml_filename)

with open(RAY_SPEC, "r") as file:
with open(RAY_SPEC) as file:
data = yaml.safe_load(file)
data["metadata"]["name"] = raycluster_name

Expand All @@ -75,7 +76,9 @@ def print_context(**kwargs):
# Retrieve `conf` passed from the parent DAG
print(kwargs)
cluster_name = kwargs.get("dag_run").conf.get("raycluster_name", "No ray cluster name provided")
raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get("raycluster_k8s_yml_filename", "No ray cluster YML filename provided")
raycluster_k8s_yml_filename = kwargs.get("dag_run").conf.get(
"raycluster_k8s_yml_filename", "No ray cluster YML filename provided"
)
print(f"Received cluster name: {cluster_name}")
print(f"Received cluster K8s YML filename: {raycluster_k8s_yml_filename}")

Expand Down Expand Up @@ -153,7 +156,6 @@ def square(x):
print(f"Mean of this population is {mean}")
return mean


data = generate_data()
process_data_with_ray(data)

Expand All @@ -172,7 +174,7 @@ def square(x):
trigger_dag_id="ray_dynamic_config_child_1",
conf={
"raycluster_name": "first-{{ dag_run.id }}",
"raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml"
"raycluster_k8s_yml_filename": "first-{{ dag_run.id }}.yaml",
},
)

Expand All @@ -184,12 +186,12 @@ def square(x):

# Illustrates that by default two DAG runs of the same DAG will be using different Ray clusters
# Disabled because in the local dev MacOS we're only managing to spin up two Ray Cluster services concurrently
#trigger_dag_3 = TriggerDagRunOperator(
# trigger_dag_3 = TriggerDagRunOperator(
# task_id="trigger_downstream_dag_3",
# trigger_dag_id="ray_dynamic_config_child_2",
# conf={},
#)
# )

empty_task >> trigger_dag_1
trigger_dag_1 >> trigger_dag_2
#trigger_dag_1 >> trigger_dag_3
# trigger_dag_1 >> trigger_dag_3
2 changes: 1 addition & 1 deletion dev/tests/dags/test_dag_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def test_dag_retries(dag_id, dag, fileloc):
"""
test if a DAG has retries set
"""
assert dag.default_args.get("retries", None) >= 2, f"{dag_id} in {fileloc} must have task retries >= 2."
assert dag.default_args.get("retries", 2) >= 2, f"{dag_id} in {fileloc} must have task retries >= 2."
2 changes: 1 addition & 1 deletion ray_provider/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ray.job_submission import JobStatus


TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
TERMINAL_JOB_STATUSES = {JobStatus.SUCCEEDED, JobStatus.STOPPED, JobStatus.FAILED}
72 changes: 35 additions & 37 deletions ray_provider/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
import inspect
import os
import re
import shutil
import tempfile
import textwrap
from datetime import timedelta
from pathlib import Path
from tempfile import mkdtemp
from typing import Any, Callable

from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
from airflow.exceptions import AirflowException
from airflow.utils.context import Context

from ray_provider.exceptions import RayAirflowException
from ray_provider.operators import SubmitRayJob


Expand All @@ -31,63 +29,58 @@ class _RayDecoratedOperator(DecoratedOperator, SubmitRayJob):
"""

custom_operator_name = "@task.ray"
_config: None | dict[str, Any] | Callable[..., dict[str, Any]] = None
_config: dict[str, Any] | Callable[..., dict[str, Any]] = dict()

template_fields: Any = (*SubmitRayJob.template_fields, "op_args", "op_kwargs")

def __init__(self, config: dict[str, Any] | Callable[..., dict[str, Any]], **kwargs: Any) -> None:
self._config = config
self.kwargs = kwargs
super().__init__(
conn_id="",
entrypoint="python script.py",
runtime_env={},
**kwargs
)

def _build_config(self, context: Context) -> dict:
if isinstance(self._config, Callable):
return self._build_config_from_callable(context)
super().__init__(conn_id="", entrypoint="python script.py", runtime_env={}, **kwargs)

def _build_config(self, context: Context) -> dict[str, Any]:
if callable(self._config):
config_params = inspect.signature(self._config).parameters
config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"}
if "context" in config_params:
config_kwargs["context"] = context
config = self._config(**config_kwargs)
assert isinstance(config, dict)
return config
return self._config

def _build_config_from_callable(self, context: Context) -> dict[str, Any]:
config_params = inspect.signature(self._config).parameters

config_kwargs = {k: v for k, v in self.kwargs.items() if k in config_params and k != "context"}

if "context" in config_params:
config_kwargs["context"] = context

# Call config with the prepared arguments
return self._config(**config_kwargs)

def _load_config(self, config: dict) -> None:
def _load_config(self, config: dict[str, Any]) -> None:
self.conn_id: str = config.get("conn_id", "")
self.is_decorated_function = False if "entrypoint" in config else True
self.entrypoint: str = config.get("entrypoint", "python script.py")
self.runtime_env: dict[str, Any] = config.get("runtime_env", {})

self.num_cpus: int | float = config.get("num_cpus", 1)
self.num_gpus: int | float = config.get("num_gpus", 0)
self.memory: int | float = config.get("memory")
self.memory: int | float = config.get("memory", 1)
self.ray_resources: dict[str, Any] | None = config.get("resources")
self.ray_cluster_yaml: str | None = config.get("ray_cluster_yaml")
self.update_if_exists: bool = config.get("update_if_exists", False)
self.kuberay_version: str = config.get("kuberay_version", "1.0.0")
self.gpu_device_plugin_yaml: str = config.get("gpu_device_plugin_yaml")
self.gpu_device_plugin_yaml: str = config.get(
"gpu_device_plugin_yaml",
"https://raw.githubusercontent.com/NVIDIA/k8s-device-plugin/v0.9.0/nvidia-device-plugin.yml",
)
self.fetch_logs: bool = config.get("fetch_logs", True)
self.wait_for_completion: bool = config.get("wait_for_completion", True)
job_timeout_seconds = config.get("job_timeout_seconds", 600)
self.job_timeout_seconds: int = timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
self.job_timeout_seconds: timedelta | None = (
timedelta(seconds=job_timeout_seconds) if job_timeout_seconds > 0 else None
)
self.poll_interval: int = config.get("poll_interval", 60)
self.xcom_task_key: str | None = config.get("xcom_task_key")

self.config = config

if not isinstance(self.num_cpus, (int, float)):
raise TypeError("num_cpus should be an integer or float value")
raise RayAirflowException("num_cpus should be an integer or float value")
if not isinstance(self.num_gpus, (int, float)):
raise TypeError("num_gpus should be an integer or float value")
raise RayAirflowException("num_gpus should be an integer or float value")

def execute(self, context: Context) -> Any:
"""
Expand All @@ -112,8 +105,6 @@ def execute(self, context: Context) -> Any:
# Get the Python source code and extract just the function body
full_source = inspect.getsource(self.python_callable)
function_body = self._extract_function_body(full_source)
if not function_body:
raise ValueError("Failed to retrieve Python source code")

# Prepare the function call
args_str = ", ".join(repr(arg) for arg in self.op_args)
Expand All @@ -137,27 +128,34 @@ def execute(self, context: Context) -> Any:

def _extract_function_body(self, source: str) -> str:
"""Extract the function, excluding only the ray.task decorator."""
self.log.info(r"Ray pipeline intended to be executed: \n %s", source)
if "@ray.task" not in source:
raise RayAirflowException("Unable to parse this body. Expects the `@ray.task` decorator.")
lines = source.split("\n")
# TODO: This approach is extremely hacky. Review it.
# It feels a mistake to have a user-facing module named the same as the offical ray SDK
# TODO: Review the current approach, that is quite hacky.
# It feels a mistake to have a user-facing module named the same as the official ray SDK.
# In particular, the decorator is working in a very artificial way, where ray means two different things
# at the scope of the task definition (Astro Ray Provider decorator) and inside the decorated method (Ray SDK)
# Find the line where the ray.task decorator is
# Additionally, if users imported the ray decorator as "from ray_provider.decorators import ray as ray_decorator
# The following will stop working.
ray_task_line = next((i for i, line in enumerate(lines) if re.match(r"^\s*@ray\.task", line.strip())), -1)

# Include everything except the ray.task decorator line
body = "\n".join(lines[:ray_task_line] + lines[ray_task_line + 1 :])
self.log.info("Ray job that is going to be executed: \m %s", body)

if not body:
raise RayAirflowException("Failed to extract Ray pipeline code decorated with @ray.task")
# Dedent the body
return textwrap.dedent(body)


class ray:
@staticmethod
def task(
python_callable: Callable[..., Any] | None = None,
multiple_outputs: bool | None = None,
config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None,
config: dict[str, Any] | Callable[[], dict[str, Any]] | None = None,
**kwargs: Any,
) -> TaskDecorator:
"""
Expand Down
2 changes: 2 additions & 0 deletions ray_provider/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class RayAirflowException(Exception):
pass
Loading

0 comments on commit 76aeb5b

Please sign in to comment.