Skip to content

Commit

Permalink
Add backward compatibility for old Airflow version for CloudComposerD…
Browse files Browse the repository at this point in the history
…AGRunSensor
  • Loading branch information
MaksYermak committed Jan 24, 2025
1 parent 7a28f29 commit dbf0024
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import json
from collections.abc import Iterable, Sequence
from datetime import datetime, timedelta
from functools import cached_property
from typing import TYPE_CHECKING

from dateutil import parser
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse

from airflow.configuration import conf
from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -122,6 +123,7 @@ def poke(self, context: Context) -> bool:
if datetime.now(end_date.tzinfo) < end_date:
return False

self._composer_airflow_version = self._get_composer_airflow_version()
dag_runs = self._pull_dag_runs()

self.log.info("Sensor waits for allowed states: %s", self.allowed_states)
Expand All @@ -135,19 +137,20 @@ def poke(self, context: Context) -> bool:

def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
hook = CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
cmd_parameters = (
["-d", self.composer_dag_id, "-o", "json"]
if self._composer_airflow_version < 3
else [self.composer_dag_id, "-o", "json"]
)
dag_runs_cmd = hook.execute_airflow_command(
dag_runs_cmd = self.hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
parameters=cmd_parameters,
)
cmd_result = hook.wait_command_execution_result(
cmd_result = self.hook.wait_command_execution_result(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
Expand All @@ -165,15 +168,29 @@ def _check_dag_runs_states(
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["logical_date"]).timestamp()
< parser.parse(
dag_run["execution_date" if self._composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
return True

def _get_composer_airflow_version(self) -> int:
"""Return Composer Airflow version."""
environment_obj = self.hook.get_environment(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
)
environment_config = Environment.to_dict(environment_obj)
image_version = environment_config["config"]["software_config"]["image_version"]
return int(image_version.split("airflow-")[1].split(".")[0])

def execute(self, context: Context) -> None:
if self.deferrable:
start_date, end_date = self._get_logical_dates(context)
self._composer_airflow_version = self._get_composer_airflow_version()
self.defer(
trigger=CloudComposerDAGRunTrigger(
project_id=self.project_id,
Expand All @@ -186,6 +203,7 @@ def execute(self, context: Context) -> None:
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
poll_interval=self.poll_interval,
composer_airflow_version=self._composer_airflow_version,
),
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
)
Expand All @@ -195,3 +213,10 @@ def execute_complete(self, context: Context, event: dict):
if event and event["status"] == "error":
raise AirflowException(event["message"])
self.log.info("DAG %s has executed successfully.", self.composer_dag_id)

@cached_property
def hook(self) -> CloudComposerHook:
return CloudComposerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
poll_interval: int = 10,
composer_airflow_version: int = 2,
):
super().__init__()
self.project_id = project_id
Expand All @@ -181,6 +182,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.poll_interval = poll_interval
self.composer_airflow_version = composer_airflow_version

self.gcp_hook = CloudComposerAsyncHook(
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -201,18 +203,24 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"gcp_conn_id": self.gcp_conn_id,
"impersonation_chain": self.impersonation_chain,
"poll_interval": self.poll_interval,
"composer_airflow_version": self.composer_airflow_version,
},
)

async def _pull_dag_runs(self) -> list[dict]:
"""Pull the list of dag runs."""
cmd_parameters = (
["-d", self.composer_dag_id, "-o", "json"]
if self.composer_airflow_version < 3
else [self.composer_dag_id, "-o", "json"]
)
dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
project_id=self.project_id,
region=self.region,
environment_id=self.environment_id,
command="dags",
subcommand="list-runs",
parameters=["-d", self.composer_dag_id, "-o", "json"],
parameters=cmd_parameters,
)
cmd_result = await self.gcp_hook.wait_command_execution_result(
project_id=self.project_id,
Expand All @@ -232,7 +240,9 @@ def _check_dag_runs_states(
for dag_run in dag_runs:
if (
start_date.timestamp()
< parser.parse(dag_run["logical_date"]).timestamp()
< parser.parse(
dag_run["execution_date" if self.composer_airflow_version < 3 else "logical_date"]
).timestamp()
< end_date.timestamp()
) and dag_run["state"] not in self.allowed_states:
return False
Expand Down
32 changes: 24 additions & 8 deletions providers/tests/google/cloud/sensors/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,36 +21,45 @@
from datetime import datetime
from unittest import mock

import pytest

from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerDAGRunSensor

TEST_PROJECT_ID = "test_project_id"
TEST_OPERATION_NAME = "test_operation_name"
TEST_REGION = "region"
TEST_ENVIRONMENT_ID = "test_env_id"
TEST_JSON_RESULT = lambda state: json.dumps(
TEST_JSON_RESULT = lambda state, date_key: json.dumps(
[
{
"dag_id": "test_dag_id",
"run_id": "scheduled__2024-05-22T11:10:00+00:00",
"state": state,
"logical_date": "2024-05-22T11:10:00+00:00",
date_key: "2024-05-22T11:10:00+00:00",
"start_date": "2024-05-22T11:20:01.531988+00:00",
"end_date": "2024-05-22T11:20:11.997479+00:00",
}
]
)
TEST_EXEC_RESULT = lambda state: {
"output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}],
TEST_EXEC_RESULT = lambda state, date_key: {
"output": [{"line_number": 1, "content": TEST_JSON_RESULT(state, date_key)}],
"output_end": True,
"exit_info": {"exit_code": 0, "error": ""},
}


class TestCloudComposerDAGRunSensor:
@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch(
"airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor._get_composer_airflow_version"
)
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_ready(self, mock_hook, to_dict_mode):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("success")
def test_wait_ready(self, mock_hook, mock_get_version, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"success", "execution_date" if composer_airflow_version < 3 else "logical_date"
)
mock_get_version.return_value = composer_airflow_version

task = CloudComposerDAGRunSensor(
task_id="task-id",
Expand All @@ -63,10 +72,17 @@ def test_wait_ready(self, mock_hook, to_dict_mode):

assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)})

@pytest.mark.parametrize("composer_airflow_version", [2, 3])
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict")
@mock.patch(
"airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor._get_composer_airflow_version"
)
@mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook")
def test_wait_not_ready(self, mock_hook, to_dict_mode):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("running")
def test_wait_not_ready(self, mock_hook, mock_get_version, to_dict_mode, composer_airflow_version):
mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT(
"running", "execution_date" if composer_airflow_version < 3 else "logical_date"
)
mock_get_version.return_value = composer_airflow_version

task = CloudComposerDAGRunSensor(
task_id="task-id",
Expand Down
3 changes: 3 additions & 0 deletions providers/tests/google/cloud/triggers/test_cloud_composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TEST_STATES = ["success"]
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_POLL_INTERVAL = 10
TEST_COMPOSER_AIRFLOW_VERSION = 3
TEST_IMPERSONATION_CHAIN = "test_impersonation_chain"
TEST_EXEC_RESULT = {
"output": [{"line_number": 1, "content": "test_content"}],
Expand Down Expand Up @@ -86,6 +87,7 @@ def dag_run_trigger(mock_conn):
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=TEST_IMPERSONATION_CHAIN,
poll_interval=TEST_POLL_INTERVAL,
composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION,
)


Expand Down Expand Up @@ -140,6 +142,7 @@ def test_serialize(self, dag_run_trigger):
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": TEST_IMPERSONATION_CHAIN,
"poll_interval": TEST_POLL_INTERVAL,
"composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION,
},
)
assert actual_data == expected_data

0 comments on commit dbf0024

Please sign in to comment.