Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[openlineage] Add error stacktrace to task fail event #39813

Merged
merged 1 commit into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion airflow/providers/openlineage/plugins/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def fail_task(
parent_run_id: str | None,
end_time: str,
task: OperatorLineage,
error: str | BaseException | None = None,
) -> RunEvent:
"""
Emit openlineage event of type FAIL.
Expand All @@ -287,7 +288,16 @@ def fail_task(
:param parent_run_id: identifier of job spawning this task
:param end_time: time of task completion
:param task: metadata container with information extracted from operator
:param error: error
"""
error_facet = {}
if error:
if isinstance(error, BaseException):
import traceback

error = "\\n".join(traceback.format_exception(type(error), error, error.__traceback__))
error_facet = {"errorMessage": ErrorMessageRunFacet(message=error, programmingLanguage="python")}

event = RunEvent(
eventType=RunState.FAIL,
eventTime=end_time,
Expand All @@ -296,7 +306,7 @@ def fail_task(
job_name=job_name,
parent_job_name=parent_job_name,
parent_run_id=parent_run_id,
run_facets=task.run_facets,
run_facets={**task.run_facets, **error_facet},
),
job=self._build_job(job_name, job_type=_JOB_TYPE_TASK, job_facets=task.job_facets),
inputs=task.inputs,
Expand Down
66 changes: 49 additions & 17 deletions airflow/providers/openlineage/plugins/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
from typing import TYPE_CHECKING

from openlineage.client.serde import Serde
from packaging.version import Version

from airflow import __version__ as airflow_version, settings
from airflow import __version__ as AIRFLOW_VERSION, settings
from airflow.listeners import hookimpl
from airflow.providers.openlineage import conf
from airflow.providers.openlineage.extractors import ExtractorManager
Expand All @@ -43,18 +44,17 @@
from sqlalchemy.orm import Session

from airflow.models import DagRun, TaskInstance
from airflow.utils.state import TaskInstanceState

_openlineage_listener: OpenLineageListener | None = None
_IS_AIRFLOW_2_10_OR_HIGHER = Version(Version(AIRFLOW_VERSION).base_version) >= Version("2.10.0")


def _get_try_number_success(val):
# todo: remove when min airflow version >= 2.10.0
from packaging.version import parse

if parse(parse(airflow_version).base_version) < parse("2.10.0"):
return val.try_number - 1
else:
if _IS_AIRFLOW_2_10_OR_HIGHER:
return val.try_number
return val.try_number - 1


class OpenLineageListener:
Expand All @@ -69,10 +69,10 @@ def __init__(self):
@hookimpl
def on_task_instance_running(
self,
previous_state,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session, # This will always be QUEUED
):
) -> None:
if not getattr(task_instance, "task", None) is not None:
self.log.warning(
"No task set for TI object task_id: %s - dag_id: %s - run_id %s",
Expand Down Expand Up @@ -159,7 +159,9 @@ def on_running():
on_running()

@hookimpl
def on_task_instance_success(self, previous_state, task_instance: TaskInstance, session):
def on_task_instance_success(
self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
) -> None:
self.log.debug("OpenLineage listener got notification about task instance success")

dagrun = task_instance.dag_run
Expand Down Expand Up @@ -223,8 +225,37 @@ def on_success():

on_success()

@hookimpl
def on_task_instance_failed(self, previous_state, task_instance: TaskInstance, session):
if _IS_AIRFLOW_2_10_OR_HIGHER:

@hookimpl
def on_task_instance_failed(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
error: None | str | BaseException,
session: Session,
) -> None:
self._on_task_instance_failed(
previous_state=previous_state, task_instance=task_instance, error=error, session=session
)

else:

@hookimpl
def on_task_instance_failed(
self, previous_state: TaskInstanceState, task_instance: TaskInstance, session: Session
) -> None:
self._on_task_instance_failed(
previous_state=previous_state, task_instance=task_instance, error=None, session=session
)

def _on_task_instance_failed(
self,
previous_state: TaskInstanceState,
task_instance: TaskInstance,
session: Session,
error: None | str | BaseException = None,
) -> None:
self.log.debug("OpenLineage listener got notification about task instance failure")

dagrun = task_instance.dag_run
Expand Down Expand Up @@ -280,6 +311,7 @@ def on_failure():
parent_run_id=parent_run_id,
end_time=end_date.isoformat(),
task=task_metadata,
error=error,
)
Stats.gauge(
f"ol.event.size.{event_type}.{operator_name}",
Expand All @@ -289,7 +321,7 @@ def on_failure():
on_failure()

@property
def executor(self):
def executor(self) -> ProcessPoolExecutor:
def initializer():
# Re-configure the ORM engine as there are issues with multiple processes
# if process calls Airflow DB.
Expand All @@ -303,17 +335,17 @@ def initializer():
return self._executor

@hookimpl
def on_starting(self, component):
def on_starting(self, component) -> None:
self.log.debug("on_starting: %s", component.__class__.__name__)

@hookimpl
def before_stopping(self, component):
def before_stopping(self, component) -> None:
self.log.debug("before_stopping: %s", component.__class__.__name__)
with timeout(30):
self.executor.shutdown(wait=True)

@hookimpl
def on_dag_run_running(self, dag_run: DagRun, msg: str):
def on_dag_run_running(self, dag_run: DagRun, msg: str) -> None:
if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag):
self.log.debug(
"Skipping OpenLineage event emission for DAG `%s` "
Expand All @@ -338,7 +370,7 @@ def on_dag_run_running(self, dag_run: DagRun, msg: str):
)

