Skip to content

Commit

Permalink
Add other instances of Context type hints (apache#45657)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil authored and HariGS-DB committed Jan 16, 2025
1 parent b979b24 commit f0f78ef
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 25 deletions.
8 changes: 5 additions & 3 deletions airflow/models/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import copy
import json
import logging
from collections.abc import ItemsView, Iterable, Mapping, MutableMapping, ValuesView
from collections.abc import ItemsView, Iterable, MutableMapping, ValuesView
from typing import TYPE_CHECKING, Any, ClassVar

from airflow.exceptions import AirflowException, ParamValidationError
Expand All @@ -29,6 +29,7 @@

if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -293,10 +294,11 @@ def __init__(self, current_dag: DAG, name: str, default: Any = NOTSET):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
"""Pull DagParam value from DagRun context. This method is run during ``op.execute()``."""
with contextlib.suppress(KeyError):
return context["dag_run"].conf[self._name]
if context["dag_run"].conf:
return context["dag_run"].conf[self._name]
if self._default is not NOTSET:
return self._default
with contextlib.suppress(KeyError):
Expand Down
11 changes: 8 additions & 3 deletions providers/src/airflow/providers/cncf/kubernetes/operators/pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import re
import shlex
import string
from collections.abc import Container, Iterable, Mapping, Sequence
from collections.abc import Container, Iterable, Sequence
from contextlib import AbstractContextManager
from enum import Enum
from functools import cached_property
Expand Down Expand Up @@ -90,7 +90,12 @@
from pendulum import DateTime

from airflow.providers.cncf.kubernetes.secret import Secret
from airflow.utils.context import Context

try:
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context

alphanum_lower = string.ascii_lowercase + string.digits

Expand Down Expand Up @@ -423,7 +428,7 @@ def _incluster_namespace(self):
def _render_nested_template_fields(
self,
content: Any,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# under the License.
from __future__ import annotations

from collections.abc import Mapping
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand All @@ -37,7 +36,11 @@
if TYPE_CHECKING:
import jinja2

from airflow.utils.context import Context
try:
from airflow.sdk.definitions.context import Context
except ImportError:
# TODO: Remove once provider drops support for Airflow 2
from airflow.utils.context import Context


class SparkKubernetesOperator(KubernetesPodOperator):
Expand Down Expand Up @@ -129,7 +132,7 @@ def __init__(
def _render_nested_template_fields(
self,
content: Any,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from collections.abc import (
Collection,
Iterable,
Mapping,
)
from typing import (
TYPE_CHECKING,
Expand All @@ -43,6 +42,7 @@
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG

DEFAULT_OWNER: str = "airflow"
Expand Down Expand Up @@ -291,7 +291,7 @@ def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/_internal/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from __future__ import annotations

from abc import abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from airflow.models.operator import Operator
from airflow.sdk.definitions.baseoperator import BaseOperator
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.edges import EdgeModifier

# TODO: Should this all just live on DAGNode?
Expand Down Expand Up @@ -132,7 +133,7 @@ def iter_references(self) -> Iterable[tuple[Operator, str]]:
"""
raise NotImplementedError

def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
"""
Resolve this value for runtime.
Expand Down
13 changes: 6 additions & 7 deletions task_sdk/src/airflow/sdk/definitions/_internal/templater.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
from airflow.utils.helpers import render_template_as_native, render_template_to_string

if TYPE_CHECKING:
from collections.abc import Mapping

from airflow.models.operator import Operator
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG


Expand All @@ -51,7 +50,7 @@ class LiteralValue(ResolveMixin):
def iter_references(self) -> Iterable[tuple[Operator, str]]:
return ()

def resolve(self, context: Mapping[str, Any], *, include_xcom: bool = True) -> Any:
def resolve(self, context: Context, *, include_xcom: bool = True) -> Any:
return self.value


Expand Down Expand Up @@ -113,7 +112,7 @@ def _do_render_template_fields(
self,
parent: Any,
template_fields: Iterable[str],
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
Expand All @@ -136,7 +135,7 @@ def _render(self, template, context, dag=None) -> Any:
def render_template(
self,
content: Any,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment | None = None,
seen_oids: set[int] | None = None,
) -> Any:
Expand Down Expand Up @@ -199,7 +198,7 @@ def render_template(
return value

def _render_object_storage_path(
self, value: ObjectStoragePath, context: Mapping[str, Any], jinja_env: jinja2.Environment
self, value: ObjectStoragePath, context: Context, jinja_env: jinja2.Environment
) -> ObjectStoragePath:
serialized_path = value.serialize()
path_version = value.__version__
Expand All @@ -209,7 +208,7 @@ def _render_object_storage_path(
def _render_nested_template_fields(
self,
value: Any,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set[int],
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions task_sdk/src/airflow/sdk/definitions/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import inspect
import sys
import warnings
from collections.abc import Collection, Iterable, Mapping, Sequence
from collections.abc import Collection, Iterable, Sequence
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from functools import total_ordering, wraps
Expand Down Expand Up @@ -68,6 +68,7 @@
import jinja2

from airflow.models.xcom_arg import XComArg
from airflow.sdk.definitions.context import Context
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.serialization.enums import DagAttributeTypes
Expand Down Expand Up @@ -1244,7 +1245,7 @@ def inherits_from_empty_operator(self):

def render_template_fields(
self,
context: Mapping[str, Any], # TODO: Change to `Context` once we have it
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _execute_task(context: Mapping[str, Any], task: BaseOperator):
def _execute_task(context: Context, task: BaseOperator):
"""Execute Task (optionally with a Timeout) and push Xcom results."""
from airflow.exceptions import AirflowTaskTimeout

Expand Down
4 changes: 2 additions & 2 deletions tests_common/test_utils/mock_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from collections.abc import Mapping, Sequence
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any

import attr
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(self, arg1: str = "", arg2: NestedFields | None = None, **kwargs):
def _render_nested_template_fields(
self,
content: Any,
context: Mapping[str, Any],
context: Context,
jinja_env: jinja2.Environment,
seen_oids: set,
) -> None:
Expand Down

0 comments on commit f0f78ef

Please sign in to comment.