From 4af46c931c64a4e32c37faec9ce531cca57e4cf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Tue, 16 Jan 2024 16:51:30 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Update=20`root=5Fpath`=20h?= =?UTF-8?q?andling=20(from=20`--root-path`=20CLI=20option)=20to=20include?= =?UTF-8?q?=20the=20root=20path=20prefix=20in=20the=20full=20ASGI=20`path`?= =?UTF-8?q?=20as=20per=20the=20ASGI=20spec=20(#2213)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ Update root-path handling to include it in the full path as per the ASGI spe, related to Starlette 0.35.0 * ✅ Update tests for root_path, ensure it's added to the prefix of the path in the ASGI scope * ♻️ Update the (deprecated) WSGIMiddleware to follow closely the ASGI spec * ✅ Update tests for WSGIMiddleware * 🎨 Fix format in tests * Update tests/protocols/test_http.py Co-authored-by: Marcelo Trylesinski * Update uvicorn/protocols/http/httptools_impl.py Co-authored-by: Marcelo Trylesinski --------- Co-authored-by: Marcelo Trylesinski --- tests/middleware/test_wsgi.py | 7 ++++--- tests/protocols/test_http.py | 13 ++++++++----- uvicorn/middleware/wsgi.py | 8 ++++++-- uvicorn/protocols/http/h11_impl.py | 7 +++++-- uvicorn/protocols/http/httptools_impl.py | 6 ++++-- uvicorn/protocols/websockets/websockets_impl.py | 7 +++++-- uvicorn/protocols/websockets/wsproto_impl.py | 7 +++++-- 7 files changed, 37 insertions(+), 18 deletions(-) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 34730ec92..6d9c17d29 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -122,11 +122,11 @@ def test_build_environ_encoding() -> None: scope: "HTTPScope" = { "asgi": {"version": "3.0", "spec_version": "2.0"}, "scheme": "http", - "raw_path": b"/\xe6\x96\x87", + "raw_path": b"/\xe6\x96\x87%2Fall", "type": "http", "http_version": "1.1", "method": "GET", - "path": "/文", + "path": "/文/all", "root_path": "/文", "client": None, "server": None, @@ -140,5 +140,6 @@ def test_build_environ_encoding() -> None: "more_body": False, } environ = wsgi.build_environ(scope, message, io.BytesIO(b"")) - assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1") + assert environ["SCRIPT_NAME"] == "/文".encode("utf8").decode("latin-1") + assert environ["PATH_INFO"] == "/all".encode("utf8").decode("latin-1") assert environ["HTTP_KEY"] == "value1,value2" diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index ca06b33a6..a422404f9 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -630,15 +630,18 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable async def test_root_path(http_protocol_cls: "Type[HttpToolsProtocol | H11Protocol]"): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "http" - path = scope.get("root_path", "") + scope["path"] - response = Response("Path: " + path, media_type="text/plain") + root_path = scope.get("root_path", "") + path = scope["path"] + response = Response( + f"root_path={root_path} path={path}", media_type="text/plain" + ) await response(scope, receive, send) protocol = get_connected_protocol(app, http_protocol_cls, root_path="/app") protocol.data_received(SIMPLE_GET_REQUEST) await protocol.loop.run_one() assert b"HTTP/1.1 200 OK" in protocol.transport.buffer - assert b"Path: /app/" in protocol.transport.buffer + assert b"root_path=/app path=/app/" in protocol.transport.buffer @pytest.mark.anyio @@ -647,8 +650,8 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable assert scope["type"] == "http" path = scope["path"] raw_path = scope.get("raw_path", None) - assert "/one/two" == path - assert b"/one%2Ftwo" == raw_path + assert "/app/one/two" == path + assert b"/app/one%2Ftwo" == raw_path response = Response("Done", media_type="text/plain") await response(scope, receive, send) diff --git a/uvicorn/middleware/wsgi.py b/uvicorn/middleware/wsgi.py index 381eca68e..b181e0f16 100644 --- a/uvicorn/middleware/wsgi.py +++ b/uvicorn/middleware/wsgi.py @@ -28,10 +28,14 @@ def build_environ( """ Builds a scope and request message into a WSGI environ object. """ + script_name = scope.get("root_path", "").encode("utf8").decode("latin1") + path_info = scope["path"].encode("utf8").decode("latin1") + if path_info.startswith(script_name): + path_info = path_info[len(script_name) :] environ = { "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": "", - "PATH_INFO": scope["path"].encode("utf8").decode("latin1"), + "SCRIPT_NAME": script_name, + "PATH_INFO": path_info, "QUERY_STRING": scope["query_string"].decode("ascii"), "SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"], "wsgi.version": (1, 0), diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index c6b2f2781..bee83122d 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -198,6 +198,9 @@ def handle_events(self) -> None: elif isinstance(event, h11.Request): self.headers = [(key.lower(), value) for key, value in event.headers] raw_path, _, query_string = event.target.partition(b"?") + path = unquote(raw_path.decode("ascii")) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path self.scope = { "type": "http", "asgi": { @@ -210,8 +213,8 @@ def handle_events(self) -> None: "scheme": self.scheme, # type: ignore[typeddict-item] "method": event.method.decode("ascii"), "root_path": self.root_path, - "path": unquote(raw_path.decode("ascii")), - "raw_path": raw_path, + "path": full_path, + "raw_path": full_raw_path, "query_string": query_string, "headers": self.headers, "state": self.app_state.copy(), diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 99c868c31..90a3bd9ff 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -260,8 +260,10 @@ def on_headers_complete(self) -> None: path = raw_path.decode("ascii") if "%" in path: path = urllib.parse.unquote(path) - self.scope["path"] = path - self.scope["raw_path"] = raw_path + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path + self.scope["path"] = full_path + self.scope["raw_path"] = full_raw_path self.scope["query_string"] = parsed_url.query or b"" # Handle 503 responses when 'limit_concurrency' is exceeded. diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 3f04c1dd5..880d9214d 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -184,6 +184,9 @@ async def process_request( (name.encode("ascii"), value.encode("ascii", errors="surrogateescape")) for name, value in headers.raw_items() ] + path = unquote(path_portion) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + path_portion.encode("ascii") self.scope = { "type": "websocket", @@ -193,8 +196,8 @@ async def process_request( "server": self.server, "client": self.client, "root_path": self.root_path, - "path": unquote(path_portion), - "raw_path": path_portion.encode("ascii"), + "path": full_path, + "raw_path": full_raw_path, "query_string": query_string.encode("ascii"), "headers": asgi_headers, "subprotocols": subprotocols, diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 7929f3a91..2409e90ef 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -164,6 +164,9 @@ def handle_connect(self, event: events.Request) -> None: headers = [(b"host", event.host.encode())] headers += [(key.lower(), value) for key, value in event.extra_headers] raw_path, _, query_string = event.target.partition("?") + path = unquote(raw_path) + full_path = self.root_path + path + full_raw_path = self.root_path.encode("ascii") + raw_path.encode("ascii") self.scope: "WebSocketScope" = { "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, @@ -172,8 +175,8 @@ def handle_connect(self, event: events.Request) -> None: "server": self.server, "client": self.client, "root_path": self.root_path, - "path": unquote(raw_path), - "raw_path": raw_path.encode("ascii"), + "path": full_path, + "raw_path": full_raw_path, "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.subprotocols,