@hookimpl
def on_dag_run_success(self, dag_run: DagRun, msg: str):
def on_dag_run_success(self, dag_run: DagRun, msg: str) -> None:
if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag):
self.log.debug(
"Skipping OpenLineage event emission for DAG `%s` "
Expand All @@ -355,7 +387,7 @@ def on_dag_run_success(self, dag_run: DagRun, msg: str):
self.executor.submit(self.adapter.dag_success, dag_run=dag_run, msg=msg)

@hookimpl
def on_dag_run_failed(self, dag_run: DagRun, msg: str):
def on_dag_run_failed(self, dag_run: DagRun, msg: str) -> None:
if dag_run.dag and not is_selective_lineage_enabled(dag_run.dag):
self.log.debug(
"Skipping OpenLineage event emission for DAG `%s` "
Expand Down
109 changes: 41 additions & 68 deletions tests/providers/openlineage/plugins/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
pytestmark = pytest.mark.db_test

EXPECTED_TRY_NUMBER_1 = 1 if AIRFLOW_V_2_10_PLUS else 0
EXPECTED_TRY_NUMBER_2 = 2 if AIRFLOW_V_2_10_PLUS else 1

TRY_NUMBER_BEFORE_EXECUTION = 0 if AIRFLOW_V_2_10_PLUS else 1
TRY_NUMBER_RUNNING = 0 if AIRFLOW_V_2_10_PLUS else 1
Expand Down Expand Up @@ -276,14 +275,21 @@ def mock_task_id(dag_id, task_id, try_number, execution_date):
mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
mock_disabled.return_value = False

listener.on_task_instance_failed(None, task_instance, None)
err = ValueError("test")
on_task_failed_listener_kwargs = {"error": err} if AIRFLOW_V_2_10_PLUS else {}
expected_err_kwargs = {"error": err if AIRFLOW_V_2_10_PLUS else None}
mobuchowski marked this conversation as resolved.
Show resolved Hide resolved

listener.on_task_instance_failed(
previous_state=None, task_instance=task_instance, session=None, **on_task_failed_listener_kwargs
)
listener.adapter.fail_task.assert_called_once_with(
end_time="2023-01-03T13:01:01",
job_name="job_name",
parent_job_name="dag_id",
parent_run_id="execution_date.dag_id",
run_id="execution_date.dag_id.task_id.1",
task=listener.extractor_manager.extract_metadata(),
**expected_err_kwargs,
)


Expand Down Expand Up @@ -316,7 +322,7 @@ def mock_task_id(dag_id, task_id, try_number, execution_date):

listener.on_task_instance_success(None, task_instance, None)
# This run_id will be different as we did NOT simulate increase of the try_number attribute,
# which happens in Airflow.
# which happens in Airflow < 2.10.
calls = listener.adapter.complete_task.call_args_list
assert len(calls) == 1
assert calls[0][1] == dict(
Expand All @@ -328,65 +334,8 @@ def mock_task_id(dag_id, task_id, try_number, execution_date):
task=listener.extractor_manager.extract_metadata(),
)

# Now we simulate the increase of try_number, and the run_id should reflect that change.
listener.adapter.complete_task.reset_mock()
task_instance.try_number += 1
listener.on_task_instance_success(None, task_instance, None)
calls = listener.adapter.complete_task.call_args_list
assert len(calls) == 1
assert calls[0][1] == dict(
end_time="2023-01-03T13:01:01",
job_name="job_name",
parent_job_name="dag_id",
parent_run_id="execution_date.dag_id",
run_id=f"execution_date.dag_id.task_id.{EXPECTED_TRY_NUMBER_2}",
task=listener.extractor_manager.extract_metadata(),
)


@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
def test_run_id_is_constant_across_all_methods(mocked_adapter):
"""Tests that the run_id remains constant across different methods of the listener.

It ensures that the run_id generated for starting, failing, and completing a task is consistent,
reflecting the task's identity and execution context. The test also simulates the change in the
try_number attribute, as it would occur in Airflow, to verify that the run_id updates accordingly.
"""

def mock_task_id(dag_id, task_id, try_number, execution_date):
returned_try_number = try_number if AIRFLOW_V_2_10_PLUS else max(try_number - 1, 1)
return f"{execution_date}.{dag_id}.{task_id}.{returned_try_number}"

listener, task_instance = _create_listener_and_task_instance()
mocked_adapter.build_task_instance_run_id.side_effect = mock_task_id
expected_run_id_1 = "execution_date.dag_id.task_id.1"
expected_run_id_2 = "execution_date.dag_id.task_id.2"
listener.on_task_instance_running(None, task_instance, None)
assert listener.adapter.start_task.call_args.kwargs["run_id"] == expected_run_id_1

listener.on_task_instance_failed(None, task_instance, None)
assert (
listener.adapter.fail_task.call_args.kwargs["run_id"] == expected_run_id_1
if AIRFLOW_V_2_10_PLUS
else expected_run_id_2
)

# This run_id will not be different as we did NOT simulate increase of the try_number attribute,
listener.on_task_instance_success(None, task_instance, None)
assert listener.adapter.complete_task.call_args.kwargs["run_id"] == expected_run_id_1

# Now we simulate the increase of try_number, and the run_id should reflect that change.
# This is how airflow works, and that's why we expect the run_id to remain constant across all methods.
task_instance.try_number += 1
listener.on_task_instance_success(None, task_instance, None)
assert (
listener.adapter.complete_task.call_args.kwargs["run_id"] == expected_run_id_2
if AIRFLOW_V_2_10_PLUS
else expected_run_id_1
)


def test_running_task_correctly_calls_openlineage_adapter_run_id_method():
def test_on_task_instance_running_correctly_calls_openlineage_adapter_run_id_method():
"""Tests the OpenLineageListener's response when a task instance is in the running state.

This test ensures that when an Airflow task instance transitions to the running state,
Expand All @@ -404,15 +353,19 @@ def test_running_task_correctly_calls_openlineage_adapter_run_id_method():


@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
def test_failed_task_correctly_calls_openlineage_adapter_run_id_method(mock_adapter):
def test_on_task_instance_failed_correctly_calls_openlineage_adapter_run_id_method(mock_adapter):
"""Tests the OpenLineageListener's response when a task instance is in the failed state.

This test ensures that when an Airflow task instance transitions to the failed state,
the OpenLineageAdapter's `build_task_instance_run_id` method is called exactly once with the correct
parameters derived from the task instance.
"""
listener, task_instance = _create_listener_and_task_instance()
listener.on_task_instance_failed(None, task_instance, None)
on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {}

listener.on_task_instance_failed(
previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs
)
mock_adapter.build_task_instance_run_id.assert_called_once_with(
dag_id="dag_id",
task_id="task_id",
Expand All @@ -422,7 +375,7 @@ def test_failed_task_correctly_calls_openlineage_adapter_run_id_method(mock_adap


@mock.patch("airflow.providers.openlineage.plugins.listener.OpenLineageAdapter")
def test_successful_task_correctly_calls_openlineage_adapter_run_id_method(mock_adapter):
def test_on_task_instance_success_correctly_calls_openlineage_adapter_run_id_method(mock_adapter):
"""Tests the OpenLineageListener's response when a task instance is in the success state.

This test ensures that when an Airflow task instance transitions to the success state,
Expand Down Expand Up @@ -530,7 +483,11 @@ def test_listener_on_task_instance_failed_do_not_call_adapter_when_disabled_oper
listener, task_instance = _create_listener_and_task_instance()
mock_disabled.return_value = True

listener.on_task_instance_failed(None, task_instance, None)
on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {}

listener.on_task_instance_failed(
previous_state=None, task_instance=task_instance, session=None, **on_task_failed_kwargs
)
mock_disabled.assert_called_once_with(task_instance.task)
mocked_adapter.build_dag_run_id.assert_not_called()
mocked_adapter.build_task_instance_run_id.assert_not_called()
Expand Down Expand Up @@ -645,6 +602,8 @@ def test_listener_with_task_enabled(
if enable_task:
enable_lineage(self.task_1)

on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {}

conf.selective_enable.cache_clear()
with conf_vars({("openlineage", "selective_enable"): selective_enable}):
listener = OpenLineageListener()
Expand All @@ -662,14 +621,24 @@ def test_listener_with_task_enabled(
# run TaskInstance-related hooks for lineage enabled task
listener.on_task_instance_running(None, self.task_instance_1, None)
listener.on_task_instance_success(None, self.task_instance_1, None)
listener.on_task_instance_failed(None, self.task_instance_1, None)
listener.on_task_instance_failed(
previous_state=None,
task_instance=self.task_instance_1,
session=None,
**on_task_failed_kwargs,
)

assert expected_task_call_count == listener.extractor_manager.extract_metadata.call_count

# run TaskInstance-related hooks for lineage disabled task
listener.on_task_instance_running(None, self.task_instance_2, None)
listener.on_task_instance_success(None, self.task_instance_2, None)
listener.on_task_instance_failed(None, self.task_instance_2, None)
listener.on_task_instance_failed(
previous_state=None,
task_instance=self.task_instance_2,
session=None,
**on_task_failed_kwargs,
)

# with selective-enable disabled both task_1 and task_2 should trigger metadata extraction
if selective_enable == "False":
Expand Down Expand Up @@ -697,6 +666,8 @@ def test_listener_with_dag_disabled_task_enabled(
if enable_task:
enable_lineage(self.task_1)

on_task_failed_kwargs = {"error": ValueError("test")} if AIRFLOW_V_2_10_PLUS else {}

conf.selective_enable.cache_clear()
with conf_vars({("openlineage", "selective_enable"): selective_enable}):
listener = OpenLineageListener()
Expand All @@ -712,7 +683,9 @@ def test_listener_with_dag_disabled_task_enabled(
# run TaskInstance-related hooks for lineage enabled task
listener.on_task_instance_running(None, self.task_instance_1, None)
listener.on_task_instance_success(None, self.task_instance_1, None)
listener.on_task_instance_failed(None, self.task_instance_1, None)
listener.on_task_instance_failed(
previous_state=None, task_instance=self.task_instance_1, session=None, **on_task_failed_kwargs
)

try:
assert expected_call_count == listener._executor.submit.call_count
Expand Down