From dbf0024946152f3b9c7b9b967ad1c7f1af84c3f9 Mon Sep 17 00:00:00 2001 From: Maksim Yermakou Date: Fri, 24 Jan 2025 11:00:09 +0000 Subject: [PATCH] Add backward compatibility for old Airflow version for CloudComposerDAGRunSensor --- .../google/cloud/sensors/cloud_composer.py | 41 +++++++++++++++---- .../google/cloud/triggers/cloud_composer.py | 14 ++++++- .../cloud/sensors/test_cloud_composer.py | 32 +++++++++++---- .../cloud/triggers/test_cloud_composer.py | 3 ++ 4 files changed, 72 insertions(+), 18 deletions(-) diff --git a/providers/src/airflow/providers/google/cloud/sensors/cloud_composer.py b/providers/src/airflow/providers/google/cloud/sensors/cloud_composer.py index d6e1622abefae..06ecbc579d564 100644 --- a/providers/src/airflow/providers/google/cloud/sensors/cloud_composer.py +++ b/providers/src/airflow/providers/google/cloud/sensors/cloud_composer.py @@ -22,10 +22,11 @@ import json from collections.abc import Iterable, Sequence from datetime import datetime, timedelta +from functools import cached_property from typing import TYPE_CHECKING from dateutil import parser -from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse +from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse from airflow.configuration import conf from airflow.exceptions import AirflowException @@ -122,6 +123,7 @@ def poke(self, context: Context) -> bool: if datetime.now(end_date.tzinfo) < end_date: return False + self._composer_airflow_version = self._get_composer_airflow_version() dag_runs = self._pull_dag_runs() self.log.info("Sensor waits for allowed states: %s", self.allowed_states) @@ -135,19 +137,20 @@ def poke(self, context: Context) -> bool: def _pull_dag_runs(self) -> list[dict]: """Pull the list of dag runs.""" - hook = CloudComposerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, + cmd_parameters = ( + ["-d", self.composer_dag_id, "-o", "json"] + if self._composer_airflow_version < 3 + else [self.composer_dag_id, "-o", "json"] ) - dag_runs_cmd = hook.execute_airflow_command( + dag_runs_cmd = self.hook.execute_airflow_command( project_id=self.project_id, region=self.region, environment_id=self.environment_id, command="dags", subcommand="list-runs", - parameters=["-d", self.composer_dag_id, "-o", "json"], + parameters=cmd_parameters, ) - cmd_result = hook.wait_command_execution_result( + cmd_result = self.hook.wait_command_execution_result( project_id=self.project_id, region=self.region, environment_id=self.environment_id, @@ -165,15 +168,29 @@ def _check_dag_runs_states( for dag_run in dag_runs: if ( start_date.timestamp() - < parser.parse(dag_run["logical_date"]).timestamp() + < parser.parse( + dag_run["execution_date" if self._composer_airflow_version < 3 else "logical_date"] + ).timestamp() < end_date.timestamp() ) and dag_run["state"] not in self.allowed_states: return False return True + def _get_composer_airflow_version(self) -> int: + """Return Composer Airflow version.""" + environment_obj = self.hook.get_environment( + project_id=self.project_id, + region=self.region, + environment_id=self.environment_id, + ) + environment_config = Environment.to_dict(environment_obj) + image_version = environment_config["config"]["software_config"]["image_version"] + return int(image_version.split("airflow-")[1].split(".")[0]) + def execute(self, context: Context) -> None: if self.deferrable: start_date, end_date = self._get_logical_dates(context) + self._composer_airflow_version = self._get_composer_airflow_version() self.defer( trigger=CloudComposerDAGRunTrigger( project_id=self.project_id, @@ -186,6 +203,7 @@ def execute(self, context: Context) -> None: gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, poll_interval=self.poll_interval, + composer_airflow_version=self._composer_airflow_version, ), method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME, ) @@ -195,3 +213,10 @@ def execute_complete(self, context: Context, event: dict): if event and event["status"] == "error": raise AirflowException(event["message"]) self.log.info("DAG %s has executed successfully.", self.composer_dag_id) + + @cached_property + def hook(self) -> CloudComposerHook: + return CloudComposerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) diff --git a/providers/src/airflow/providers/google/cloud/triggers/cloud_composer.py b/providers/src/airflow/providers/google/cloud/triggers/cloud_composer.py index c13681aaeff7c..96748bc2a7c49 100644 --- a/providers/src/airflow/providers/google/cloud/triggers/cloud_composer.py +++ b/providers/src/airflow/providers/google/cloud/triggers/cloud_composer.py @@ -169,6 +169,7 @@ def __init__( gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, poll_interval: int = 10, + composer_airflow_version: int = 2, ): super().__init__() self.project_id = project_id @@ -181,6 +182,7 @@ def __init__( self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain self.poll_interval = poll_interval + self.composer_airflow_version = composer_airflow_version self.gcp_hook = CloudComposerAsyncHook( gcp_conn_id=self.gcp_conn_id, @@ -201,18 +203,24 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "gcp_conn_id": self.gcp_conn_id, "impersonation_chain": self.impersonation_chain, "poll_interval": self.poll_interval, + "composer_airflow_version": self.composer_airflow_version, }, ) async def _pull_dag_runs(self) -> list[dict]: """Pull the list of dag runs.""" + cmd_parameters = ( + ["-d", self.composer_dag_id, "-o", "json"] + if self.composer_airflow_version < 3 + else [self.composer_dag_id, "-o", "json"] + ) dag_runs_cmd = await self.gcp_hook.execute_airflow_command( project_id=self.project_id, region=self.region, environment_id=self.environment_id, command="dags", subcommand="list-runs", - parameters=["-d", self.composer_dag_id, "-o", "json"], + parameters=cmd_parameters, ) cmd_result = await self.gcp_hook.wait_command_execution_result( project_id=self.project_id, @@ -232,7 +240,9 @@ def _check_dag_runs_states( for dag_run in dag_runs: if ( start_date.timestamp() - < parser.parse(dag_run["logical_date"]).timestamp() + < parser.parse( + dag_run["execution_date" if self.composer_airflow_version < 3 else "logical_date"] + ).timestamp() < end_date.timestamp() ) and dag_run["state"] not in self.allowed_states: return False diff --git a/providers/tests/google/cloud/sensors/test_cloud_composer.py b/providers/tests/google/cloud/sensors/test_cloud_composer.py index 091839686ab7f..53f1651a986c5 100644 --- a/providers/tests/google/cloud/sensors/test_cloud_composer.py +++ b/providers/tests/google/cloud/sensors/test_cloud_composer.py @@ -21,36 +21,45 @@ from datetime import datetime from unittest import mock +import pytest + from airflow.providers.google.cloud.sensors.cloud_composer import CloudComposerDAGRunSensor TEST_PROJECT_ID = "test_project_id" TEST_OPERATION_NAME = "test_operation_name" TEST_REGION = "region" TEST_ENVIRONMENT_ID = "test_env_id" -TEST_JSON_RESULT = lambda state: json.dumps( +TEST_JSON_RESULT = lambda state, date_key: json.dumps( [ { "dag_id": "test_dag_id", "run_id": "scheduled__2024-05-22T11:10:00+00:00", "state": state, - "logical_date": "2024-05-22T11:10:00+00:00", + date_key: "2024-05-22T11:10:00+00:00", "start_date": "2024-05-22T11:20:01.531988+00:00", "end_date": "2024-05-22T11:20:11.997479+00:00", } ] ) -TEST_EXEC_RESULT = lambda state: { - "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state)}], +TEST_EXEC_RESULT = lambda state, date_key: { + "output": [{"line_number": 1, "content": TEST_JSON_RESULT(state, date_key)}], "output_end": True, "exit_info": {"exit_code": 0, "error": ""}, } class TestCloudComposerDAGRunSensor: + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch( + "airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor._get_composer_airflow_version" + ) @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") - def test_wait_ready(self, mock_hook, to_dict_mode): - mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("success") + def test_wait_ready(self, mock_hook, mock_get_version, to_dict_mode, composer_airflow_version): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT( + "success", "execution_date" if composer_airflow_version < 3 else "logical_date" + ) + mock_get_version.return_value = composer_airflow_version task = CloudComposerDAGRunSensor( task_id="task-id", @@ -63,10 +72,17 @@ def test_wait_ready(self, mock_hook, to_dict_mode): assert task.poke(context={"logical_date": datetime(2024, 5, 23, 0, 0, 0)}) + @pytest.mark.parametrize("composer_airflow_version", [2, 3]) @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.ExecuteAirflowCommandResponse.to_dict") + @mock.patch( + "airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerDAGRunSensor._get_composer_airflow_version" + ) @mock.patch("airflow.providers.google.cloud.sensors.cloud_composer.CloudComposerHook") - def test_wait_not_ready(self, mock_hook, to_dict_mode): - mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT("running") + def test_wait_not_ready(self, mock_hook, mock_get_version, to_dict_mode, composer_airflow_version): + mock_hook.return_value.wait_command_execution_result.return_value = TEST_EXEC_RESULT( + "running", "execution_date" if composer_airflow_version < 3 else "logical_date" + ) + mock_get_version.return_value = composer_airflow_version task = CloudComposerDAGRunSensor( task_id="task-id", diff --git a/providers/tests/google/cloud/triggers/test_cloud_composer.py b/providers/tests/google/cloud/triggers/test_cloud_composer.py index 00d109ed975a1..8716805fa133d 100644 --- a/providers/tests/google/cloud/triggers/test_cloud_composer.py +++ b/providers/tests/google/cloud/triggers/test_cloud_composer.py @@ -44,6 +44,7 @@ TEST_STATES = ["success"] TEST_GCP_CONN_ID = "test_gcp_conn_id" TEST_POLL_INTERVAL = 10 +TEST_COMPOSER_AIRFLOW_VERSION = 3 TEST_IMPERSONATION_CHAIN = "test_impersonation_chain" TEST_EXEC_RESULT = { "output": [{"line_number": 1, "content": "test_content"}], @@ -86,6 +87,7 @@ def dag_run_trigger(mock_conn): gcp_conn_id=TEST_GCP_CONN_ID, impersonation_chain=TEST_IMPERSONATION_CHAIN, poll_interval=TEST_POLL_INTERVAL, + composer_airflow_version=TEST_COMPOSER_AIRFLOW_VERSION, ) @@ -140,6 +142,7 @@ def test_serialize(self, dag_run_trigger): "gcp_conn_id": TEST_GCP_CONN_ID, "impersonation_chain": TEST_IMPERSONATION_CHAIN, "poll_interval": TEST_POLL_INTERVAL, + "composer_airflow_version": TEST_COMPOSER_AIRFLOW_VERSION, }, ) assert actual_data == expected_data