From bd77d7d9f07103ec57f21f991739c784b8a2cd22 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 29 Feb 2024 11:16:42 +0100 Subject: [PATCH 1/3] Enforce `__future__.annotations` (#2483) --- pyproject.toml | 11 +++-------- starlette/applications.py | 12 ++++++------ starlette/authentication.py | 10 ++-------- starlette/config.py | 6 +++--- starlette/convertors.py | 4 +++- starlette/datastructures.py | 12 ++++++------ starlette/formparsers.py | 2 +- starlette/middleware/authentication.py | 7 ++++--- starlette/middleware/base.py | 16 ++++++++-------- starlette/middleware/cors.py | 4 +++- starlette/middleware/errors.py | 6 +++--- starlette/middleware/exceptions.py | 13 ++++++++----- starlette/middleware/sessions.py | 8 +++++--- starlette/middleware/trustedhost.py | 4 +++- starlette/middleware/wsgi.py | 8 +++++--- starlette/requests.py | 10 +++++----- starlette/routing.py | 12 +++++------- starlette/templating.py | 2 +- starlette/websockets.py | 4 ++-- tests/conftest.py | 6 ++++-- tests/middleware/test_base.py | 15 +++++++-------- tests/test_authentication.py | 6 ++++-- tests/test_formparsers.py | 10 ++++++---- tests/test_requests.py | 16 +++++++++------- tests/test_responses.py | 6 ++++-- tests/test_routing.py | 6 ++++-- tests/test_templates.py | 4 +++- tests/test_websockets.py | 2 +- 28 files changed, 118 insertions(+), 104 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3cf7c0d407..679deaade1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,9 +9,7 @@ description = "The little ASGI library that shines." readme = "README.md" license = "BSD-3-Clause" requires-python = ">=3.8" -authors = [ - { name = "Tom Christie", email = "tom@tomchristie.com" }, -] +authors = [{ name = "Tom Christie", email = "tom@tomchristie.com" }] classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Web Environment", @@ -52,7 +50,7 @@ Source = "https://github.com/encode/starlette" path = "starlette/__init__.py" [tool.ruff.lint] -select = ["E", "F", "I"] +select = ["E", "F", "I", "FA", "UP"] [tool.ruff.lint.isort] combine-as-imports = true @@ -83,10 +81,7 @@ filterwarnings = [ ] [tool.coverage.run] -source_pkgs = [ - "starlette", - "tests", -] +source_pkgs = ["starlette", "tests"] [tool.coverage.report] exclude_lines = [ diff --git a/starlette/applications.py b/starlette/applications.py index 1a4e3d264f..913fd4c9db 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -79,7 +79,7 @@ def __init__( {} if exception_handlers is None else dict(exception_handlers) ) self.user_middleware = [] if middleware is None else list(middleware) - self.middleware_stack: typing.Optional[ASGIApp] = None + self.middleware_stack: ASGIApp | None = None def build_middleware_stack(self) -> ASGIApp: debug = self.debug @@ -133,7 +133,7 @@ def host(self, host: str, app: ASGIApp, name: str | None = None) -> None: def add_middleware( self, - middleware_class: typing.Type[_MiddlewareClass[P]], + middleware_class: type[_MiddlewareClass[P]], *args: P.args, **kwargs: P.kwargs, ) -> None: @@ -143,7 +143,7 @@ def add_middleware( def add_exception_handler( self, - exc_class_or_status_code: int | typing.Type[Exception], + exc_class_or_status_code: int | type[Exception], handler: ExceptionHandler, ) -> None: # pragma: no cover self.exception_handlers[exc_class_or_status_code] = handler @@ -159,8 +159,8 @@ def add_route( self, path: str, route: typing.Callable[[Request], typing.Awaitable[Response] | Response], - methods: typing.Optional[typing.List[str]] = None, - name: typing.Optional[str] = None, + methods: list[str] | None = None, + name: str | None = None, include_in_schema: bool = True, ) -> None: # pragma: no cover self.router.add_route( @@ -176,7 +176,7 @@ def add_websocket_route( self.router.add_websocket_route(path, route, name=name) def exception_handler( - self, exc_class_or_status_code: int | typing.Type[Exception] + self, exc_class_or_status_code: int | type[Exception] ) -> typing.Callable: # type: ignore[type-arg] warnings.warn( "The `exception_handler` decorator is deprecated, and will be removed in version 1.0.0. " # noqa: E501 diff --git a/starlette/authentication.py b/starlette/authentication.py index e26a8a3881..f2586a0427 100644 --- a/starlette/authentication.py +++ b/starlette/authentication.py @@ -75,10 +75,7 @@ async def async_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) - next_url = "{redirect_path}?{orig_request}".format( - redirect_path=request.url_for(redirect), - orig_request=orig_request_qparam, - ) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return await func(*args, **kwargs) @@ -95,10 +92,7 @@ def sync_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> typing.Any: if not has_required_scope(request, scopes_list): if redirect is not None: orig_request_qparam = urlencode({"next": str(request.url)}) - next_url = "{redirect_path}?{orig_request}".format( - redirect_path=request.url_for(redirect), - orig_request=orig_request_qparam, - ) + next_url = f"{request.url_for(redirect)}?{orig_request_qparam}" return RedirectResponse(url=next_url, status_code=303) raise HTTPException(status_code=status_code) return func(*args, **kwargs) diff --git a/starlette/config.py b/starlette/config.py index d222a0a627..5b9813beac 100644 --- a/starlette/config.py +++ b/starlette/config.py @@ -17,7 +17,7 @@ class EnvironError(Exception): class Environ(typing.MutableMapping[str, str]): def __init__(self, environ: typing.MutableMapping[str, str] = os.environ): self._environ = environ - self._has_been_read: typing.Set[str] = set() + self._has_been_read: set[str] = set() def __getitem__(self, key: str) -> str: self._has_been_read.add(key) @@ -60,7 +60,7 @@ def __init__( ) -> None: self.environ = environ self.env_prefix = env_prefix - self.file_values: typing.Dict[str, str] = {} + self.file_values: dict[str, str] = {} if env_file is not None: if not os.path.isfile(env_file): warnings.warn(f"Config file '{env_file}' not found.") @@ -118,7 +118,7 @@ def get( raise KeyError(f"Config '{key}' is missing, and has no default.") def _read_file(self, file_name: str | Path) -> dict[str, str]: - file_values: typing.Dict[str, str] = {} + file_values: dict[str, str] = {} with open(file_name) as input_file: for line in input_file.readlines(): line = line.strip() diff --git a/starlette/convertors.py b/starlette/convertors.py index 3b12ac7a0c..2d8ab53beb 100644 --- a/starlette/convertors.py +++ b/starlette/convertors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import math import typing import uuid @@ -74,7 +76,7 @@ def to_string(self, value: uuid.UUID) -> str: return str(value) -CONVERTOR_TYPES: typing.Dict[str, Convertor[typing.Any]] = { +CONVERTOR_TYPES: dict[str, Convertor[typing.Any]] = { "str": StringConvertor(), "path": PathConvertor(), "int": IntegerConvertor(), diff --git a/starlette/datastructures.py b/starlette/datastructures.py index e430d09b6b..54b5e54f3b 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -150,7 +150,7 @@ def replace_query_params(self, **kwargs: typing.Any) -> URL: query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) - def remove_query_params(self, keys: str | typing.Sequence[str]) -> "URL": + def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL: if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) @@ -178,7 +178,7 @@ class URLPath(str): Used by the routing to return `url_path_for` matches. """ - def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": + def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath: assert protocol in ("http", "websocket", "") return str.__new__(cls, path) @@ -251,13 +251,13 @@ def __str__(self) -> str: class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]): - _dict: typing.Dict[_KeyType, _CovariantValueType] + _dict: dict[_KeyType, _CovariantValueType] def __init__( self, *args: ImmutableMultiDict[_KeyType, _CovariantValueType] | typing.Mapping[_KeyType, _CovariantValueType] - | typing.Iterable[typing.Tuple[_KeyType, _CovariantValueType]], + | typing.Iterable[tuple[_KeyType, _CovariantValueType]], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." @@ -599,7 +599,7 @@ def __setitem__(self, key: str, value: str) -> None: set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") - found_indexes: "typing.List[int]" = [] + found_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: found_indexes.append(idx) @@ -619,7 +619,7 @@ def __delitem__(self, key: str) -> None: """ del_key = key.lower().encode("latin-1") - pop_indexes: "typing.List[int]" = [] + pop_indexes: list[int] = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) diff --git a/starlette/formparsers.py b/starlette/formparsers.py index e2a95e53fe..2e12c7faac 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -91,7 +91,7 @@ async def parse(self) -> FormData: field_name = b"" field_value = b"" - items: list[tuple[str, typing.Union[str, UploadFile]]] = [] + items: list[tuple[str, str | UploadFile]] = [] # Feed the parser with data from the request. async for chunk in self.stream: diff --git a/starlette/middleware/authentication.py b/starlette/middleware/authentication.py index 21f0974343..966c639bb6 100644 --- a/starlette/middleware/authentication.py +++ b/starlette/middleware/authentication.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from starlette.authentication import ( @@ -16,9 +18,8 @@ def __init__( self, app: ASGIApp, backend: AuthenticationBackend, - on_error: typing.Optional[ - typing.Callable[[HTTPConnection, AuthenticationError], Response] - ] = None, + on_error: typing.Callable[[HTTPConnection, AuthenticationError], Response] + | None = None, ) -> None: self.app = app self.backend = backend diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ad3ffcfeef..4e5054d7a2 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing import anyio @@ -92,9 +94,7 @@ async def wrapped_receive(self) -> Message: class BaseHTTPMiddleware: - def __init__( - self, app: ASGIApp, dispatch: typing.Optional[DispatchFunction] = None - ) -> None: + def __init__(self, app: ASGIApp, dispatch: DispatchFunction | None = None) -> None: self.app = app self.dispatch_func = self.dispatch if dispatch is None else dispatch @@ -108,7 +108,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: response_sent = anyio.Event() async def call_next(request: Request) -> Response: - app_exc: typing.Optional[Exception] = None + app_exc: Exception | None = None send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send_stream, recv_stream = anyio.create_memory_object_stream() @@ -203,10 +203,10 @@ def __init__( self, content: ContentStream, status_code: int = 200, - headers: typing.Optional[typing.Mapping[str, str]] = None, - media_type: typing.Optional[str] = None, - background: typing.Optional[BackgroundTask] = None, - info: typing.Optional[typing.Mapping[str, typing.Any]] = None, + headers: typing.Mapping[str, str] | None = None, + media_type: str | None = None, + background: BackgroundTask | None = None, + info: typing.Mapping[str, typing.Any] | None = None, ) -> None: self._info = info super().__init__(content, status_code, headers, media_type, background) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 5c9bfa6840..4b8e97bc9d 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import re import typing @@ -18,7 +20,7 @@ def __init__( allow_methods: typing.Sequence[str] = ("GET",), allow_headers: typing.Sequence[str] = (), allow_credentials: bool = False, - allow_origin_regex: typing.Optional[str] = None, + allow_origin_regex: str | None = None, expose_headers: typing.Sequence[str] = (), max_age: int = 600, ) -> None: diff --git a/starlette/middleware/errors.py b/starlette/middleware/errors.py index c6336160ca..e9eba62b0b 100644 --- a/starlette/middleware/errors.py +++ b/starlette/middleware/errors.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import html import inspect import traceback @@ -137,9 +139,7 @@ class ServerErrorMiddleware: def __init__( self, app: ASGIApp, - handler: typing.Optional[ - typing.Callable[[Request, Exception], typing.Any] - ] = None, + handler: typing.Callable[[Request, Exception], typing.Any] | None = None, debug: bool = False, ) -> None: self.app = app diff --git a/starlette/middleware/exceptions.py b/starlette/middleware/exceptions.py index 0124f5c8f3..b2bf88dbfe 100644 --- a/starlette/middleware/exceptions.py +++ b/starlette/middleware/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from starlette._exception_handler import ( @@ -16,9 +18,10 @@ class ExceptionMiddleware: def __init__( self, app: ASGIApp, - handlers: typing.Optional[ - typing.Mapping[typing.Any, typing.Callable[[Request, Exception], Response]] - ] = None, + handlers: typing.Mapping[ + typing.Any, typing.Callable[[Request, Exception], Response] + ] + | None = None, debug: bool = False, ) -> None: self.app = app @@ -34,7 +37,7 @@ def __init__( def add_exception_handler( self, - exc_class_or_status_code: typing.Union[int, typing.Type[Exception]], + exc_class_or_status_code: int | type[Exception], handler: typing.Callable[[Request, Exception], Response], ) -> None: if isinstance(exc_class_or_status_code, int): @@ -53,7 +56,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: self._status_handlers, ) - conn: typing.Union[Request, WebSocket] + conn: Request | WebSocket if scope["type"] == "http": conn = Request(scope, receive, send) else: diff --git a/starlette/middleware/sessions.py b/starlette/middleware/sessions.py index 1093717b43..5855912cac 100644 --- a/starlette/middleware/sessions.py +++ b/starlette/middleware/sessions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import typing from base64 import b64decode, b64encode @@ -14,13 +16,13 @@ class SessionMiddleware: def __init__( self, app: ASGIApp, - secret_key: typing.Union[str, Secret], + secret_key: str | Secret, session_cookie: str = "session", - max_age: typing.Optional[int] = 14 * 24 * 60 * 60, # 14 days, in seconds + max_age: int | None = 14 * 24 * 60 * 60, # 14 days, in seconds path: str = "/", same_site: typing.Literal["lax", "strict", "none"] = "lax", https_only: bool = False, - domain: typing.Optional[str] = None, + domain: str | None = None, ) -> None: self.app = app self.signer = itsdangerous.TimestampSigner(str(secret_key)) diff --git a/starlette/middleware/trustedhost.py b/starlette/middleware/trustedhost.py index e84e6876a0..59e5273633 100644 --- a/starlette/middleware/trustedhost.py +++ b/starlette/middleware/trustedhost.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing from starlette.datastructures import URL, Headers @@ -11,7 +13,7 @@ class TrustedHostMiddleware: def __init__( self, app: ASGIApp, - allowed_hosts: typing.Optional[typing.Sequence[str]] = None, + allowed_hosts: typing.Sequence[str] | None = None, www_redirect: bool = True, ) -> None: if allowed_hosts is None: diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 2ce83b0740..c9a7e13281 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import io import math import sys @@ -16,7 +18,7 @@ ) -def build_environ(scope: Scope, body: bytes) -> typing.Dict[str, typing.Any]: +def build_environ(scope: Scope, body: bytes) -> dict[str, typing.Any]: """ Builds a scope and request body into a WSGI environ object. """ @@ -117,7 +119,7 @@ async def sender(self, send: Send) -> None: def start_response( self, status: str, - response_headers: typing.List[typing.Tuple[str, str]], + response_headers: list[tuple[str, str]], exc_info: typing.Any = None, ) -> None: self.exc_info = exc_info @@ -140,7 +142,7 @@ def start_response( def wsgi( self, - environ: typing.Dict[str, typing.Any], + environ: dict[str, typing.Any], start_response: typing.Callable[..., typing.Any], ) -> None: for chunk in self.app(environ, start_response): diff --git a/starlette/requests.py b/starlette/requests.py index 4af63bfc1f..b27e8e1e26 100644 --- a/starlette/requests.py +++ b/starlette/requests.py @@ -43,7 +43,7 @@ def cookie_parser(cookie_string: str) -> dict[str, str]: Note: we are explicitly _NOT_ using `SimpleCookie.load` because it is based on an outdated spec and will fail on lots of input we want to support """ - cookie_dict: typing.Dict[str, str] = {} + cookie_dict: dict[str, str] = {} for chunk in cookie_string.split(";"): if "=" in chunk: key, val = chunk.split("=", 1) @@ -135,7 +135,7 @@ def path_params(self) -> dict[str, typing.Any]: @property def cookies(self) -> dict[str, str]: if not hasattr(self, "_cookies"): - cookies: typing.Dict[str, str] = {} + cookies: dict[str, str] = {} cookie_header = self.headers.get("cookie") if cookie_header: @@ -197,7 +197,7 @@ async def empty_send(message: Message) -> typing.NoReturn: class Request(HTTPConnection): - _form: typing.Optional[FormData] + _form: FormData | None def __init__( self, scope: Scope, receive: Receive = empty_receive, send: Send = empty_send @@ -240,7 +240,7 @@ async def stream(self) -> typing.AsyncGenerator[bytes, None]: async def body(self) -> bytes: if not hasattr(self, "_body"): - chunks: "typing.List[bytes]" = [] + chunks: list[bytes] = [] async for chunk in self.stream(): chunks.append(chunk) self._body = b"".join(chunks) @@ -309,7 +309,7 @@ async def is_disconnected(self) -> bool: async def send_push_promise(self, path: str) -> None: if "http.response.push" in self.scope.get("extensions", {}): - raw_headers: "typing.List[typing.Tuple[bytes, bytes]]" = [] + raw_headers: list[tuple[bytes, bytes]] = [] for name in SERVER_PUSH_HEADERS_TO_COPY: for value in self.headers.getlist(name): raw_headers.append( diff --git a/starlette/routing.py b/starlette/routing.py index b5467bb05c..92cdf2be8b 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -57,9 +57,7 @@ def iscoroutinefunction_or_partial(obj: typing.Any) -> bool: # pragma: no cover def request_response( - func: typing.Callable[ - [Request], typing.Union[typing.Awaitable[Response], Response] - ], + func: typing.Callable[[Request], typing.Awaitable[Response] | Response], ) -> ASGIApp: """ Takes a function or coroutine `func(request) -> response`, @@ -255,7 +253,7 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: - path_params: "typing.Dict[str, typing.Any]" + path_params: dict[str, typing.Any] if scope["type"] == "http": route_path = get_route_path(scope) match = self.path_regex.match(route_path) @@ -344,7 +342,7 @@ def __init__( self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> tuple[Match, Scope]: - path_params: "typing.Dict[str, typing.Any]" + path_params: dict[str, typing.Any] if scope["type"] == "websocket": route_path = get_route_path(scope) match = self.path_regex.match(route_path) @@ -417,8 +415,8 @@ def __init__( def routes(self) -> list[BaseRoute]: return getattr(self._base_app, "routes", []) - def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: - path_params: "typing.Dict[str, typing.Any]" + def matches(self, scope: Scope) -> tuple[Match, Scope]: + path_params: dict[str, typing.Any] if scope["type"] in ("http", "websocket"): root_path = scope.get("root_path", "") route_path = get_route_path(scope) diff --git a/starlette/templating.py b/starlette/templating.py index fe31ab5ee4..2dc3a5930d 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -129,7 +129,7 @@ def _create_env( def _setup_env_defaults(self, env: jinja2.Environment) -> None: @pass_context def url_for( - context: typing.Dict[str, typing.Any], + context: dict[str, typing.Any], name: str, /, **path_params: typing.Any, diff --git a/starlette/websockets.py b/starlette/websockets.py index 955063fa17..53ab5a70c8 100644 --- a/starlette/websockets.py +++ b/starlette/websockets.py @@ -17,7 +17,7 @@ class WebSocketState(enum.Enum): class WebSocketDisconnect(Exception): - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + def __init__(self, code: int = 1000, reason: str | None = None) -> None: self.code = code self.reason = reason or "" @@ -95,7 +95,7 @@ async def send(self, message: Message) -> None: self.application_state = WebSocketState.DISCONNECTED try: await self._send(message) - except IOError: + except OSError: self.application_state = WebSocketState.DISCONNECTED raise WebSocketDisconnect(code=1006) elif self.application_state == WebSocketState.RESPONSE: diff --git a/tests/conftest.py b/tests/conftest.py index 1a61664d17..724ca65d3c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import functools -from typing import Any, Callable, Dict, Literal +from typing import Any, Callable, Literal import pytest @@ -11,7 +13,7 @@ @pytest.fixture def test_client_factory( anyio_backend_name: Literal["asyncio", "trio"], - anyio_backend_options: Dict[str, Any], + anyio_backend_options: dict[str, Any], ) -> TestClientFactory: # anyio_backend_name defined by: # https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 6e5e42b944..2176404d82 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextvars from contextlib import AsyncExitStack from typing import ( @@ -5,9 +7,6 @@ AsyncGenerator, Callable, Generator, - List, - Type, - Union, ) import anyio @@ -241,7 +240,7 @@ async def dispatch( ) def test_contextvars( test_client_factory: TestClientFactory, - middleware_cls: Type[_MiddlewareClass[Any]], + middleware_cls: type[_MiddlewareClass[Any]], ) -> None: # this has to be an async endpoint because Starlette calls run_in_threadpool # on sync endpoints which has it's own set of peculiarities w.r.t propagating @@ -318,7 +317,7 @@ async def send(message: Message) -> None: async def test_do_not_block_on_background_tasks() -> None: request_body_sent = False response_complete = anyio.Event() - events: List[Union[str, Message]] = [] + events: list[str | Message] = [] async def sleep_and_set() -> None: events.append("Background task started") @@ -766,7 +765,7 @@ async def dispatch( call_next: RequestResponseEndpoint, ) -> Response: expected = b"1" - response: Union[Response, None] = None + response: Response | None = None async for chunk in request.stream(): assert chunk == expected if expected == b"1": @@ -783,7 +782,7 @@ async def rcv() -> AsyncGenerator[Message, None]: yield {"type": "http.request", "body": b"3"} await anyio.sleep(float("inf")) - sent: List[Message] = [] + sent: list[Message] = [] async def send(msg: Message) -> None: sent.append(msg) @@ -1000,7 +999,7 @@ def test_pr_1519_comment_1236166180_example() -> None: """ https://github.com/encode/starlette/pull/1519#issuecomment-1236166180 """ - bodies: List[bytes] = [] + bodies: list[bytes] = [] class LogRequestBodySize(BaseHTTPMiddleware): async def dispatch( diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 27b0337620..ecddda75ed 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import base64 import binascii -from typing import Any, Awaitable, Callable, Optional, Tuple +from typing import Any, Awaitable, Callable from urllib.parse import urlencode import pytest @@ -31,7 +33,7 @@ class BasicAuth(AuthenticationBackend): async def authenticate( self, request: HTTPConnection, - ) -> Optional[Tuple[AuthCredentials, SimpleUser]]: + ) -> tuple[AuthCredentials, SimpleUser] | None: if "Authorization" not in request.headers: return None diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 4f0cd430d3..ed2226878b 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import typing from contextlib import nullcontext as does_not_raise @@ -29,7 +31,7 @@ def __bool__(self) -> bool: async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: typing.Dict[str, typing.Any] = {} + output: dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -49,7 +51,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: typing.Dict[str, typing.List[typing.Any]] = {} + output: dict[str, list[typing.Any]] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] @@ -73,7 +75,7 @@ async def multi_items_app(scope: Scope, receive: Receive, send: Send) -> None: async def app_with_headers(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form() - output: typing.Dict[str, typing.Any] = {} + output: dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() @@ -108,7 +110,7 @@ def make_app_max_parts(max_files: int = 1000, max_fields: int = 1000) -> ASGIApp async def app(scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope, receive) data = await request.form(max_files=max_files, max_fields=max_fields) - output: typing.Dict[str, typing.Any] = {} + output: dict[str, typing.Any] = {} for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() diff --git a/tests/test_requests.py b/tests/test_requests.py index b3ce3a04ad..d8e2e94773 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import sys -from typing import Any, Callable, Dict, Iterator, List, Optional +from typing import Any, Callable, Iterator import anyio import pytest @@ -72,7 +74,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ({}, None), ], ) -def test_request_client(scope: Scope, expected_client: Optional[Address]) -> None: +def test_request_client(scope: Scope, expected_client: Address | None) -> None: scope.update({"type": "http"}) # required by Request's constructor client = Request(scope).client assert client == expected_client @@ -239,7 +241,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_request_disconnect( anyio_backend_name: str, - anyio_backend_options: Dict[str, Any], + anyio_backend_options: dict[str, Any], ) -> None: """ If a client disconnect occurs while reading request body @@ -391,7 +393,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) def test_cookies_edge_cases( set_cookie: str, - expected: Dict[str, str], + expected: dict[str, str], test_client_factory: TestClientFactory, ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: @@ -430,7 +432,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ) def test_cookies_invalid( set_cookie: str, - expected: Dict[str, str], + expected: dict[str, str], test_client_factory: TestClientFactory, ) -> None: """ @@ -542,7 +544,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: ], ) @pytest.mark.anyio -async def test_request_rcv(messages: List[Message]) -> None: +async def test_request_rcv(messages: list[Message]) -> None: messages = messages.copy() async def rcv() -> Message: @@ -557,7 +559,7 @@ async def rcv() -> Message: @pytest.mark.anyio async def test_request_stream_called_twice() -> None: - messages: List[Message] = [ + messages: list[Message] = [ {"type": "http.request", "body": b"1", "more_body": True}, {"type": "http.request", "body": b"2", "more_body": True}, {"type": "http.request", "body": b"3"}, diff --git a/tests/test_responses.py b/tests/test_responses.py index 57a5949018..a3cdcadcf2 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import datetime as dt import os import time from http.cookies import SimpleCookie from pathlib import Path -from typing import AsyncIterator, Callable, Iterator, Union +from typing import AsyncIterator, Callable, Iterator import anyio import pytest @@ -160,7 +162,7 @@ def test_streaming_response_custom_iterable( ) -> None: async def app(scope: Scope, receive: Receive, send: Send) -> None: class CustomAsyncIterable: - async def __aiter__(self) -> AsyncIterator[Union[str, bytes]]: + async def __aiter__(self) -> AsyncIterator[str | bytes]: for i in range(5): yield str(i + 1) diff --git a/tests/test_routing.py b/tests/test_routing.py index 8c3f16639a..b75fc47f02 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import functools import json @@ -762,7 +764,7 @@ def test_lifespan_state_unsupported( @contextlib.asynccontextmanager async def lifespan( app: ASGIApp, - ) -> typing.AsyncGenerator[typing.Dict[str, str], None]: + ) -> typing.AsyncGenerator[dict[str, str], None]: yield {"foo": "bar"} app = Router( @@ -787,7 +789,7 @@ def test_lifespan_state_async_cm(test_client_factory: TestClientFactory) -> None class State(typing.TypedDict): count: int - items: typing.List[int] + items: list[int] async def hello_world(request: Request) -> Response: # modifications to the state should not leak across requests diff --git a/tests/test_templates.py b/tests/test_templates.py index ab0b38a910..95e392ed5a 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import typing from pathlib import Path @@ -46,7 +48,7 @@ def test_calls_context_processors( async def homepage(request: Request) -> Response: return templates.TemplateResponse(request, "index.html") - def hello_world_processor(request: Request) -> typing.Dict[str, str]: + def hello_world_processor(request: Request) -> dict[str, str]: return {"username": "World"} app = Starlette( diff --git a/tests/test_websockets.py b/tests/test_websockets.py index c4b6c16bdb..854c269143 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -273,7 +273,7 @@ async def send(message: Message) -> None: return # Simulate the exception the server would send to the application when the # client disconnects. - raise IOError + raise OSError with pytest.raises(WebSocketDisconnect) as ctx: await app({"type": "websocket", "path": "/"}, receive, send) From 39dccd911251dfa886775493ed6cefcb15549bc2 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Thu, 29 Feb 2024 13:55:04 +0100 Subject: [PATCH 2/3] Revert "Turn `scope["client"]` to `None` on `TestClient` (#2377)" (#2525) * Revert "Turn `scope["client"]` to `None` on `TestClient` (#2377)" This reverts commit 483849a466a2bfc121f5a367339e1aa3ed20344b. * format * Add type hints --------- Co-authored-by: Marcelo Trylesinski --- starlette/testclient.py | 4 ++-- tests/test_testclient.py | 13 +++++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index d076331c18..f17d4e8923 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -298,7 +298,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "scheme": scheme, "query_string": query.encode(), "headers": headers, - "client": None, + "client": ["testclient", 50000], "server": [host, port], "subprotocols": subprotocols, "state": self.app_state.copy(), @@ -317,7 +317,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "scheme": scheme, "query_string": query.encode(), "headers": headers, - "client": None, + "client": ["testclient", 50000], "server": [host, port], "extensions": {"http.response.debug": {}}, "state": self.app_state.copy(), diff --git a/tests/test_testclient.py b/tests/test_testclient.py index e8956cd30a..4ed1ced9a3 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -274,6 +274,19 @@ async def asgi(receive: Receive, send: Send) -> None: assert websocket.should_close.is_set() +def test_client(test_client_factory: TestClientFactory) -> None: + async def app(scope: Scope, receive: Receive, send: Send) -> None: + client = scope.get("client") + assert client is not None + host, port = client + response = JSONResponse({"host": host, "port": port}) + await response(scope, receive, send) + + client = test_client_factory(app) + response = client.get("/") + assert response.json() == {"host": "testclient", "port": 50000} + + @pytest.mark.parametrize("param", ("2020-07-14T00:00:00+00:00", "España", "voilà")) def test_query_params(test_client_factory: TestClientFactory, param: str) -> None: def homepage(request: Request) -> Response: From 85d35737c7053bee489c438467b18a9108b23b93 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 2 Mar 2024 05:35:35 -0700 Subject: [PATCH 3/3] Bump the python-packages group with 7 updates (#2532) * Bump the python-packages group with 7 updates Bumps the python-packages group with 7 updates: | Package | From | To | | --- | --- | --- | | [coverage](https://github.com/nedbat/coveragepy) | `7.4.1` | `7.4.3` | | [ruff](https://github.com/astral-sh/ruff) | `0.1.15` | `0.3.0` | | [typing-extensions](https://github.com/python/typing_extensions) | `4.9.0` | `4.10.0` | | [pytest](https://github.com/pytest-dev/pytest) | `8.0.0` | `8.0.2` | | [mkdocs-material](https://github.com/squidfunk/mkdocs-material) | `9.5.6` | `9.5.12` | | [build](https://github.com/pypa/build) | `1.0.3` | `1.1.1` | | [twine](https://github.com/pypa/twine) | `4.0.2` | `5.0.0` | Updates `coverage` from 7.4.1 to 7.4.3 - [Release notes](https://github.com/nedbat/coveragepy/releases) - [Changelog](https://github.com/nedbat/coveragepy/blob/master/CHANGES.rst) - [Commits](https://github.com/nedbat/coveragepy/compare/7.4.1...7.4.3) Updates `ruff` from 0.1.15 to 0.3.0 - [Release notes](https://github.com/astral-sh/ruff/releases) - [Changelog](https://github.com/astral-sh/ruff/blob/main/CHANGELOG.md) - [Commits](https://github.com/astral-sh/ruff/compare/v0.1.15...v0.3.0) Updates `typing-extensions` from 4.9.0 to 4.10.0 - [Release notes](https://github.com/python/typing_extensions/releases) - [Changelog](https://github.com/python/typing_extensions/blob/main/CHANGELOG.md) - [Commits](https://github.com/python/typing_extensions/compare/4.9.0...4.10.0) Updates `pytest` from 8.0.0 to 8.0.2 - [Release notes](https://github.com/pytest-dev/pytest/releases) - [Changelog](https://github.com/pytest-dev/pytest/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pytest-dev/pytest/compare/8.0.0...8.0.2) Updates `mkdocs-material` from 9.5.6 to 9.5.12 - [Release notes](https://github.com/squidfunk/mkdocs-material/releases) - [Changelog](https://github.com/squidfunk/mkdocs-material/blob/master/CHANGELOG) - [Commits](https://github.com/squidfunk/mkdocs-material/compare/9.5.6...9.5.12) Updates `build` from 1.0.3 to 1.1.1 - [Release notes](https://github.com/pypa/build/releases) - [Changelog](https://github.com/pypa/build/blob/main/CHANGELOG.rst) - [Commits](https://github.com/pypa/build/compare/1.0.3...1.1.1) Updates `twine` from 4.0.2 to 5.0.0 - [Release notes](https://github.com/pypa/twine/releases) - [Changelog](https://github.com/pypa/twine/blob/main/docs/changelog.rst) - [Commits](https://github.com/pypa/twine/compare/4.0.2...5.0.0) --- updated-dependencies: - dependency-name: coverage dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python-packages - dependency-name: ruff dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python-packages - dependency-name: typing-extensions dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python-packages - dependency-name: pytest dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python-packages - dependency-name: mkdocs-material dependency-type: direct:production update-type: version-update:semver-patch dependency-group: python-packages - dependency-name: build dependency-type: direct:production update-type: version-update:semver-minor dependency-group: python-packages - dependency-name: twine dependency-type: direct:production update-type: version-update:semver-major dependency-group: python-packages ... Signed-off-by: dependabot[bot] * Update requirements.txt --------- Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Marcelo Trylesinski --- requirements.txt | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/requirements.txt b/requirements.txt index d864321a65..5652a865e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,22 +2,22 @@ -e .[full] # Testing -coverage==7.4.1 +coverage==7.4.3 importlib-metadata==7.0.1 mypy==1.8.0 ruff==0.1.15 -typing_extensions==4.9.0 +typing_extensions==4.10.0 types-contextvars==2.4.7.3 types-PyYAML==6.0.12.12 types-dataclasses==0.6.6 -pytest==8.0.0 +pytest==8.0.2 trio==0.24.0 # Documentation mkdocs==1.5.3 -mkdocs-material==9.5.6 +mkdocs-material==9.5.12 mkautodoc==0.2.0 # Packaging -build==1.0.3 -twine==4.0.2 +build==1.1.1 +twine==5.0.0