diff --git a/providers/src/airflow/providers/microsoft/azure/operators/powerbi.py b/providers/src/airflow/providers/microsoft/azure/operators/powerbi.py index b334a1708a6f7..1c6878c27af68 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/powerbi.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/powerbi.py @@ -114,6 +114,35 @@ def execute(self, context: Context): check_interval=self.check_interval, wait_for_termination=self.wait_for_termination, ), + method_name=self.get_refresh_status.__name__, + ) + + def get_refresh_status(self, context: Context, event: dict[str, str] | None = None): + """Push the refresh Id to XCom then runs the Trigger to wait for refresh completion.""" + if event: + if event["status"] == "error": + raise AirflowException(event["message"]) + + dataset_refresh_id = event["dataset_refresh_id"] + + if dataset_refresh_id: + self.xcom_push( + context=context, + key=f"{self.task_id}.powerbi_dataset_refresh_Id", + value=dataset_refresh_id, + ) + self.defer( + trigger=PowerBITrigger( + conn_id=self.conn_id, + group_id=self.group_id, + dataset_id=self.dataset_id, + dataset_refresh_id=dataset_refresh_id, + timeout=self.timeout, + proxies=self.proxies, + api_version=self.api_version, + check_interval=self.check_interval, + wait_for_termination=self.wait_for_termination, + ), method_name=self.execute_complete.__name__, ) @@ -124,10 +153,10 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> Any: Relies on trigger to throw an exception, otherwise it assumes execution was successful. """ if event: - if event["status"] == "error": - raise AirflowException(event["message"]) - self.xcom_push( - context=context, key="powerbi_dataset_refresh_Id", value=event["dataset_refresh_id"] + context=context, + key=f"{self.task_id}.powerbi_dataset_refresh_status", + value=event["dataset_refresh_status"], ) - self.xcom_push(context=context, key="powerbi_dataset_refresh_status", value=event["status"]) + if event["status"] == "error": + raise AirflowException(event["message"]) diff --git a/providers/src/airflow/providers/microsoft/azure/triggers/powerbi.py b/providers/src/airflow/providers/microsoft/azure/triggers/powerbi.py index a2132f6c39384..8f749311806ec 100644 --- a/providers/src/airflow/providers/microsoft/azure/triggers/powerbi.py +++ b/providers/src/airflow/providers/microsoft/azure/triggers/powerbi.py @@ -22,7 +22,13 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING -from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook +import tenacity + +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIDatasetRefreshException, + PowerBIDatasetRefreshStatus, + PowerBIHook, +) from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: @@ -43,6 +49,7 @@ class PowerBITrigger(BaseTrigger): You can pass an enum named APIVersion which has 2 possible members v1 and beta, or you can pass a string as `v1.0` or `beta`. :param dataset_id: The dataset Id to refresh. + :param dataset_refresh_id: The dataset refresh Id to poll for the status, if not provided a new refresh will be triggered. :param group_id: The workspace Id where dataset is located. :param end_time: Time in seconds when trigger should stop polling. :param check_interval: Time in seconds to wait between each poll. @@ -55,6 +62,7 @@ def __init__( dataset_id: str, group_id: str, timeout: float = 60 * 60 * 24 * 7, + dataset_refresh_id: str | None = None, proxies: dict | None = None, api_version: APIVersion | str | None = None, check_interval: int = 60, @@ -63,6 +71,7 @@ def __init__( super().__init__() self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout) self.dataset_id = dataset_id + self.dataset_refresh_id = dataset_refresh_id self.timeout = timeout self.group_id = group_id self.check_interval = check_interval @@ -77,6 +86,7 @@ def serialize(self): "proxies": self.proxies, "api_version": self.api_version, "dataset_id": self.dataset_id, + "dataset_refresh_id": self.dataset_refresh_id, "group_id": self.group_id, "timeout": self.timeout, "check_interval": self.check_interval, @@ -98,19 +108,53 @@ def api_version(self) -> APIVersion | str: async def run(self) -> AsyncIterator[TriggerEvent]: """Make async connection to the PowerBI and polls for the dataset refresh status.""" - self.dataset_refresh_id = await self.hook.trigger_dataset_refresh( - dataset_id=self.dataset_id, - group_id=self.group_id, - ) - - async def fetch_refresh_status_and_error() -> tuple[str, str]: - """Fetch the current status and error of the dataset refresh.""" - refresh_details = await self.hook.get_refresh_details_by_refresh_id( + if not self.dataset_refresh_id: + # Trigger the dataset refresh + dataset_refresh_id = await self.hook.trigger_dataset_refresh( dataset_id=self.dataset_id, group_id=self.group_id, - refresh_id=self.dataset_refresh_id, ) - return refresh_details["status"], refresh_details["error"] + + if dataset_refresh_id: + self.log.info("Triggered dataset refresh %s", dataset_refresh_id) + yield TriggerEvent( + { + "status": "success", + "dataset_refresh_status": None, + "message": f"The dataset refresh {dataset_refresh_id} has been triggered.", + "dataset_refresh_id": dataset_refresh_id, + } + ) + return + + yield TriggerEvent( + { + "status": "error", + "dataset_refresh_status": None, + "message": "Failed to trigger the dataset refresh.", + "dataset_refresh_id": None, + } + ) + return + + # The dataset refresh is already triggered. Poll for the dataset refresh status. + @tenacity.retry( + stop=tenacity.stop_after_attempt(3), + wait=tenacity.wait_exponential(min=5, multiplier=2), + reraise=True, + retry=tenacity.retry_if_exception_type(PowerBIDatasetRefreshException), + ) + async def fetch_refresh_status_and_error() -> tuple[str, str]: + """Fetch the current status and error of the dataset refresh.""" + if self.dataset_refresh_id: + refresh_details = await self.hook.get_refresh_details_by_refresh_id( + dataset_id=self.dataset_id, + group_id=self.group_id, + refresh_id=self.dataset_refresh_id, + ) + return refresh_details["status"], refresh_details["error"] + + raise PowerBIDatasetRefreshException("Dataset refresh Id is missing.") try: dataset_refresh_status, dataset_refresh_error = await fetch_refresh_status_and_error() diff --git a/providers/tests/microsoft/azure/operators/test_powerbi.py b/providers/tests/microsoft/azure/operators/test_powerbi.py index a115b4c52dc50..7975411d50cc8 100644 --- a/providers/tests/microsoft/azure/operators/test_powerbi.py +++ b/providers/tests/microsoft/azure/operators/test_powerbi.py @@ -49,6 +49,13 @@ NEW_REFRESH_REQUEST_ID = "5e2d9921-e91b-491f-b7e1-e7d8db49194c" SUCCESS_TRIGGER_EVENT = { + "status": "success", + "dataset_refresh_status": None, + "message": "success", + "dataset_refresh_id": NEW_REFRESH_REQUEST_ID, +} + +SUCCESS_REFRESH_EVENT = { "status": "success", "dataset_refresh_status": PowerBIDatasetRefreshStatus.COMPLETED, "message": "success", @@ -88,6 +95,26 @@ def test_execute_wait_for_termination_with_deferrable(self, connection): operator.execute(context) assert isinstance(exc.value.trigger, PowerBITrigger) + assert exc.value.trigger.dataset_refresh_id is None + + @mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection) + def test_powerbi_operator_async_get_refresh_status_success(self, connection): + """Assert that get_refresh_status log success message""" + operator = PowerBIDatasetRefreshOperator( + **CONFIG, + ) + context = {"ti": MagicMock()} + context["ti"].task_id = TASK_ID + + with pytest.raises(TaskDeferred) as exc: + operator.get_refresh_status( + context=context, + event=SUCCESS_TRIGGER_EVENT, + ) + + assert isinstance(exc.value.trigger, PowerBITrigger) + assert exc.value.trigger.dataset_refresh_id is NEW_REFRESH_REQUEST_ID + assert context["ti"].xcom_push.call_count == 1 def test_powerbi_operator_async_execute_complete_success(self): """Assert that execute_complete log success message""" @@ -97,9 +124,9 @@ def test_powerbi_operator_async_execute_complete_success(self): context = {"ti": MagicMock()} operator.execute_complete( context=context, - event=SUCCESS_TRIGGER_EVENT, + event=SUCCESS_REFRESH_EVENT, ) - assert context["ti"].xcom_push.call_count == 2 + assert context["ti"].xcom_push.call_count == 1 def test_powerbi_operator_async_execute_complete_fail(self): """Assert that execute_complete raise exception on error""" @@ -117,7 +144,7 @@ def test_powerbi_operator_async_execute_complete_fail(self): "dataset_refresh_id": "1234", }, ) - assert context["ti"].xcom_push.call_count == 0 + assert context["ti"].xcom_push.call_count == 1 assert str(exc.value) == "error" def test_powerbi_operator_refresh_fail(self): @@ -136,7 +163,7 @@ def test_powerbi_operator_refresh_fail(self): "dataset_refresh_id": "1234", }, ) - assert context["ti"].xcom_push.call_count == 0 + assert context["ti"].xcom_push.call_count == 1 assert str(exc.value) == "error message" def test_execute_complete_no_event(self): diff --git a/providers/tests/microsoft/azure/triggers/test_powerbi.py b/providers/tests/microsoft/azure/triggers/test_powerbi.py index 303b7d06c8047..58bb3489fd56e 100644 --- a/providers/tests/microsoft/azure/triggers/test_powerbi.py +++ b/providers/tests/microsoft/azure/triggers/test_powerbi.py @@ -22,7 +22,10 @@ import pytest -from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus +from airflow.providers.microsoft.azure.hooks.powerbi import ( + PowerBIDatasetRefreshException, + PowerBIDatasetRefreshStatus, +) from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger from airflow.triggers.base import TriggerEvent @@ -46,6 +49,7 @@ def powerbi_trigger(timeout=TIMEOUT, check_interval=CHECK_INTERVAL) -> PowerBITr proxies=None, api_version=API_VERSION, dataset_id=DATASET_ID, + dataset_refresh_id=DATASET_REFRESH_ID, group_id=GROUP_ID, check_interval=check_interval, wait_for_termination=True, @@ -62,6 +66,7 @@ def test_powerbi_trigger_serialization(self, connection): proxies=None, api_version=API_VERSION, dataset_id=DATASET_ID, + dataset_refresh_id=DATASET_REFRESH_ID, group_id=GROUP_ID, check_interval=CHECK_INTERVAL, wait_for_termination=True, @@ -73,6 +78,7 @@ def test_powerbi_trigger_serialization(self, connection): assert kwargs == { "conn_id": POWERBI_CONN_ID, "dataset_id": DATASET_ID, + "dataset_refresh_id": DATASET_REFRESH_ID, "timeout": TIMEOUT, "group_id": GROUP_ID, "proxies": None, @@ -126,13 +132,32 @@ async def test_powerbi_trigger_run_failed( ) assert expected == actual + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_trigger_refresh(self, mock_trigger_dataset_refresh, powerbi_trigger): + """Assert event is triggered upon successful new refresh trigger.""" + powerbi_trigger.dataset_refresh_id = None + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "success", + "dataset_refresh_status": None, + "message": f"The dataset refresh {DATASET_REFRESH_ID} has been triggered.", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert len(task) == 1 + assert response in task + @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") async def test_powerbi_trigger_run_completed( self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger ): - """Assert event is triggered upon successful dataset refresh.""" + """Assert event is triggered upon successful dataset refresh completion.""" mock_get_refresh_details_by_refresh_id.return_value = { "status": PowerBIDatasetRefreshStatus.COMPLETED, "error": None, @@ -180,6 +205,35 @@ async def test_powerbi_trigger_run_exception_during_refresh_check_loop( assert response in task mock_cancel_dataset_refresh.assert_called_once() + @pytest.mark.asyncio + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") + @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh") + async def test_powerbi_trigger_run_PowerBIDatasetRefreshException_during_refresh_check_loop( + self, + mock_trigger_dataset_refresh, + mock_get_refresh_details_by_refresh_id, + mock_cancel_dataset_refresh, + powerbi_trigger, + ): + """Assert that run catch PowerBIDatasetRefreshException and triggers retry mechanism""" + mock_get_refresh_details_by_refresh_id.side_effect = PowerBIDatasetRefreshException("Test exception") + mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID + + task = [i async for i in powerbi_trigger.run()] + response = TriggerEvent( + { + "status": "error", + "dataset_refresh_status": None, + "message": "An error occurred: Test exception", + "dataset_refresh_id": DATASET_REFRESH_ID, + } + ) + assert mock_get_refresh_details_by_refresh_id.call_count == 3 + assert len(task) == 1 + assert response in task + assert mock_cancel_dataset_refresh.call_count == 1 + @pytest.mark.asyncio @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh") @mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id") @@ -228,7 +282,7 @@ async def test_powerbi_trigger_run_exception_without_refresh_id( { "status": "error", "dataset_refresh_status": None, - "message": "An error occurred: Test exception for no dataset_refresh_id", + "message": "Failed to trigger the dataset refresh.", "dataset_refresh_id": None, } )