Skip to content

Commit

Permalink
Move Literal alias into TYPE_CHECKING block
Browse files Browse the repository at this point in the history
This allows us to always import Literal from typing, which has received
quite some bug fixes between different versions. Doing this in the
typing block avoids loading the runtime version at all, thus eliminating
differences between runtime Python versions.
  • Loading branch information
uranusjr committed Jan 2, 2025
1 parent f7da5e4 commit a7ae2e7
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 33 deletions.
10 changes: 5 additions & 5 deletions airflow/auth/managers/base_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@
from abc import abstractmethod
from collections.abc import Container, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from flask_appbuilder.menu import MenuItem
from sqlalchemy import select

from airflow.auth.managers.models.base_user import BaseUser
from airflow.auth.managers.models.resource_details import (
DagDetails,
)
from airflow.auth.managers.models.resource_details import DagDetails
from airflow.exceptions import AirflowException
from airflow.models import DagModel
from airflow.security.permissions import ACTION_CAN_ACCESS_MENU
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from typing import Literal

from fastapi import FastAPI
from flask import Blueprint
from sqlalchemy.orm import Session
Expand All @@ -58,7 +58,7 @@
from airflow.cli.cli_config import CLICommand
from airflow.www.security_manager import AirflowSecurityManagerV2

ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]
ResourceMethod = Literal["GET", "POST", "PUT", "DELETE", "MENU"]

T = TypeVar("T", bound=BaseUser)

Expand Down
9 changes: 5 additions & 4 deletions airflow/cli/commands/remote_commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import textwrap
from collections.abc import Generator
from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from typing import TYPE_CHECKING, Protocol, Union, cast
from typing import TYPE_CHECKING, Protocol, cast

import pendulum
from pendulum.parsing.exceptions import ParserError
Expand All @@ -50,7 +50,6 @@
from airflow.settings import IS_EXECUTOR_CONTAINER, IS_K8S_EXECUTOR_POD
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_deps import SCHEDULER_QUEUED_DEPS
from airflow.typing_compat import Literal
from airflow.utils import cli as cli_utils, timezone
from airflow.utils.cli import (
get_dag,
Expand All @@ -70,13 +69,15 @@
from airflow.utils.types import DagRunTriggeredByType

if TYPE_CHECKING:
from typing import Literal

from sqlalchemy.orm.session import Session

from airflow.models.operator import Operator

log = logging.getLogger(__name__)
CreateIfNecessary = Literal[False, "db", "memory"]

CreateIfNecessary = Union[Literal[False], Literal["db"], Literal["memory"]]
log = logging.getLogger(__name__)


def _generate_temporary_run_id() -> str:
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.triggers.base import StartTriggerArgs
from airflow.typing_compat import Literal
from airflow.utils.context import context_update_for_unmapped
from airflow.utils.helpers import is_container, prevent_duplicates
from airflow.utils.task_instance_session import get_current_task_instance_session
Expand All @@ -62,6 +61,7 @@

if TYPE_CHECKING:
import datetime
from typing import Literal

import jinja2 # Slow import.
import pendulum
Expand Down Expand Up @@ -89,7 +89,7 @@

TaskStateChangeCallbackAttrType = Union[None, TaskStateChangeCallback, list[TaskStateChangeCallback]]

ValidationSource = Union[Literal["expand"], Literal["partial"]]
ValidationSource = Literal["expand", "partial"]


def validate_mapping_kwargs(op: type[BaseOperator], func: ValidationSource, value: dict[str, Any]) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ def ensure_prefix(field):


if TYPE_CHECKING:
from typing import Literal
from urllib.parse import SplitResult

from airflow.decorators.base import TaskDecorator
from airflow.hooks.base import BaseHook
from airflow.sdk.definitions.asset import Asset
from airflow.typing_compat import Literal


class LazyDictWithCache(MutableMapping):
Expand Down Expand Up @@ -201,7 +201,7 @@ class ProviderInfo:

version: str
data: dict
package_or_source: Literal["source"] | Literal["package"]
package_or_source: Literal["source", "package"]

def __post_init__(self):
if self.package_or_source not in ("source", "package"):
Expand Down
10 changes: 6 additions & 4 deletions providers/src/airflow/providers/amazon/aws/utils/sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import annotations

import json
import logging
from typing import Any
from typing import TYPE_CHECKING, Any

import jsonpath_ng
import jsonpath_ng.ext
from typing_extensions import Literal

log = logging.getLogger(__name__)
if TYPE_CHECKING:
from typing import Literal

MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]

MessageFilteringType = Literal["literal", "jsonpath", "jsonpath-ext"]
log = logging.getLogger(__name__)


def process_response(
Expand Down
4 changes: 3 additions & 1 deletion providers/src/airflow/providers/docker/decorators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
import os
from collections.abc import Sequence
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable

from airflow.decorators.base import DecoratedOperator, task_decorator_factory
from airflow.exceptions import AirflowException
from airflow.providers.common.compat.standard.utils import write_python_script
from airflow.providers.docker.operators.docker import DockerOperator

if TYPE_CHECKING:
from typing import Literal

from airflow.decorators.base import TaskDecorator
from airflow.utils.context import Context

Expand Down
6 changes: 4 additions & 2 deletions providers/src/airflow/providers/standard/operators/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
AIRFLOW_V_2_10_PLUS,
AIRFLOW_V_3_0_PLUS,
)
from airflow.typing_compat import Literal
from airflow.utils import hashlib_wrapper
from airflow.utils.context import context_copy_partial, context_merge
from airflow.utils.file import get_unique_dag_module_name
Expand All @@ -61,10 +60,14 @@
log = logging.getLogger(__name__)

if TYPE_CHECKING:
from typing import Literal

from pendulum.datetime import DateTime

from airflow.utils.context import Context

_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]


@cache
def _parse_version_info(text: str) -> tuple[int, int, int, str, int]:
Expand Down Expand Up @@ -343,7 +346,6 @@ def _load_cloudpickle():
return cloudpickle


_SerializerTypeDef = Literal["pickle", "cloudpickle", "dill"]
_SERIALIZERS: dict[_SerializerTypeDef, Any] = {
"pickle": lazy_object_proxy.Proxy(_load_pickle),
"dill": lazy_object_proxy.Proxy(_load_dill),
Expand Down
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/weaviate/hooks/weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def create_or_replace_document_objects(
verbose: bool = False,
) -> Sequence[dict[str, UUID | str] | None]:
"""
create or replace objects belonging to documents.
Create or replace objects belonging to documents.
In real-world scenarios, information sources like Airflow docs, Stack Overflow, or other issues
are considered 'documents' here. It's crucial to keep the database objects in sync with these sources.
Expand Down
22 changes: 13 additions & 9 deletions tests/api_fastapi/core_api/routes/public/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Literal
from typing import TYPE_CHECKING

import pytest

Expand All @@ -28,15 +28,19 @@
from tests_common.test_utils.db import clear_db_jobs
from tests_common.test_utils.format_datetime import from_datetime_to_zulu

if TYPE_CHECKING:
from typing import Literal

TestCase = Literal[
"should_report_success_for_one_working_scheduler",
"should_report_success_for_one_working_scheduler_with_hostname",
"should_report_success_for_ha_schedulers",
"should_ignore_not_running_jobs",
"should_raise_exception_for_multiple_scheduler_on_one_host",
]

pytestmark = pytest.mark.db_test

TESTCASE_TYPE = Literal[
"should_report_success_for_one_working_scheduler",
"should_report_success_for_one_working_scheduler_with_hostname",
"should_report_success_for_ha_schedulers",
"should_ignore_not_running_jobs",
"should_raise_exception_for_multiple_scheduler_on_one_host",
]
TESTCASE_ONE_SCHEDULER = "should_report_success_for_one_working_scheduler"
TESTCASE_ONE_SCHEDULER_WITH_HOSTNAME = "should_report_success_for_one_working_scheduler_with_hostname"
TESTCASE_HA_SCHEDULERS = "should_report_success_for_ha_schedulers"
Expand Down Expand Up @@ -107,7 +111,7 @@ def _setup_should_raise_exception_for_multiple_scheduler_on_one_host(self, sessi
scheduler_job.heartbeat(heartbeat_callback=job_runner.heartbeat_callback)

@provide_session
def setup(self, testcase: TESTCASE_TYPE, session=None) -> None:
def setup(self, testcase: TestCase, session=None) -> None:
"""
Setup testcase at runtime based on the `testcase` provided by `pytest.mark.parametrize`.
"""
Expand Down
9 changes: 6 additions & 3 deletions tests_common/_internals/capture_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@
from contextlib import contextmanager
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Callable
from typing import TYPE_CHECKING, Callable

import pytest
from typing_extensions import Literal

WhenTypeDef = Literal["config", "collect", "runtest"]
if TYPE_CHECKING:
from typing import Literal

WhenTypeDef = Literal["config", "collect", "runtest"]

TESTS_DIR = Path(__file__).parents[1].resolve()


Expand Down

0 comments on commit a7ae2e7

Please sign in to comment.