Skip to content

Commit

Permalink
Move Literal alias into TYPE_CHECKING block (#45345)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Jan 9, 2025
1 parent c2feef6 commit b703d53
Show file tree
Hide file tree
Showing 18 changed files with 71 additions and 44 deletions.
5 changes: 4 additions & 1 deletion airflow/api_fastapi/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import logging
from contextlib import AsyncExitStack, asynccontextmanager
from typing import TYPE_CHECKING

from fastapi import FastAPI
from starlette.routing import Mount
Expand All @@ -31,10 +32,12 @@
init_views,
)
from airflow.api_fastapi.execution_api.app import create_task_execution_api_app
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.configuration import conf
from airflow.exceptions import AirflowConfigException

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import BaseAuthManager

log = logging.getLogger(__name__)

app: FastAPI | None = None
Expand Down
6 changes: 4 additions & 2 deletions airflow/api_fastapi/core_api/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@
from __future__ import annotations

from functools import cache
from typing import Annotated, Any, Callable
from typing import TYPE_CHECKING, Annotated, Any, Callable

from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer
from jwt import InvalidTokenError

from airflow.api_fastapi.app import get_auth_manager
from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.configuration import conf
from airflow.utils.jwt_signer import JWTSigner

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod

oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")


Expand Down
12 changes: 7 additions & 5 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
from __future__ import annotations

from abc import abstractmethod
from collections.abc import Container, Sequence
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from sqlalchemy import select

from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
DagDetails,
)
from airflow.auth.managers.models.resource_details import DagDetails
from airflow.exceptions import AirflowException
from airflow.models import DagModel
from airflow.typing_compat import Literal
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Container, Sequence

from fastapi import FastAPI
from flask import Blueprint
from flask_appbuilder.menu import MenuItem
Expand All @@ -55,6 +55,8 @@
)
from airflow.cli.cli_config import CLICommand

# This cannot be in the TYPE_CHECKING block since some providers import it globally.
# TODO: Move this inside once all providers drop Airflow 2.x support.
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]

T = TypeVar("T", bound=BaseUser)
Expand Down
3 changes: 2 additions & 1 deletion airflow/auth/managers/simple/simple_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
from flask import session, url_for
from termcolor import colored

from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.simple.user import SimpleAuthManagerUser
from airflow.auth.managers.simple.views.auth import SimpleAuthManagerAuthenticationViews
from airflow.configuration import AIRFLOW_HOME, conf

if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem

from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.resource_details import (
AccessView,
AssetDetails,
Expand Down
9 changes: 5 additions & 4 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import textwrap
from collections.abc import Generator
from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from typing import TYPE_CHECKING, Protocol, Union, cast
from typing import TYPE_CHECKING, Protocol, cast

import pendulum
from pendulum.parsing.exceptions import ParserError
Expand All @@ -50,7 +50,6 @@
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import (
get_dag,
Expand All @@ -70,13 +69,15 @@
from airflow.utils.types import DagRunTriggeredByType

if TYPE_CHECKING:
from typing import Literal

from sqlalchemy.orm.session import Session

from airflow.models.operator import Operator

log = logging.getLogger(__name__)
CreateIfNecessary = Literal[False, "db", "memory"]

CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]]
log = logging.getLogger(__name__)


def _generate_temporary_run_id() -> str:
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.task_instance_session import get_current_task_instance_session
Expand All @@ -62,6 +61,7 @@

if TYPE_CHECKING:
import datetime
from typing import Literal

import jinja2 # Slow import.
import pendulum
Expand Down Expand Up @@ -89,7 +89,7 @@

TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]]

ValidationSource = Union[Literal["expand"], Literal["partial"]]
ValidationSource = Literal["expand", "partial"]


def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def ensure_prefix(field):


if TYPE_CHECKING:
from typing import Literal
from urllib.parse import SplitResult

from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
from airflow.sdk.definitions.asset import Asset
from airflow.typing_compat import Literal


class LazyDictWithCache(MutableMapping):
Expand Down Expand Up @@ -201,7 +201,7 @@ class ProviderInfo:

version: str
data: dict
package_or_source: Literal["source"] | Literal["package"]
package_or_source: Literal["source", "package"]

def __post_init__(self):
if self.package_or_source not in ("source", "package"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from flask import session, url_for

from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
AccessView,
ConnectionDetails,
Expand Down Expand Up @@ -53,6 +53,7 @@
if TYPE_CHECKING:
from flask_appbuilder.menu import MenuItem

from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.batch_apis import (
IsAuthorizedConnectionRequest,
Expand Down Expand Up @@ -326,11 +327,11 @@ def filter_permitted_dag_ids(
for method in ["GET", "PUT"]:
if method in methods:
request: IsAuthorizedRequest = {
"method": cast(ResourceMethod, method),
"method": cast("ResourceMethod", method),
"entity_type": AvpEntities.DAG,
"entity_id": dag_id,
}
requests[dag_id][cast(ResourceMethod, method)] = request
requests[dag_id][cast("ResourceMethod", method)] = request
requests_list.append(request)

batch_is_authorized_results = self.avp_facade.get_batch_is_authorized_results(
Expand Down
3 changes: 2 additions & 1 deletion providers/src/airflow/providers/amazon/aws/sensors/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
from airflow.providers.amazon.aws.triggers.sqs import SqsSensorTrigger
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response
from airflow.providers.amazon.aws.utils.sqs import process_response

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType
from airflow.utils.context import Context


Expand Down
3 changes: 2 additions & 1 deletion providers/src/airflow/providers/amazon/aws/triggers/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.sqs import SqsHook
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType, process_response
from airflow.providers.amazon.aws.utils.sqs import process_response
from airflow.triggers.base import BaseTrigger, TriggerEvent

if TYPE_CHECKING:
from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection
from airflow.providers.amazon.aws.utils.sqs import MessageFilteringType


class SqsSensorTrigger(BaseTrigger):
Expand Down
10 changes: 6 additions & 4 deletions providers/src/airflow/providers/amazon/aws/utils/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import annotations

import json
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

import jsonpath_ng
import jsonpath_ng.ext
from typing_extensions import Literal

log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Literal

MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]

MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]
log = logging.getLogger(__name__)


def process_response(
Expand Down
4 changes: 3 additions & 1 deletion providers/src/airflow/providers/docker/decorators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import os
from collections.abc import Sequence
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.exceptions import AirflowException
from airflow.providers.common.compat.standard.utils import write_python_script
from airflow.providers.docker.operators.docker import DockerOperator

if TYPE_CHECKING:
from typing import Literal

from airflow.decorators.base import TaskDecorator
from airflow.utils.context import Context

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from starlette.middleware.wsgi import WSGIMiddleware

from airflow import __version__ as airflow_version
from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod
from airflow.auth.managers.base_auth_manager import BaseAuthManager
from airflow.auth.managers.models.resource_details import (
AccessView,
ConfigurationDetails,
Expand Down Expand Up @@ -94,6 +94,7 @@
from airflow.version import version

if TYPE_CHECKING:
from airflow.auth.managers.base_auth_manager import ResourceMethod
from airflow.auth.managers.models.base_user import BaseUser
from airflow.cli.cli_config import (
CLICommand,
Expand Down
6 changes: 4 additions & 2 deletions providers/src/airflow/providers/standard/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)
from airflow.typing_compat import Literal
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
Expand All @@ -61,10 +60,14 @@
log = logging.getLogger(__name__)

if TYPE_CHECKING:
from typing import Literal

from pendulum.datetime import DateTime

from airflow.utils.context import Context

_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]


@cache
def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
Expand Down Expand Up @@ -343,7 +346,6 @@ def _load_cloudpickle():
return cloudpickle


_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
_SERIALIZERS: dict[_SerializerTypeDef, Any] = {
"pickle": lazy_object_proxy.Proxy(_load_pickle),
"dill": lazy_object_proxy.Proxy(_load_dill),
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_or_replace_document_objects(
verbose: bool = False,
) -> Sequence[dict[str, UUID | str] | None]:
"""
create or replace objects belonging to documents.
Create or replace objects belonging to documents.
In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues
are considered 'documents' here. It's crucial to keep the database objects in sync with these sources.
Expand Down
22 changes: 13 additions & 9 deletions tests/api_fastapi/core_api/routes/public/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING

import pytest

Expand All @@ -28,15 +28,19 @@
from tests_common.test_utils.db import clear_db_jobs
from tests_common.test_utils.format_datetime import from_datetime_to_zulu

if TYPE_CHECKING:
from typing import Literal

TestCase = Literal[
"should_report_success_for_one_working_scheduler",
"should_report_success_for_one_working_scheduler_with_hostname",
"should_report_success_for_ha_schedulers",
"should_ignore_not_running_jobs",
"should_raise_exception_for_multiple_scheduler_on_one_host",
]

pytestmark = pytest.mark.db_test

TESTCASE_TYPE = Literal[
"should_report_success_for_one_working_scheduler",
"should_report_success_for_one_working_scheduler_with_hostname",
"should_report_success_for_ha_schedulers",
"should_ignore_not_running_jobs",
"should_raise_exception_for_multiple_scheduler_on_one_host",
]
TESTCASE_ONE_SCHEDULER = "should_report_success_for_one_working_scheduler"
TESTCASE_ONE_SCHEDULER_WITH_HOSTNAME = "should_report_success_for_one_working_scheduler_with_hostname"
TESTCASE_HA_SCHEDULERS = "should_report_success_for_ha_schedulers"
Expand Down Expand Up @@ -107,7 +111,7 @@ def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self, sessi
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)

@provide_session
def setup(self, testcase: TESTCASE_TYPE, session=None) -> None:
def setup(self, testcase: TestCase, session=None) -> None:
"""
Setup testcase at runtime based on the `testcase` provided by `pytest.mark.parametrize`.
"""
Expand Down
Loading

0 comments on commit b703d53

Please sign in to comment.