From e3b0b864c3d66c966da01aefb6a6fc2f0681ba5f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 2 Jan 2025 17:58:35 +0800 Subject: [PATCH 1/2] Move Literal alias into TYPE_CHECKING block This allows us to always import Literal from typing, which has received quite some bug fixes between different versions. Doing this in the typing block avoids loading the runtime version at all, thus eliminating differences between runtime Python versions. --- airflow/auth/managers/base_auth_manager.py | 12 +++++----- .../commands/remote_commands/task_command.py | 9 ++++---- airflow/models/mappedoperator.py | 4 ++-- airflow/providers_manager.py | 4 ++-- .../airflow/providers/amazon/aws/utils/sqs.py | 10 +++++---- .../providers/docker/decorators/docker.py | 4 +++- .../providers/standard/operators/python.py | 6 +++-- .../providers/weaviate/hooks/weaviate.py | 2 +- .../core_api/routes/public/test_job.py | 22 +++++++++++-------- tests_common/_internals/capture_warnings.py | 9 +++++--- 10 files changed, 48 insertions(+), 34 deletions(-) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 345d22d536395..10be4385464c0 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -18,21 +18,21 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Container, Sequence -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from sqlalchemy import select from airflow.auth.managers.models.base_user import BaseUser -from airflow.auth.managers.models.resource_details import ( - DagDetails, -) +from airflow.auth.managers.models.resource_details import DagDetails from airflow.exceptions import AirflowException from airflow.models import DagModel from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: + from collections.abc import Container, Sequence + from typing import Literal + from fastapi import FastAPI from flask import Blueprint from flask_appbuilder.menu import MenuItem @@ -55,7 +55,7 @@ ) from airflow.cli.cli_config import CLICommand -ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] + ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] T = TypeVar("T", bound=BaseUser) diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index ad3c26ff56eda..6c591801515d5 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -28,7 +28,7 @@ import textwrap from collections.abc import Generator from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress -from typing import TYPE_CHECKING, Protocol, Union, cast +from typing import TYPE_CHECKING, Protocol, cast import pendulum from pendulum.parsing.exceptions import ParserError @@ -50,7 +50,6 @@ from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS -from airflow.typing_compat import Literal from airflow.utils import cli as cli_utils, timezone from airflow.utils.cli import ( get_dag, @@ -70,13 +69,15 @@ from airflow.utils.types import DagRunTriggeredByType if TYPE_CHECKING: + from typing import Literal + from sqlalchemy.orm.session import Session from airflow.models.operator import Operator -log = logging.getLogger(__name__) + CreateIfNecessary = Literal[False, "db", "memory"] -CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]] +log = logging.getLogger(__name__) def _generate_temporary_run_id() -> str: diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 4d362714794e3..f4728095037d4 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -53,7 +53,6 @@ from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.triggers.base import StartTriggerArgs -from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped from airflow.utils.helpers import is_container, prevent_duplicates from airflow.utils.task_instance_session import get_current_task_instance_session @@ -62,6 +61,7 @@ if TYPE_CHECKING: import datetime + from typing import Literal import jinja2 # Slow import. import pendulum @@ -89,7 +89,7 @@ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] -ValidationSource = Union[Literal["expand"], Literal["partial"]] + ValidationSource = Literal["expand", "partial"] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 8d9f93734d7eb..575306a840b79 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -86,12 +86,12 @@ def ensure_prefix(field): if TYPE_CHECKING: + from typing import Literal from urllib.parse import SplitResult from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.sdk.definitions.asset import Asset - from airflow.typing_compat import Literal class LazyDictWithCache(MutableMapping): @@ -201,7 +201,7 @@ class ProviderInfo: version: str data: dict - package_or_source: Literal["source"] | Literal["package"] + package_or_source: Literal["source", "package"] def __post_init__(self): if self.package_or_source not in ("source", "package"): diff --git a/providers/src/airflow/providers/amazon/aws/utils/sqs.py b/providers/src/airflow/providers/amazon/aws/utils/sqs.py index 293aa1b898d27..3c509454655a7 100644 --- a/providers/src/airflow/providers/amazon/aws/utils/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/utils/sqs.py @@ -14,20 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# from __future__ import annotations import json import logging -from typing import Any +from typing import TYPE_CHECKING, Any import jsonpath_ng import jsonpath_ng.ext -from typing_extensions import Literal -log = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Literal + MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"] -MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"] +log = logging.getLogger(__name__) def process_response( diff --git a/providers/src/airflow/providers/docker/decorators/docker.py b/providers/src/airflow/providers/docker/decorators/docker.py index 560028c16ed6c..77355ff03b2ca 100644 --- a/providers/src/airflow/providers/docker/decorators/docker.py +++ b/providers/src/airflow/providers/docker/decorators/docker.py @@ -20,7 +20,7 @@ import os from collections.abc import Sequence from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable from airflow.decorators.base import DecoratedOperator, task_decorator_factory from airflow.exceptions import AirflowException @@ -28,6 +28,8 @@ from airflow.providers.docker.operators.docker import DockerOperator if TYPE_CHECKING: + from typing import Literal + from airflow.decorators.base import TaskDecorator from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 40a0cb7a9223e..25de405a80c9b 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -51,7 +51,6 @@ AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, ) -from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge from airflow.utils.file import get_unique_dag_module_name @@ -61,10 +60,14 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: + from typing import Literal + from pendulum.datetime import DateTime from airflow.utils.context import Context + _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] + @cache def _parse_version_info(text: str) -> tuple[int, int, int, str, int]: @@ -343,7 +346,6 @@ def _load_cloudpickle(): return cloudpickle -_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] _SERIALIZERS: dict[_SerializerTypeDef, Any] = { "pickle": lazy_object_proxy.Proxy(_load_pickle), "dill": lazy_object_proxy.Proxy(_load_dill), diff --git a/providers/src/airflow/providers/weaviate/hooks/weaviate.py b/providers/src/airflow/providers/weaviate/hooks/weaviate.py index 716bc3e10e6be..e49cc58f02420 100644 --- a/providers/src/airflow/providers/weaviate/hooks/weaviate.py +++ b/providers/src/airflow/providers/weaviate/hooks/weaviate.py @@ -749,7 +749,7 @@ def create_or_replace_document_objects( verbose: bool = False, ) -> Sequence[dict[str, UUID | str] | None]: """ - create or replace objects belonging to documents. + Create or replace objects belonging to documents. In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues are considered 'documents' here. It's crucial to keep the database objects in sync with these sources. diff --git a/tests/api_fastapi/core_api/routes/public/test_job.py b/tests/api_fastapi/core_api/routes/public/test_job.py index f09c2b902c1c8..780d51f446785 100644 --- a/tests/api_fastapi/core_api/routes/public/test_job.py +++ b/tests/api_fastapi/core_api/routes/public/test_job.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING import pytest @@ -28,15 +28,19 @@ from tests_common.test_utils.db import clear_db_jobs from tests_common.test_utils.format_datetime import from_datetime_to_zulu +if TYPE_CHECKING: + from typing import Literal + + TestCase = Literal[ + "should_report_success_for_one_working_scheduler", + "should_report_success_for_one_working_scheduler_with_hostname", + "should_report_success_for_ha_schedulers", + "should_ignore_not_running_jobs", + "should_raise_exception_for_multiple_scheduler_on_one_host", + ] + pytestmark = pytest.mark.db_test -TESTCASE_TYPE = Literal[ - "should_report_success_for_one_working_scheduler", - "should_report_success_for_one_working_scheduler_with_hostname", - "should_report_success_for_ha_schedulers", - "should_ignore_not_running_jobs", - "should_raise_exception_for_multiple_scheduler_on_one_host", -] TESTCASE_ONE_SCHEDULER = "should_report_success_for_one_working_scheduler" TESTCASE_ONE_SCHEDULER_WITH_HOSTNAME = "should_report_success_for_one_working_scheduler_with_hostname" TESTCASE_HA_SCHEDULERS = "should_report_success_for_ha_schedulers" @@ -107,7 +111,7 @@ def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self, sessi scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback) @provide_session - def setup(self, testcase: TESTCASE_TYPE, session=None) -> None: + def setup(self, testcase: TestCase, session=None) -> None: """ Setup testcase at runtime based on the `testcase` provided by `pytest.mark.parametrize`. """ diff --git a/tests_common/_internals/capture_warnings.py b/tests_common/_internals/capture_warnings.py index d0cb2a0361581..a8e719c24df69 100644 --- a/tests_common/_internals/capture_warnings.py +++ b/tests_common/_internals/capture_warnings.py @@ -28,12 +28,15 @@ from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING, Callable import pytest -from typing_extensions import Literal -WhenTypeDef = Literal["config", "collect", "runtest"] +if TYPE_CHECKING: + from typing import Literal + + WhenTypeDef = Literal["config", "collect", "runtest"] + TESTS_DIR = Path(__file__).parents[1].resolve() From 462d88a9eb97bf11a616ce22c8c8fbe2721f60c5 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 2 Jan 2025 21:35:15 +0800 Subject: [PATCH 2/2] Fix type alias import --- airflow/api_fastapi/app.py | 5 ++++- airflow/api_fastapi/core_api/security.py | 6 ++++-- airflow/auth/managers/base_auth_manager.py | 6 ++++-- airflow/auth/managers/simple/simple_auth_manager.py | 3 ++- .../providers/amazon/aws/auth_manager/aws_auth_manager.py | 7 ++++--- providers/src/airflow/providers/amazon/aws/sensors/sqs.py | 3 ++- providers/src/airflow/providers/amazon/aws/triggers/sqs.py | 3 ++- .../airflow/providers/fab/auth_manager/fab_auth_manager.py | 3 ++- tests/auth/managers/test_base_auth_manager.py | 3 ++- 9 files changed, 26 insertions(+), 13 deletions(-) diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 3c9aa39ae707e..9cbb190d41120 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -18,6 +18,7 @@ import logging from contextlib import AsyncExitStack, asynccontextmanager +from typing import TYPE_CHECKING from fastapi import FastAPI from starlette.routing import Mount @@ -31,10 +32,12 @@ init_views, ) from airflow.api_fastapi.execution_api.app import create_task_execution_api_app -from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.configuration import conf from airflow.exceptions import AirflowConfigException +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import BaseAuthManager + log = logging.getLogger(__name__) app: FastAPI | None = None diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 30470e9b5da55..7aaee4d0fae7f 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -17,19 +17,21 @@ from __future__ import annotations from functools import cache -from typing import Annotated, Any, Callable +from typing import TYPE_CHECKING, Annotated, Any, Callable from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from jwt import InvalidTokenError from airflow.api_fastapi.app import get_auth_manager -from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 10be4385464c0..6a9ef11e3d785 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -26,12 +26,12 @@ from airflow.auth.managers.models.resource_details import DagDetails from airflow.exceptions import AirflowException from airflow.models import DagModel +from airflow.typing_compat import Literal from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from collections.abc import Container, Sequence - from typing import Literal from fastapi import FastAPI from flask import Blueprint @@ -55,7 +55,9 @@ ) from airflow.cli.cli_config import CLICommand - ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] +# This cannot be in the TYPE_CHECKING block since some providers import it globally. +# TODO: Move this inside once all providers drop Airflow 2.x support. +ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] T = TypeVar("T", bound=BaseUser) diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 5c411a5202d97..6ac5f34258749 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -27,7 +27,7 @@ from flask import session, url_for from termcolor import colored -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.auth.managers.simple.views.auth import SimpleAuthManagerAuthenticationViews from airflow.configuration import AIRFLOW_HOME, conf @@ -35,6 +35,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails, diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index ab82c6042ce3b..88f8cef8b76c4 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -25,7 +25,7 @@ from flask import session, url_for -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( AccessView, ConnectionDetails, @@ -53,6 +53,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, @@ -326,11 +327,11 @@ def filter_permitted_dag_ids( for method in ["GET", "PUT"]: if method in methods: request: IsAuthorizedRequest = { - "method": cast(ResourceMethod, method), + "method": cast("ResourceMethod", method), "entity_type": AvpEntities.DAG, "entity_id": dag_id, } - requests[dag_id][cast(ResourceMethod, method)] = request + requests[dag_id][cast("ResourceMethod", method)] = request requests_list.append(request) batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results( diff --git a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py index 006c5bf2ad20b..016e4fbd6c6d6 100644 --- a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py @@ -30,10 +30,11 @@ from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response +from airflow.providers.amazon.aws.utils.sqs import process_response if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py index 31f344b998209..28c0b509d28c0 100644 --- a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py @@ -22,11 +22,12 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sqs import SqsHook -from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response +from airflow.providers.amazon.aws.utils.sqs import process_response from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType class SqsSensorTrigger(BaseTrigger): diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 6c8bf51b7acf8..4c889a9c14e3f 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -34,7 +34,7 @@ from starlette.middleware.wsgi import WSGIMiddleware from airflow import __version__ as airflow_version -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( AccessView, ConfigurationDetails, @@ -94,6 +94,7 @@ from airflow.version import version if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index a6480e809a8e1..4406ae9d43607 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -21,7 +21,7 @@ import pytest -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( ConnectionDetails, @@ -34,6 +34,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails,