Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DatabricksWorkflowTaskGroup #39771

Merged
merged 10 commits into from
May 30, 2024
18 changes: 18 additions & 0 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from __future__ import annotations

import json
from enum import Enum
from typing import Any

from requests import exceptions as requests_exceptions
Expand Down Expand Up @@ -63,6 +64,23 @@
SPARK_VERSIONS_ENDPOINT = ("GET", "api/2.0/clusters/spark-versions")


class RunLifeCycleState(Enum):
pankajkoti marked this conversation as resolved.
Show resolved Hide resolved
"""Enum for the run life cycle state concept of Databricks runs.

See more information at: https://docs.databricks.com/api/azure/workspace/jobs/listruns#runs-state-life_cycle_state
"""

BLOCKED = "BLOCKED"
INTERNAL_ERROR = "INTERNAL_ERROR"
PENDING = "PENDING"
QUEUED = "QUEUED"
RUNNING = "RUNNING"
SKIPPED = "SKIPPED"
TERMINATED = "TERMINATED"
TERMINATING = "TERMINATING"
WAITING_FOR_RETRY = "WAITING_FOR_RETRY"


class RunState:
"""Utility class for the run state concept of Databricks runs."""

Expand Down
141 changes: 116 additions & 25 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,18 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.models import BaseOperator, BaseOperatorLink, XCom
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunState
from airflow.providers.databricks.hooks.databricks import DatabricksHook, RunLifeCycleState, RunState
from airflow.providers.databricks.operators.databricks_workflow import (
DatabricksWorkflowTaskGroup,
WorkflowRunMetadata,
)
from airflow.providers.databricks.triggers.databricks import DatabricksExecutionTrigger
from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event

if TYPE_CHECKING:
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.context import Context
from airflow.utils.task_group import TaskGroup

DEFER_METHOD_NAME = "execute_complete"
XCOM_RUN_ID_KEY = "run_id"
Expand Down Expand Up @@ -926,7 +931,10 @@ class DatabricksNotebookOperator(BaseOperator):
:param deferrable: Run operator in the deferrable mode.
"""

template_fields = ("notebook_params",)
template_fields = (
"notebook_params",
"workflow_run_metadata",
)
CALLER = "DatabricksNotebookOperator"

def __init__(
Expand All @@ -944,6 +952,7 @@ def __init__(
databricks_retry_args: dict[Any, Any] | None = None,
wait_for_termination: bool = True,
databricks_conn_id: str = "databricks_default",
workflow_run_metadata: dict | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs: Any,
):
Expand All @@ -962,6 +971,10 @@ def __init__(
self.databricks_conn_id = databricks_conn_id
self.databricks_run_id: int | None = None
self.deferrable = deferrable

# This is used to store the metadata of the Databricks job run when the job is launched from within DatabricksWorkflowTaskGroup.
self.workflow_run_metadata: dict | None = workflow_run_metadata

super().__init__(**kwargs)

@cached_property
Expand Down Expand Up @@ -1016,6 +1029,79 @@ def _get_databricks_task_id(self, task_id: str) -> str:
"""Get the databricks task ID using dag_id and task_id. Removes illegal characters."""
return f"{self.dag_id}__{task_id.replace('.', '__')}"

@property
def _databricks_workflow_task_group(self) -> DatabricksWorkflowTaskGroup | None:
"""
Traverse up parent TaskGroups until the `is_databricks` flag associated with the root DatabricksWorkflowTaskGroup is found.

If found, returns the task group. Otherwise, return None.
"""
parent_tg: TaskGroup | DatabricksWorkflowTaskGroup | None = self.task_group

while parent_tg:
if getattr(parent_tg, "is_databricks", False):
return parent_tg # type: ignore[return-value]

if getattr(parent_tg, "task_group", None):
parent_tg = parent_tg.task_group
else:
return None

return None

def _extend_workflow_notebook_packages(
self, databricks_workflow_task_group: DatabricksWorkflowTaskGroup
) -> None:
"""Extend the task group packages into the notebook's packages, without adding any duplicates."""
for task_group_package in databricks_workflow_task_group.notebook_packages:
exists = any(
task_group_package == existing_package for existing_package in self.notebook_packages
)
if not exists:
self.notebook_packages.append(task_group_package)

