diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index d0f2b2a5e..a2735fc78 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -20,19 +20,8 @@ ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL -from uvicorn.protocols.http.flow_control import ( - CLOSE_HEADER, - HIGH_WATER_LIMIT, - FlowControl, - service_unavailable, -) -from uvicorn.protocols.utils import ( - get_client_addr, - get_local_addr, - get_path_with_query_string, - get_remote_addr, - is_ssl, -) +from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable +from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl from uvicorn.server import ServerState diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 60debaf8f..e54966609 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -21,19 +21,8 @@ ) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL -from uvicorn.protocols.http.flow_control import ( - CLOSE_HEADER, - HIGH_WATER_LIMIT, - FlowControl, - service_unavailable, -) -from uvicorn.protocols.utils import ( - get_client_addr, - get_local_addr, - get_path_with_query_string, - get_remote_addr, - is_ssl, -) +from uvicorn.protocols.http.flow_control import CLOSE_HEADER, HIGH_WATER_LIMIT, FlowControl, service_unavailable +from uvicorn.protocols.utils import get_client_addr, get_local_addr, get_path_with_query_string, get_remote_addr, is_ssl from uvicorn.server import ServerState HEADER_RE = re.compile(b'[\x00-\x1f\x7f()<>@,;:[]={} \t\\"]') diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 6d098d5af..4b71240fd 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -7,14 +7,17 @@ from urllib.parse import unquote import websockets +import websockets.legacy.handshake from websockets.datastructures import Headers from websockets.exceptions import ConnectionClosed +from websockets.extensions.base import ServerExtensionFactory from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.legacy.server import HTTPResponse from websockets.server import WebSocketServerProtocol from websockets.typing import Subprotocol from uvicorn._types import ( + ASGI3Application, ASGISendEvent, WebSocketAcceptEvent, WebSocketCloseEvent, @@ -53,6 +56,7 @@ def is_serving(self) -> bool: class WebSocketProtocol(WebSocketServerProtocol): extra_headers: list[tuple[str, str]] + logger: logging.Logger | logging.LoggerAdapter[Any] def __init__( self, @@ -65,7 +69,7 @@ def __init__( config.load() self.config = config - self.app = config.loaded_app + self.app = cast(ASGI3Application, config.loaded_app) self.loop = _loop or asyncio.get_event_loop() self.root_path = config.root_path self.app_state = app_state @@ -92,7 +96,7 @@ def __init__( self.ws_server: Server = Server() # type: ignore[assignment] - extensions = [] + extensions: list[ServerExtensionFactory] = [] if self.config.ws_per_message_deflate: extensions.append(ServerPerMessageDeflateFactory()) @@ -147,10 +151,10 @@ def shutdown(self) -> None: self.send_500_response() self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) - async def process_request(self, path: str, headers: Headers) -> HTTPResponse | None: + async def process_request(self, path: str, request_headers: Headers) -> HTTPResponse | None: """ This hook is called to determine if the websocket should return an HTTP response and close. @@ -161,15 +165,15 @@ async def process_request(self, path: str, headers: Headers) -> HTTPResponse | N """ path_portion, _, query_string = path.partition("?") - websockets.legacy.handshake.check_request(headers) + websockets.legacy.handshake.check_request(request_headers) - subprotocols = [] - for header in headers.get_all("Sec-WebSocket-Protocol"): + subprotocols: list[str] = [] + for header in request_headers.get_all("Sec-WebSocket-Protocol"): subprotocols.extend([token.strip() for token in header.split(",")]) asgi_headers = [ (name.encode("ascii"), value.encode("ascii", errors="surrogateescape")) - for name, value in headers.raw_items() + for name, value in request_headers.raw_items() ] path = unquote(path_portion) full_path = self.root_path + path @@ -237,14 +241,13 @@ async def run_asgi(self) -> None: termination states. """ try: - result = await self.app(self.scope, self.asgi_receive, self.asgi_send) + result = await self.app(self.scope, self.asgi_receive, self.asgi_send) # type: ignore[func-returns-value] except ClientDisconnected: self.closed_event.set() self.transport.close() - except BaseException as exc: + except BaseException: self.closed_event.set() - msg = "Exception in ASGI application\n" - self.logger.error(msg, exc_info=exc) + self.logger.exception("Exception in ASGI application\n") if not self.handshake_started_event.is_set(): self.send_500_response() else: @@ -253,13 +256,11 @@ async def run_asgi(self) -> None: else: self.closed_event.set() if not self.handshake_started_event.is_set(): - msg = "ASGI callable returned without sending handshake." - self.logger.error(msg) + self.logger.error("ASGI callable returned without sending handshake.") self.send_500_response() self.transport.close() elif result is not None: - msg = "ASGI callable should return None, but returned '%s'." - self.logger.error(msg, result) + self.logger.error("ASGI callable should return None, but returned '%s'.", result) await self.handshake_completed_event.wait() self.transport.close() diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index c92625277..86c251a4f 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -3,7 +3,7 @@ import asyncio import logging import typing -from typing import Literal +from typing import Literal, cast from urllib.parse import unquote import wsproto @@ -13,6 +13,7 @@ from wsproto.utilities import LocalProtocolError, RemoteProtocolError from uvicorn._types import ( + ASGI3Application, ASGISendEvent, WebSocketAcceptEvent, WebSocketCloseEvent, @@ -46,7 +47,7 @@ def __init__( config.load() self.config = config - self.app = config.loaded_app + self.app = cast(ASGI3Application, config.loaded_app) self.loop = _loop or asyncio.get_event_loop() self.logger = logging.getLogger("uvicorn.error") self.root_path = config.root_path @@ -156,7 +157,7 @@ def shutdown(self) -> None: self.send_500_response() self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) # Event handlers @@ -220,7 +221,7 @@ def handle_ping(self, event: events.Ping) -> None: def send_500_response(self) -> None: if self.response_started or self.handshake_complete: return # we cannot send responses anymore - headers = [ + headers: list[tuple[bytes, bytes]] = [ (b"content-type", b"text/plain; charset=utf-8"), (b"connection", b"close"), ] @@ -230,7 +231,7 @@ def send_500_response(self) -> None: async def run_asgi(self) -> None: try: - result = await self.app(self.scope, self.receive, self.send) + result = await self.app(self.scope, self.receive, self.send) # type: ignore[func-returns-value] except ClientDisconnected: self.transport.close() except BaseException: @@ -239,13 +240,11 @@ async def run_asgi(self) -> None: self.transport.close() else: if not self.handshake_complete: - msg = "ASGI callable returned without completing handshake." - self.logger.error(msg) + self.logger.error("ASGI callable returned without completing handshake.") self.send_500_response() self.transport.close() elif result is not None: - msg = "ASGI callable should return None, but returned '%s'." - self.logger.error(msg, result) + self.logger.error("ASGI callable should return None, but returned '%s'.", result) self.transport.close() async def send(self, message: ASGISendEvent) -> None: