Skip to content

Commit

Permalink
Fixed retry of PowerBIDatasetRefreshOperator when dataset refresh was…
Browse files Browse the repository at this point in the history
…n't directly available (#45513)
  • Loading branch information
Ohashiro authored Jan 26, 2025
1 parent 7cc4202 commit 6b87d1c
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,35 @@ def execute(self, context: Context):
check_interval=self.check_interval,
wait_for_termination=self.wait_for_termination,
),
method_name=self.get_refresh_status.__name__,
)

def get_refresh_status(self, context: Context, event: dict[str, str] | None = None):
"""Push the refresh Id to XCom then runs the Trigger to wait for refresh completion."""
if event:
if event["status"] == "error":
raise AirflowException(event["message"])

dataset_refresh_id = event["dataset_refresh_id"]

if dataset_refresh_id:
self.xcom_push(
context=context,
key=f"{self.task_id}.powerbi_dataset_refresh_Id",
value=dataset_refresh_id,
)
self.defer(
trigger=PowerBITrigger(
conn_id=self.conn_id,
group_id=self.group_id,
dataset_id=self.dataset_id,
dataset_refresh_id=dataset_refresh_id,
timeout=self.timeout,
proxies=self.proxies,
api_version=self.api_version,
check_interval=self.check_interval,
wait_for_termination=self.wait_for_termination,
),
method_name=self.execute_complete.__name__,
)

Expand All @@ -124,10 +153,10 @@ def execute_complete(self, context: Context, event: dict[str, str]) -> Any:
Relies on trigger to throw an exception, otherwise it assumes execution was successful.
"""
if event:
if event["status"] == "error":
raise AirflowException(event["message"])

self.xcom_push(
context=context, key="powerbi_dataset_refresh_Id", value=event["dataset_refresh_id"]
context=context,
key=f"{self.task_id}.powerbi_dataset_refresh_status",
value=event["dataset_refresh_status"],
)
self.xcom_push(context=context, key="powerbi_dataset_refresh_status", value=event["status"])
if event["status"] == "error":
raise AirflowException(event["message"])
66 changes: 55 additions & 11 deletions providers/src/airflow/providers/microsoft/azure/triggers/powerbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus, PowerBIHook
import tenacity

from airflow.providers.microsoft.azure.hooks.powerbi import (
PowerBIDatasetRefreshException,
PowerBIDatasetRefreshStatus,
PowerBIHook,
)
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
Expand All @@ -43,6 +49,7 @@ class PowerBITrigger(BaseTrigger):
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as `v1.0` or `beta`.
:param dataset_id: The dataset Id to refresh.
:param dataset_refresh_id: The dataset refresh Id to poll for the status, if not provided a new refresh will be triggered.
:param group_id: The workspace Id where dataset is located.
:param end_time: Time in seconds when trigger should stop polling.
:param check_interval: Time in seconds to wait between each poll.
Expand All @@ -55,6 +62,7 @@ def __init__(
dataset_id: str,
group_id: str,
timeout: float = 60 * 60 * 24 * 7,
dataset_refresh_id: str | None = None,
proxies: dict | None = None,
api_version: APIVersion | str | None = None,
check_interval: int = 60,
Expand All @@ -63,6 +71,7 @@ def __init__(
super().__init__()
self.hook = PowerBIHook(conn_id=conn_id, proxies=proxies, api_version=api_version, timeout=timeout)
self.dataset_id = dataset_id
self.dataset_refresh_id = dataset_refresh_id
self.timeout = timeout
self.group_id = group_id
self.check_interval = check_interval
Expand All @@ -77,6 +86,7 @@ def serialize(self):
"proxies": self.proxies,
"api_version": self.api_version,
"dataset_id": self.dataset_id,
"dataset_refresh_id": self.dataset_refresh_id,
"group_id": self.group_id,
"timeout": self.timeout,
"check_interval": self.check_interval,
Expand All @@ -98,19 +108,53 @@ def api_version(self) -> APIVersion | str:

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Make async connection to the PowerBI and polls for the dataset refresh status."""
self.dataset_refresh_id = await self.hook.trigger_dataset_refresh(
dataset_id=self.dataset_id,
group_id=self.group_id,
)

async def fetch_refresh_status_and_error() -> tuple[str, str]:
"""Fetch the current status and error of the dataset refresh."""
refresh_details = await self.hook.get_refresh_details_by_refresh_id(
if not self.dataset_refresh_id:
# Trigger the dataset refresh
dataset_refresh_id = await self.hook.trigger_dataset_refresh(
dataset_id=self.dataset_id,
group_id=self.group_id,
refresh_id=self.dataset_refresh_id,
)
return refresh_details["status"], refresh_details["error"]

if dataset_refresh_id:
self.log.info("Triggered dataset refresh %s", dataset_refresh_id)
yield TriggerEvent(
{
"status": "success",
"dataset_refresh_status": None,
"message": f"The dataset refresh {dataset_refresh_id} has been triggered.",
"dataset_refresh_id": dataset_refresh_id,
}
)
return

yield TriggerEvent(
{
"status": "error",
"dataset_refresh_status": None,
"message": "Failed to trigger the dataset refresh.",
"dataset_refresh_id": None,
}
)
return

# The dataset refresh is already triggered. Poll for the dataset refresh status.
@tenacity.retry(
stop=tenacity.stop_after_attempt(3),
wait=tenacity.wait_exponential(min=5, multiplier=2),
reraise=True,
retry=tenacity.retry_if_exception_type(PowerBIDatasetRefreshException),
)
async def fetch_refresh_status_and_error() -> tuple[str, str]:
"""Fetch the current status and error of the dataset refresh."""
if self.dataset_refresh_id:
refresh_details = await self.hook.get_refresh_details_by_refresh_id(
dataset_id=self.dataset_id,
group_id=self.group_id,
refresh_id=self.dataset_refresh_id,
)
return refresh_details["status"], refresh_details["error"]

raise PowerBIDatasetRefreshException("Dataset refresh Id is missing.")

try:
dataset_refresh_status, dataset_refresh_error = await fetch_refresh_status_and_error()
Expand Down
35 changes: 31 additions & 4 deletions providers/tests/microsoft/azure/operators/test_powerbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,13 @@
NEW_REFRESH_REQUEST_ID = "5e2d9921-e91b-491f-b7e1-e7d8db49194c"

SUCCESS_TRIGGER_EVENT = {
"status": "success",
"dataset_refresh_status": None,
"message": "success",
"dataset_refresh_id": NEW_REFRESH_REQUEST_ID,
}

SUCCESS_REFRESH_EVENT = {
"status": "success",
"dataset_refresh_status": PowerBIDatasetRefreshStatus.COMPLETED,
"message": "success",
Expand Down Expand Up @@ -88,6 +95,26 @@ def test_execute_wait_for_termination_with_deferrable(self, connection):
operator.execute(context)

assert isinstance(exc.value.trigger, PowerBITrigger)
assert exc.value.trigger.dataset_refresh_id is None

@mock.patch("airflow.hooks.base.BaseHook.get_connection", side_effect=get_airflow_connection)
def test_powerbi_operator_async_get_refresh_status_success(self, connection):
"""Assert that get_refresh_status log success message"""
operator = PowerBIDatasetRefreshOperator(
**CONFIG,
)
context = {"ti": MagicMock()}
context["ti"].task_id = TASK_ID

with pytest.raises(TaskDeferred) as exc:
operator.get_refresh_status(
context=context,
event=SUCCESS_TRIGGER_EVENT,
)

assert isinstance(exc.value.trigger, PowerBITrigger)
assert exc.value.trigger.dataset_refresh_id is NEW_REFRESH_REQUEST_ID
assert context["ti"].xcom_push.call_count == 1

def test_powerbi_operator_async_execute_complete_success(self):
"""Assert that execute_complete log success message"""
Expand All @@ -97,9 +124,9 @@ def test_powerbi_operator_async_execute_complete_success(self):
context = {"ti": MagicMock()}
operator.execute_complete(
context=context,
event=SUCCESS_TRIGGER_EVENT,
event=SUCCESS_REFRESH_EVENT,
)
assert context["ti"].xcom_push.call_count == 2
assert context["ti"].xcom_push.call_count == 1

def test_powerbi_operator_async_execute_complete_fail(self):
"""Assert that execute_complete raise exception on error"""
Expand All @@ -117,7 +144,7 @@ def test_powerbi_operator_async_execute_complete_fail(self):
"dataset_refresh_id": "1234",
},
)
assert context["ti"].xcom_push.call_count == 0
assert context["ti"].xcom_push.call_count == 1
assert str(exc.value) == "error"

def test_powerbi_operator_refresh_fail(self):
Expand All @@ -136,7 +163,7 @@ def test_powerbi_operator_refresh_fail(self):
"dataset_refresh_id": "1234",
},
)
assert context["ti"].xcom_push.call_count == 0
assert context["ti"].xcom_push.call_count == 1
assert str(exc.value) == "error message"

def test_execute_complete_no_event(self):
Expand Down
60 changes: 57 additions & 3 deletions providers/tests/microsoft/azure/triggers/test_powerbi.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

import pytest

from airflow.providers.microsoft.azure.hooks.powerbi import PowerBIDatasetRefreshStatus
from airflow.providers.microsoft.azure.hooks.powerbi import (
PowerBIDatasetRefreshException,
PowerBIDatasetRefreshStatus,
)
from airflow.providers.microsoft.azure.triggers.powerbi import PowerBITrigger
from airflow.triggers.base import TriggerEvent

Expand All @@ -46,6 +49,7 @@ def powerbi_trigger(timeout=TIMEOUT, check_interval=CHECK_INTERVAL) -> PowerBITr
proxies=None,
api_version=API_VERSION,
dataset_id=DATASET_ID,
dataset_refresh_id=DATASET_REFRESH_ID,
group_id=GROUP_ID,
check_interval=check_interval,
wait_for_termination=True,
Expand All @@ -62,6 +66,7 @@ def test_powerbi_trigger_serialization(self, connection):
proxies=None,
api_version=API_VERSION,
dataset_id=DATASET_ID,
dataset_refresh_id=DATASET_REFRESH_ID,
group_id=GROUP_ID,
check_interval=CHECK_INTERVAL,
wait_for_termination=True,
Expand All @@ -73,6 +78,7 @@ def test_powerbi_trigger_serialization(self, connection):
assert kwargs == {
"conn_id": POWERBI_CONN_ID,
"dataset_id": DATASET_ID,
"dataset_refresh_id": DATASET_REFRESH_ID,
"timeout": TIMEOUT,
"group_id": GROUP_ID,
"proxies": None,
Expand Down Expand Up @@ -126,13 +132,32 @@ async def test_powerbi_trigger_run_failed(
)
assert expected == actual

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
async def test_powerbi_trigger_run_trigger_refresh(self, mock_trigger_dataset_refresh, powerbi_trigger):
"""Assert event is triggered upon successful new refresh trigger."""
powerbi_trigger.dataset_refresh_id = None
mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID

task = [i async for i in powerbi_trigger.run()]
response = TriggerEvent(
{
"status": "success",
"dataset_refresh_status": None,
"message": f"The dataset refresh {DATASET_REFRESH_ID} has been triggered.",
"dataset_refresh_id": DATASET_REFRESH_ID,
}
)
assert len(task) == 1
assert response in task

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
async def test_powerbi_trigger_run_completed(
self, mock_trigger_dataset_refresh, mock_get_refresh_details_by_refresh_id, powerbi_trigger
):
"""Assert event is triggered upon successful dataset refresh."""
"""Assert event is triggered upon successful dataset refresh completion."""
mock_get_refresh_details_by_refresh_id.return_value = {
"status": PowerBIDatasetRefreshStatus.COMPLETED,
"error": None,
Expand Down Expand Up @@ -180,6 +205,35 @@ async def test_powerbi_trigger_run_exception_during_refresh_check_loop(
assert response in task
mock_cancel_dataset_refresh.assert_called_once()

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.trigger_dataset_refresh")
async def test_powerbi_trigger_run_PowerBIDatasetRefreshException_during_refresh_check_loop(
self,
mock_trigger_dataset_refresh,
mock_get_refresh_details_by_refresh_id,
mock_cancel_dataset_refresh,
powerbi_trigger,
):
"""Assert that run catch PowerBIDatasetRefreshException and triggers retry mechanism"""
mock_get_refresh_details_by_refresh_id.side_effect = PowerBIDatasetRefreshException("Test exception")
mock_trigger_dataset_refresh.return_value = DATASET_REFRESH_ID

task = [i async for i in powerbi_trigger.run()]
response = TriggerEvent(
{
"status": "error",
"dataset_refresh_status": None,
"message": "An error occurred: Test exception",
"dataset_refresh_id": DATASET_REFRESH_ID,
}
)
assert mock_get_refresh_details_by_refresh_id.call_count == 3
assert len(task) == 1
assert response in task
assert mock_cancel_dataset_refresh.call_count == 1

@pytest.mark.asyncio
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.cancel_dataset_refresh")
@mock.patch(f"{MODULE}.hooks.powerbi.PowerBIHook.get_refresh_details_by_refresh_id")
Expand Down Expand Up @@ -228,7 +282,7 @@ async def test_powerbi_trigger_run_exception_without_refresh_id(
{
"status": "error",
"dataset_refresh_status": None,
"message": "An error occurred: Test exception for no dataset_refresh_id",
"message": "Failed to trigger the dataset refresh.",
"dataset_refresh_id": None,
}
)
Expand Down

0 comments on commit 6b87d1c

Please sign in to comment.