def _convert_to_databricks_workflow_task(
self, relevant_upstreams: list[BaseOperator], context: Context | None = None
) -> dict[str, object]:
"""Convert the operator to a Databricks workflow task that can be a task in a workflow."""
databricks_workflow_task_group = self._databricks_workflow_task_group
if not databricks_workflow_task_group:
raise AirflowException(
"Calling `_convert_to_databricks_workflow_task` without a parent TaskGroup."
)

if hasattr(databricks_workflow_task_group, "notebook_packages"):
self._extend_workflow_notebook_packages(databricks_workflow_task_group)

if hasattr(databricks_workflow_task_group, "notebook_params"):
self.notebook_params = {
**self.notebook_params,
**databricks_workflow_task_group.notebook_params,
}

base_task_json = self._get_task_base_json()
result = {
"task_key": self._get_databricks_task_id(self.task_id),
"depends_on": [
{"task_key": self._get_databricks_task_id(task_id)}
for task_id in self.upstream_task_ids
if task_id in relevant_upstreams
],
**base_task_json,
}

if self.existing_cluster_id and self.job_cluster_key:
raise ValueError(
"Both existing_cluster_id and job_cluster_key are set. Only one can be set per task."
)

if self.existing_cluster_id:
result["existing_cluster_id"] = self.existing_cluster_id
elif self.job_cluster_key:
result["job_cluster_key"] = self.job_cluster_key

return result

def _get_run_json(self) -> dict[str, Any]:
"""Get run json to be used for task submissions."""
run_json = {
Expand All @@ -1039,6 +1125,17 @@ def launch_notebook_job(self) -> int:
self.log.info("Check the job run in Databricks: %s", url)
return self.databricks_run_id

def _handle_terminal_run_state(self, run_state: RunState) -> None:
if run_state.life_cycle_state != RunLifeCycleState.TERMINATED.value:
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)

def monitor_databricks_job(self) -> None:
if self.databricks_run_id is None:
raise ValueError("Databricks job not yet launched. Please run launch_notebook_job first.")
Expand All @@ -1063,34 +1160,28 @@ def monitor_databricks_job(self) -> None:
run = self._hook.get_run(self.databricks_run_id)
run_state = RunState(**run["state"])
self.log.info(
"task %s %s", self._get_databricks_task_id(self.task_id), run_state.life_cycle_state
)
self.log.info("Current state of the job: %s", run_state.life_cycle_state)
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
"Current state of the databricks task %s is %s",
self._get_databricks_task_id(self.task_id),
run_state.life_cycle_state,
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
self._handle_terminal_run_state(run_state)

def execute(self, context: Context) -> None:
self.launch_notebook_job()
if self._databricks_workflow_task_group:
# If we are in a DatabricksWorkflowTaskGroup, we should have an upstream task launched.
if not self.workflow_run_metadata:
launch_task_id = next(task for task in self.upstream_task_ids if task.endswith(".launch"))
self.workflow_run_metadata = context["ti"].xcom_pull(task_ids=launch_task_id)
workflow_run_metadata = WorkflowRunMetadata( # type: ignore[arg-type]
**self.workflow_run_metadata
)
self.databricks_run_id = workflow_run_metadata.run_id
self.databricks_conn_id = workflow_run_metadata.conn_id
else:
self.launch_notebook_job()
if self.wait_for_termination:
self.monitor_databricks_job()

def execute_complete(self, context: dict | None, event: dict) -> None:
run_state = RunState.from_json(event["run_state"])
if run_state.life_cycle_state != "TERMINATED":
raise AirflowException(
f"Databricks job failed with state {run_state.life_cycle_state}. "
f"Message: {run_state.state_message}"
)
if not run_state.is_successful:
raise AirflowException(
f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}"
)
self.log.info("Task succeeded. Final state %s.", run_state.result_state)
self._handle_terminal_run_state(run_state)
Loading