From b703d53b774960326b8d91963304bac3ca5d533c Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Thu, 9 Jan 2025 13:34:05 +0800 Subject: [PATCH] Move Literal alias into TYPE_CHECKING block (#45345) --- airflow/api_fastapi/app.py | 5 ++++- airflow/api_fastapi/core_api/security.py | 6 +++-- airflow/auth/managers/base_auth_manager.py | 12 +++++----- .../managers/simple/simple_auth_manager.py | 3 ++- .../commands/remote_commands/task_command.py | 9 ++++---- airflow/models/mappedoperator.py | 4 ++-- airflow/providers_manager.py | 4 ++-- .../aws/auth_manager/aws_auth_manager.py | 7 +++--- .../providers/amazon/aws/sensors/sqs.py | 3 ++- .../providers/amazon/aws/triggers/sqs.py | 3 ++- .../airflow/providers/amazon/aws/utils/sqs.py | 10 +++++---- .../providers/docker/decorators/docker.py | 4 +++- .../fab/auth_manager/fab_auth_manager.py | 3 ++- .../providers/standard/operators/python.py | 6 +++-- .../providers/weaviate/hooks/weaviate.py | 2 +- .../core_api/routes/public/test_job.py | 22 +++++++++++-------- tests/auth/managers/test_base_auth_manager.py | 3 ++- tests_common/_internals/capture_warnings.py | 9 +++++--- 18 files changed, 71 insertions(+), 44 deletions(-) diff --git a/airflow/api_fastapi/app.py b/airflow/api_fastapi/app.py index 3c9aa39ae707e..9cbb190d41120 100644 --- a/airflow/api_fastapi/app.py +++ b/airflow/api_fastapi/app.py @@ -18,6 +18,7 @@ import logging from contextlib import AsyncExitStack, asynccontextmanager +from typing import TYPE_CHECKING from fastapi import FastAPI from starlette.routing import Mount @@ -31,10 +32,12 @@ init_views, ) from airflow.api_fastapi.execution_api.app import create_task_execution_api_app -from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.configuration import conf from airflow.exceptions import AirflowConfigException +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import BaseAuthManager + log = logging.getLogger(__name__) app: FastAPI | None = None diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 30470e9b5da55..7aaee4d0fae7f 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -17,19 +17,21 @@ from __future__ import annotations from functools import cache -from typing import Annotated, Any, Callable +from typing import TYPE_CHECKING, Annotated, Any, Callable from fastapi import Depends, HTTPException from fastapi.security import OAuth2PasswordBearer from jwt import InvalidTokenError from airflow.api_fastapi.app import get_auth_manager -from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.configuration import conf from airflow.utils.jwt_signer import JWTSigner +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 345d22d536395..6a9ef11e3d785 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -18,21 +18,21 @@ from __future__ import annotations from abc import abstractmethod -from collections.abc import Container, Sequence -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar +from typing import TYPE_CHECKING, Any, Generic, TypeVar from sqlalchemy import select from airflow.auth.managers.models.base_user import BaseUser -from airflow.auth.managers.models.resource_details import ( - DagDetails, -) +from airflow.auth.managers.models.resource_details import DagDetails from airflow.exceptions import AirflowException from airflow.models import DagModel +from airflow.typing_compat import Literal from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: + from collections.abc import Container, Sequence + from fastapi import FastAPI from flask import Blueprint from flask_appbuilder.menu import MenuItem @@ -55,6 +55,8 @@ ) from airflow.cli.cli_config import CLICommand +# This cannot be in the TYPE_CHECKING block since some providers import it globally. +# TODO: Move this inside once all providers drop Airflow 2.x support. ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"] T = TypeVar("T", bound=BaseUser) diff --git a/airflow/auth/managers/simple/simple_auth_manager.py b/airflow/auth/managers/simple/simple_auth_manager.py index 5c411a5202d97..6ac5f34258749 100644 --- a/airflow/auth/managers/simple/simple_auth_manager.py +++ b/airflow/auth/managers/simple/simple_auth_manager.py @@ -27,7 +27,7 @@ from flask import session, url_for from termcolor import colored -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.simple.user import SimpleAuthManagerUser from airflow.auth.managers.simple.views.auth import SimpleAuthManagerAuthenticationViews from airflow.configuration import AIRFLOW_HOME, conf @@ -35,6 +35,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails, diff --git a/airflow/cli/commands/remote_commands/task_command.py b/airflow/cli/commands/remote_commands/task_command.py index ad3c26ff56eda..6c591801515d5 100644 --- a/airflow/cli/commands/remote_commands/task_command.py +++ b/airflow/cli/commands/remote_commands/task_command.py @@ -28,7 +28,7 @@ import textwrap from collections.abc import Generator from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress -from typing import TYPE_CHECKING, Protocol, Union, cast +from typing import TYPE_CHECKING, Protocol, cast import pendulum from pendulum.parsing.exceptions import ParserError @@ -50,7 +50,6 @@ from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS -from airflow.typing_compat import Literal from airflow.utils import cli as cli_utils, timezone from airflow.utils.cli import ( get_dag, @@ -70,13 +69,15 @@ from airflow.utils.types import DagRunTriggeredByType if TYPE_CHECKING: + from typing import Literal + from sqlalchemy.orm.session import Session from airflow.models.operator import Operator -log = logging.getLogger(__name__) + CreateIfNecessary = Literal[False, "db", "memory"] -CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]] +log = logging.getLogger(__name__) def _generate_temporary_run_id() -> str: diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 4d362714794e3..f4728095037d4 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -53,7 +53,6 @@ from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.triggers.base import StartTriggerArgs -from airflow.typing_compat import Literal from airflow.utils.context import context_update_for_unmapped from airflow.utils.helpers import is_container, prevent_duplicates from airflow.utils.task_instance_session import get_current_task_instance_session @@ -62,6 +61,7 @@ if TYPE_CHECKING: import datetime + from typing import Literal import jinja2 # Slow import. import pendulum @@ -89,7 +89,7 @@ TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]] -ValidationSource = Union[Literal["expand"], Literal["partial"]] + ValidationSource = Literal["expand", "partial"] def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None: diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py index 8d9f93734d7eb..575306a840b79 100644 --- a/airflow/providers_manager.py +++ b/airflow/providers_manager.py @@ -86,12 +86,12 @@ def ensure_prefix(field): if TYPE_CHECKING: + from typing import Literal from urllib.parse import SplitResult from airflow.decorators.base import TaskDecorator from airflow.hooks.base import BaseHook from airflow.sdk.definitions.asset import Asset - from airflow.typing_compat import Literal class LazyDictWithCache(MutableMapping): @@ -201,7 +201,7 @@ class ProviderInfo: version: str data: dict - package_or_source: Literal["source"] | Literal["package"] + package_or_source: Literal["source", "package"] def __post_init__(self): if self.package_or_source not in ("source", "package"): diff --git a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py index ab82c6042ce3b..88f8cef8b76c4 100644 --- a/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py +++ b/providers/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py @@ -25,7 +25,7 @@ from flask import session, url_for -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( AccessView, ConnectionDetails, @@ -53,6 +53,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.batch_apis import ( IsAuthorizedConnectionRequest, @@ -326,11 +327,11 @@ def filter_permitted_dag_ids( for method in ["GET", "PUT"]: if method in methods: request: IsAuthorizedRequest = { - "method": cast(ResourceMethod, method), + "method": cast("ResourceMethod", method), "entity_type": AvpEntities.DAG, "entity_id": dag_id, } - requests[dag_id][cast(ResourceMethod, method)] = request + requests[dag_id][cast("ResourceMethod", method)] = request requests_list.append(request) batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results( diff --git a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py index 006c5bf2ad20b..016e4fbd6c6d6 100644 --- a/providers/src/airflow/providers/amazon/aws/sensors/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/sensors/sqs.py @@ -30,10 +30,11 @@ from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields -from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response +from airflow.providers.amazon.aws.utils.sqs import process_response if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py index 31f344b998209..28c0b509d28c0 100644 --- a/providers/src/airflow/providers/amazon/aws/triggers/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/triggers/sqs.py @@ -22,11 +22,12 @@ from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.sqs import SqsHook -from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response +from airflow.providers.amazon.aws.utils.sqs import process_response from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection + from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType class SqsSensorTrigger(BaseTrigger): diff --git a/providers/src/airflow/providers/amazon/aws/utils/sqs.py b/providers/src/airflow/providers/amazon/aws/utils/sqs.py index 293aa1b898d27..3c509454655a7 100644 --- a/providers/src/airflow/providers/amazon/aws/utils/sqs.py +++ b/providers/src/airflow/providers/amazon/aws/utils/sqs.py @@ -14,20 +14,22 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# from __future__ import annotations import json import logging -from typing import Any +from typing import TYPE_CHECKING, Any import jsonpath_ng import jsonpath_ng.ext -from typing_extensions import Literal -log = logging.getLogger(__name__) +if TYPE_CHECKING: + from typing import Literal + MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"] -MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"] +log = logging.getLogger(__name__) def process_response( diff --git a/providers/src/airflow/providers/docker/decorators/docker.py b/providers/src/airflow/providers/docker/decorators/docker.py index 560028c16ed6c..77355ff03b2ca 100644 --- a/providers/src/airflow/providers/docker/decorators/docker.py +++ b/providers/src/airflow/providers/docker/decorators/docker.py @@ -20,7 +20,7 @@ import os from collections.abc import Sequence from tempfile import TemporaryDirectory -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable from airflow.decorators.base import DecoratedOperator, task_decorator_factory from airflow.exceptions import AirflowException @@ -28,6 +28,8 @@ from airflow.providers.docker.operators.docker import DockerOperator if TYPE_CHECKING: + from typing import Literal + from airflow.decorators.base import TaskDecorator from airflow.utils.context import Context diff --git a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py index 6c8bf51b7acf8..4c889a9c14e3f 100644 --- a/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py +++ b/providers/src/airflow/providers/fab/auth_manager/fab_auth_manager.py @@ -34,7 +34,7 @@ from starlette.middleware.wsgi import WSGIMiddleware from airflow import __version__ as airflow_version -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.resource_details import ( AccessView, ConfigurationDetails, @@ -94,6 +94,7 @@ from airflow.version import version if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, diff --git a/providers/src/airflow/providers/standard/operators/python.py b/providers/src/airflow/providers/standard/operators/python.py index 40a0cb7a9223e..25de405a80c9b 100644 --- a/providers/src/airflow/providers/standard/operators/python.py +++ b/providers/src/airflow/providers/standard/operators/python.py @@ -51,7 +51,6 @@ AIRFLOW_V_2_10_PLUS, AIRFLOW_V_3_0_PLUS, ) -from airflow.typing_compat import Literal from airflow.utils import hashlib_wrapper from airflow.utils.context import context_copy_partial, context_merge from airflow.utils.file import get_unique_dag_module_name @@ -61,10 +60,14 @@ log = logging.getLogger(__name__) if TYPE_CHECKING: + from typing import Literal + from pendulum.datetime import DateTime from airflow.utils.context import Context + _SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] + @cache def _parse_version_info(text: str) -> tuple[int, int, int, str, int]: @@ -343,7 +346,6 @@ def _load_cloudpickle(): return cloudpickle -_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"] _SERIALIZERS: dict[_SerializerTypeDef, Any] = { "pickle": lazy_object_proxy.Proxy(_load_pickle), "dill": lazy_object_proxy.Proxy(_load_dill), diff --git a/providers/src/airflow/providers/weaviate/hooks/weaviate.py b/providers/src/airflow/providers/weaviate/hooks/weaviate.py index 716bc3e10e6be..e49cc58f02420 100644 --- a/providers/src/airflow/providers/weaviate/hooks/weaviate.py +++ b/providers/src/airflow/providers/weaviate/hooks/weaviate.py @@ -749,7 +749,7 @@ def create_or_replace_document_objects( verbose: bool = False, ) -> Sequence[dict[str, UUID | str] | None]: """ - create or replace objects belonging to documents. + Create or replace objects belonging to documents. In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues are considered 'documents' here. It's crucial to keep the database objects in sync with these sources. diff --git a/tests/api_fastapi/core_api/routes/public/test_job.py b/tests/api_fastapi/core_api/routes/public/test_job.py index f09c2b902c1c8..780d51f446785 100644 --- a/tests/api_fastapi/core_api/routes/public/test_job.py +++ b/tests/api_fastapi/core_api/routes/public/test_job.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Literal +from typing import TYPE_CHECKING import pytest @@ -28,15 +28,19 @@ from tests_common.test_utils.db import clear_db_jobs from tests_common.test_utils.format_datetime import from_datetime_to_zulu +if TYPE_CHECKING: + from typing import Literal + + TestCase = Literal[ + "should_report_success_for_one_working_scheduler", + "should_report_success_for_one_working_scheduler_with_hostname", + "should_report_success_for_ha_schedulers", + "should_ignore_not_running_jobs", + "should_raise_exception_for_multiple_scheduler_on_one_host", + ] + pytestmark = pytest.mark.db_test -TESTCASE_TYPE = Literal[ - "should_report_success_for_one_working_scheduler", - "should_report_success_for_one_working_scheduler_with_hostname", - "should_report_success_for_ha_schedulers", - "should_ignore_not_running_jobs", - "should_raise_exception_for_multiple_scheduler_on_one_host", -] TESTCASE_ONE_SCHEDULER = "should_report_success_for_one_working_scheduler" TESTCASE_ONE_SCHEDULER_WITH_HOSTNAME = "should_report_success_for_one_working_scheduler_with_hostname" TESTCASE_HA_SCHEDULERS = "should_report_success_for_ha_schedulers" @@ -107,7 +111,7 @@ def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self, sessi scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback) @provide_session - def setup(self, testcase: TESTCASE_TYPE, session=None) -> None: + def setup(self, testcase: TestCase, session=None) -> None: """ Setup testcase at runtime based on the `testcase` provided by `pytest.mark.parametrize`. """ diff --git a/tests/auth/managers/test_base_auth_manager.py b/tests/auth/managers/test_base_auth_manager.py index a6480e809a8e1..4406ae9d43607 100644 --- a/tests/auth/managers/test_base_auth_manager.py +++ b/tests/auth/managers/test_base_auth_manager.py @@ -21,7 +21,7 @@ import pytest -from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod +from airflow.auth.managers.base_auth_manager import BaseAuthManager from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( ConnectionDetails, @@ -34,6 +34,7 @@ if TYPE_CHECKING: from flask_appbuilder.menu import MenuItem + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.models.resource_details import ( AccessView, AssetDetails, diff --git a/tests_common/_internals/capture_warnings.py b/tests_common/_internals/capture_warnings.py index d0cb2a0361581..a8e719c24df69 100644 --- a/tests_common/_internals/capture_warnings.py +++ b/tests_common/_internals/capture_warnings.py @@ -28,12 +28,15 @@ from contextlib import contextmanager from dataclasses import asdict, dataclass from pathlib import Path -from typing import Callable +from typing import TYPE_CHECKING, Callable import pytest -from typing_extensions import Literal -WhenTypeDef = Literal["config", "collect", "runtest"] +if TYPE_CHECKING: + from typing import Literal + + WhenTypeDef = Literal["config", "collect", "runtest"] + TESTS_DIR = Path(__file__).parents[1].resolve()