Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor result_processor and event_handler signatures in MSGraphAsyncOperator #46637

Merged
merged 38 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
b3faafa
refactor: Refactor result_processor and event_handler signatures so t…
davidblain-infrabel Feb 11, 2025
4ae6aaf
refactor: Reformatted TestMSGraphAsyncOperator
Feb 11, 2025
5ba42c7
refactor: Ignore types
Feb 11, 2025
737ae14
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 11, 2025
1674746
refactor: Ignore types
Feb 11, 2025
e31d2d8
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 11, 2025
98b6b2d
refactor: Changed types xcom helper methods
davidblain-infrabel Feb 11, 2025
ad6d122
refactor: Fixed MSGraphSensor
davidblain-infrabel Feb 11, 2025
33aa735
refactor: Fixed TestMSGraphSensor
Feb 11, 2025
8ac5296
refactor: Changed type pull_xcom
Feb 11, 2025
080b829
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 11, 2025
39c57c7
refactor: Removed white line
Feb 12, 2025
39a4b44
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 12, 2025
ee57188
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 12, 2025
12cb9af
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 12, 2025
7fb6448
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 13, 2025
cb38423
refactor: Reduced timeout of sensor in tests to 5 seconds
davidblain-infrabel Feb 13, 2025
a7998e8
refactor: Extracted common execute_callback method for the handlers a…
davidblain-infrabel Feb 13, 2025
5fc806e
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 13, 2025
dc9b753
refactor: Changed types of execute_callable
davidblain-infrabel Feb 13, 2025
f99ac77
refactor: Changed func parameter type of execute_callable
davidblain-infrabel Feb 14, 2025
44bd449
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 14, 2025
009b1aa
refactor: Added test for execute_callable
davidblain-infrabel Feb 14, 2025
ac14d96
refactor: Changed func type for execute_callable as it still fails
davidblain-infrabel Feb 14, 2025
4ebfc63
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 14, 2025
6ba568d
refactor: Added test for request_information with custom host in Test…
davidblain-infrabel Feb 14, 2025
a9801b1
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 14, 2025
6d31fbf
refactor: Ignore types in execute_callable
davidblain-infrabel Feb 14, 2025
b05ee9d
refactor: Reformatted test_request_information_with_custom_host
davidblain-infrabel Feb 14, 2025
9e3100c
refactor: Reformatted test_execute_callable
davidblain-infrabel Feb 14, 2025
3fbbd66
refactor: Context should not be in type checking block
davidblain-infrabel Feb 17, 2025
221464d
refactor: Ignore type event_handler
davidblain-infrabel Feb 17, 2025
a096d2f
refactor: Context should be type checked just like in main
davidblain-infrabel Feb 17, 2025
7a7427d
Merge branch 'main' into feature/msgraph-refactor-context
davidblain-infrabel Feb 17, 2025
5c93ee8
refactor: Fixed TestMSGraphAsyncOperator
davidblain-infrabel Feb 17, 2025
3121963
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 18, 2025
60f5131
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 18, 2025
9950817
Merge branch 'main' into feature/msgraph-refactor-context
dabla Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def get_host(self, connection: Connection) -> str:
@staticmethod
def format_no_proxy_url(url: str) -> str:
if "://" not in url:
url = f"all://{url}"
return f"all://{url}"
return url

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import (
Expand All @@ -25,7 +26,7 @@
Callable,
)

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred
from airflow.models import BaseOperator
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.triggers.msgraph import (
Expand All @@ -44,14 +45,31 @@
from airflow.utils.context import Context


def default_event_handler(context: Context, event: dict[Any, Any] | None = None) -> Any:
def default_event_handler(event: dict[Any, Any] | None = None, **context) -> Any:
if event:
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

return event.get("response")


def execute_callable(
func: Callable[[dict[Any, Any] | None, Context], Any] | Callable[[dict[Any, Any] | None, Any], Any],
value: Any,
context: Context,
message: str,
) -> Any:
try:
return func(value, **context) # type: ignore
except TypeError:
warnings.warn(
message,
AirflowProviderDeprecationWarning,
stacklevel=2,
)
return func(context, value) # type: ignore


class MSGraphAsyncOperator(BaseOperator):
"""
A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API.
Expand All @@ -76,7 +94,7 @@ class MSGraphAsyncOperator(BaseOperator):
You can pass an enum named APIVersion which has 2 possible members v1 and beta,
or you can pass a string as `v1.0` or `beta`.
:param result_processor: Function to further process the response from MS Graph API
(default is lambda: context, response: response). When the response returned by the
(default is lambda: response, context: response). When the response returned by the
`KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string.
:param event_handler: Function to process the event returned from `MSGraphTrigger`. By default, when the
event returned by the `MSGraphTrigger` has a failed status, an AirflowException is being raised with
Expand Down Expand Up @@ -114,8 +132,8 @@ def __init__(
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None,
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
event_handler: Callable[[Context, dict[Any, Any] | None], Any] | None = None,
result_processor: Callable[[Any, Context], Any] = lambda result, **context: result,
event_handler: Callable[[dict[Any, Any] | None, Context], Any] | None = None,
serializer: type[ResponseSerializer] = ResponseSerializer,
**kwargs: Any,
):
Expand Down Expand Up @@ -175,7 +193,12 @@ def execute_complete(
if event:
self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event)

response = self.event_handler(context, event)
response = execute_callable(
self.event_handler, # type: ignore
event,
context,
"event_handler signature has changed, event parameter should be defined before context!",
)

self.log.debug("response: %s", response)

Expand All @@ -186,7 +209,12 @@ def execute_complete(

self.log.debug("deserialize response: %s", response)

result = self.result_processor(context, response)
result = execute_callable(
self.result_processor,
response,
context,
"result_processor signature has changed, result parameter should be defined before context!",
)

self.log.debug("processed response: %s", result)

Expand Down Expand Up @@ -234,13 +262,14 @@ def append_result(
return result
return results

def pull_xcom(self, context: Context) -> list:
def pull_xcom(self, context: Context | dict[str, Any]) -> list:
map_index = context["ti"].map_index
value = list(
context["ti"].xcom_pull(
key=self.key,
task_ids=self.task_id,
dag_id=self.dag_id,
map_indexes=map_index,
)
or []
)
Expand All @@ -265,15 +294,15 @@ def pull_xcom(self, context: Context) -> list:

return value

def push_xcom(self, context: Context, value) -> None:
def push_xcom(self, context: Any, value) -> None:
self.log.debug("do_xcom_push: %s", self.do_xcom_push)
if self.do_xcom_push:
self.log.info("Pushing XCom with key '%s': %s", self.key, value)
self.xcom_push(context=context, key=self.key, value=value)

@staticmethod
def paginate(
operator: MSGraphAsyncOperator, response: dict, context: Context
operator: MSGraphAsyncOperator, response: dict, **context
) -> tuple[Any, dict[str, Any] | None]:
odata_count = response.get("@odata.count")
if odata_count and operator.query_parameters:
Expand All @@ -282,15 +311,15 @@ def paginate(

if top and odata_count:
if len(response.get("value", [])) == top and context:
results = operator.pull_xcom(context=context)
results = operator.pull_xcom(context)
skip = sum([len(result["value"]) for result in results]) + top if results else top # type: ignore
query_parameters["$skip"] = skip
return operator.url, query_parameters
return response.get("@odata.nextLink"), operator.query_parameters

def trigger_next_link(self, response, method_name: str, context: Context) -> None:
if isinstance(response, dict):
url, query_parameters = self.pagination_function(self, response, context)
url, query_parameters = self.pagination_function(self, response, **dict(context.items())) # type: ignore

self.log.debug("url: %s", url)
self.log.debug("query_parameters: %s", query_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.common.compat.standard.triggers import TimeDeltaTrigger
from airflow.providers.microsoft.azure.hooks.msgraph import KiotaRequestAdapterHook
from airflow.providers.microsoft.azure.operators.msgraph import execute_callable
from airflow.providers.microsoft.azure.triggers.msgraph import MSGraphTrigger, ResponseSerializer
from airflow.sensors.base import BaseSensorOperator

Expand Down Expand Up @@ -55,7 +56,7 @@ class MSGraphSensor(BaseSensorOperator):
`default_event_processor` method) and returns a boolean. When the result is True, the sensor
will stop poking, otherwise it will continue until it's True or times out.
:param result_processor: Function to further process the response from MS Graph API
(default is lambda: context, response: response). When the response returned by the
(default is lambda: response, context: response). When the response returned by the
`KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string.
:param serializer: Class which handles response serialization (default is ResponseSerializer).
Bytes will be base64 encoded into a string, so it can be stored as an XCom.
Expand Down Expand Up @@ -86,8 +87,8 @@ def __init__(
proxies: dict | None = None,
scopes: str | list[str] | None = None,
api_version: APIVersion | str | None = None,
event_processor: Callable[[Context, Any], bool] = lambda context, e: e.get("status") == "Succeeded",
result_processor: Callable[[Context, Any], Any] = lambda context, result: result,
event_processor: Callable[[Any, Context], bool] = lambda e, **context: e.get("status") == "Succeeded",
result_processor: Callable[[Any, Context], Any] = lambda result, **context: result,
serializer: type[ResponseSerializer] = ResponseSerializer,
retry_delay: timedelta | float = 60,
**kwargs,
Expand Down Expand Up @@ -164,12 +165,22 @@ def execute_complete(

self.log.debug("deserialize response: %s", response)

is_done = self.event_processor(context, response)
is_done = execute_callable(
self.event_processor,
response,
context,
"event_processor signature has changed, event parameter should be defined before context!",
)

self.log.debug("is_done: %s", is_done)

if is_done:
result = self.result_processor(context, response)
result = execute_callable(
self.result_processor,
response,
context,
"result_processor signature has changed, result parameter should be defined before context!",
)

self.log.debug("processed response: %s", result)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pytest
from httpx import Response
from httpx._utils import URLPattern
from kiota_abstractions.request_information import RequestInformation
from kiota_http.httpx_request_adapter import HttpxRequestAdapter
from kiota_serialization_json.json_parse_node import JsonParseNode
from kiota_serialization_text.text_parse_node import TextParseNode
Expand Down Expand Up @@ -271,6 +272,26 @@ def test_encoded_query_parameters(self):

assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000}

def test_request_information_with_custom_host(self):
connection = lambda conn_id: get_airflow_connection(
conn_id=conn_id,
host="api.fabric.microsoft.com",
api_version="v1",
)

with patch(
"airflow.hooks.base.BaseHook.get_connection",
side_effect=connection,
):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
request_info = hook.request_information(url="myorg/admin/apps", query_parameters={"$top": 5000})
request_adapter = hook.get_conn()
request_adapter.set_base_url_for_request_information(request_info)

assert isinstance(request_info, RequestInformation)
assert isinstance(request_adapter, HttpxRequestAdapter)
assert request_info.url == "https://api.fabric.microsoft.com/v1/myorg/admin/apps?%24top=5000"

@pytest.mark.asyncio
async def test_throw_failed_responses_with_text_plain_content_type(self):
with patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.operators.msgraph import MSGraphAsyncOperator, execute_callable
from airflow.triggers.base import TriggerEvent
from airflow.utils import timezone
from unit.microsoft.azure.base import Base
from unit.microsoft.azure.test_utils import mock_json_response, mock_response

Expand All @@ -45,7 +46,7 @@

class TestMSGraphAsyncOperator(Base):
@pytest.mark.db_test
def test_execute(self):
def test_execute_with_old_result_processor_signature(self):
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
next_users = load_json_from_resources(dirname(__file__), "..", "resources", "next_users.json")
response = mock_json_response(200, users, next_users)
Expand All @@ -58,6 +59,38 @@ def test_execute(self):
result_processor=lambda context, result: result.get("value"),
)

with pytest.warns(
AirflowProviderDeprecationWarning,
match="result_processor signature has changed, result parameter should be defined before context!",
):
results, events = execute_operator(operator)

assert len(results) == 30
assert results == users.get("value") + next_users.get("value")
assert len(events) == 2
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "success"
assert events[0].payload["type"] == "builtins.dict"
assert events[0].payload["response"] == json.dumps(users)
assert isinstance(events[1], TriggerEvent)
assert events[1].payload["status"] == "success"
assert events[1].payload["type"] == "builtins.dict"
assert events[1].payload["response"] == json.dumps(next_users)

@pytest.mark.db_test
def test_execute_with_new_result_processor_signature(self):
users = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
next_users = load_json_from_resources(dirname(__file__), "..", "resources", "next_users.json")
response = mock_json_response(200, users, next_users)

with self.patch_hook_and_request_adapter(response):
operator = MSGraphAsyncOperator(
task_id="users_delta",
conn_id="msgraph_api",
url="users",
result_processor=lambda result, **context: result.get("value"),
)

results, events = execute_operator(operator)

assert len(results) == 30
Expand Down Expand Up @@ -109,7 +142,7 @@ def test_execute_when_an_exception_occurs(self):
execute_operator(operator)

@pytest.mark.db_test
def test_execute_when_an_exception_occurs_on_custom_event_handler(self):
def test_execute_when_an_exception_occurs_on_custom_event_handler_with_old_signature(self):
with self.patch_hook_and_request_adapter(AirflowException("An error occurred")):

def custom_event_handler(context: Context, event: dict[Any, Any] | None = None):
Expand All @@ -126,6 +159,36 @@ def custom_event_handler(context: Context, event: dict[Any, Any] | None = None):
event_handler=custom_event_handler,
)

with pytest.warns(
AirflowProviderDeprecationWarning,
match="event_handler signature has changed, event parameter should be defined before context!",
):
results, events = execute_operator(operator)

assert not results
assert len(events) == 1
assert isinstance(events[0], TriggerEvent)
assert events[0].payload["status"] == "failure"
assert events[0].payload["message"] == "An error occurred"

@pytest.mark.db_test
def test_execute_when_an_exception_occurs_on_custom_event_handler_with_new_signature(self):
with self.patch_hook_and_request_adapter(AirflowException("An error occurred")):

def custom_event_handler(event: dict[Any, Any] | None = None, **context):
if event:
if event.get("status") == "failure":
return None

return event.get("response")

operator = MSGraphAsyncOperator(
task_id="users_delta",
conn_id="msgraph_api",
url="users/delta",
event_handler=custom_event_handler,
)

results, events = execute_operator(operator)

assert not results
Expand Down Expand Up @@ -209,7 +272,7 @@ def test_paginate_without_query_parameters(self):
)
context = mock_context(task=operator)
response = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
next_link, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)
next_link, query_parameters = MSGraphAsyncOperator.paginate(operator, response, **context)

assert next_link == response["@odata.nextLink"]
assert query_parameters is None
Expand All @@ -224,7 +287,31 @@ def test_paginate_with_context_query_parameters(self):
context = mock_context(task=operator)
response = load_json_from_resources(dirname(__file__), "..", "resources", "users.json")
response["@odata.count"] = 100
url, query_parameters = MSGraphAsyncOperator.paginate(operator, response, context)
url, query_parameters = MSGraphAsyncOperator.paginate(operator, response, **context)

assert url == "users"
assert query_parameters == {"$skip": 12, "$top": 12}

def test_execute_callable(self):
with pytest.warns(
AirflowProviderDeprecationWarning,
match="result_processor signature has changed, result parameter should be defined before context!",
):
assert (
execute_callable(
lambda context, response: response,
"response",
{"execution_date": timezone.utcnow()},
"result_processor signature has changed, result parameter should be defined before context!",
)
== "response"
)
assert (
execute_callable(
lambda response, **context: response,
"response",
{"execution_date": timezone.utcnow()},
"result_processor signature has changed, result parameter should be defined before context!",
)
== "response"
)
Loading
Loading