From abb04a512496206de279225340ed022852fbf51f Mon Sep 17 00:00:00 2001 From: pgjones <philip.graham.jones@googlemail.com> Date: Mon, 23 Dec 2024 13:45:22 +0000 Subject: [PATCH] Support max_form_parts and max_form_memory_size These allow greater control over safer form parsing with the former limiting the number of parts and the latter limiting any individual (data) parts maximum size in bytes. The default values are taken from Flask. --- src/quart/app.py | 7 +++-- src/quart/formparser.py | 51 ++++++++++++++++++++++++----------- src/quart/wrappers/base.py | 9 ------- src/quart/wrappers/request.py | 47 ++++++++++++++++++++++++++++++++ tests/test_formparser.py | 10 +++++++ 5 files changed, 97 insertions(+), 27 deletions(-) diff --git a/src/quart/app.py b/src/quart/app.py index 03813e6..6a3e422 100644 --- a/src/quart/app.py +++ b/src/quart/app.py @@ -248,6 +248,8 @@ class Quart(App): "EXPLAIN_TEMPLATE_LOADING": False, "MAX_CONTENT_LENGTH": 16 * 1024 * 1024, # 16 MB Limit "MAX_COOKIE_SIZE": 4093, + "MAX_FORM_MEMORY_SIZE": 500_000, + "MAX_FORM_PARTS": 1_000, "PERMANENT_SESSION_LIFETIME": timedelta(days=31), # Replaces PREFERRED_URL_SCHEME to allow for WebSocket scheme "PREFER_SECURE_URLS": False, @@ -1130,8 +1132,9 @@ async def handle_websocket_exception( def log_exception( self, - exception_info: tuple[type, BaseException, TracebackType] - | tuple[None, None, None], + exception_info: ( + tuple[type, BaseException, TracebackType] | tuple[None, None, None] + ), ) -> None: """Log a exception to the :attr:`logger`. diff --git a/src/quart/formparser.py b/src/quart/formparser.py index e8ea6e1..50e7ee4 100644 --- a/src/quart/formparser.py +++ b/src/quart/formparser.py @@ -43,15 +43,20 @@ class FormDataParser: def __init__( self, - stream_factory: StreamFactory = default_stream_factory, - max_form_memory_size: int | None = None, - max_content_length: int | None = None, + *, cls: type[MultiDict] | None = MultiDict, + max_content_length: int | None = None, + max_form_memory_size: int | None = None, + max_form_parts: int | None = None, silent: bool = True, + stream_factory: StreamFactory = default_stream_factory, ) -> None: - self.stream_factory = stream_factory self.cls = cls + self.max_content_length = max_content_length + self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts self.silent = silent + self.stream_factory = stream_factory def get_parse_func( self, mimetype: str, options: dict[str, str] @@ -87,9 +92,12 @@ async def _parse_multipart( options: dict[str, str], ) -> tuple[MultiDict, MultiDict]: parser = MultiPartParser( - self.stream_factory, cls=self.cls, file_storage_cls=self.file_storage_class, + max_content_length=self.max_content_length, + max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, + stream_factory=self.stream_factory, ) boundary = options.get("boundary", "").encode("ascii") @@ -105,10 +113,14 @@ async def _parse_urlencoded( content_length: int | None, options: dict[str, str], ) -> tuple[MultiDict, MultiDict]: - form = parse_qsl( - (await body).decode(), - keep_blank_values=True, - ) + try: + form = parse_qsl( + (await body).decode(), + keep_blank_values=True, + max_num_fields=self.max_form_parts, + ) + except ValueError: + raise RequestEntityTooLarge() from None return self.cls(form), self.cls() parse_functions: dict[str, ParserFunc] = { @@ -121,17 +133,22 @@ async def _parse_urlencoded( class MultiPartParser: def __init__( self, - stream_factory: StreamFactory = default_stream_factory, - max_form_memory_size: int | None = None, - cls: type[MultiDict] = MultiDict, + *, buffer_size: int = 64 * 1024, + cls: type[MultiDict] = MultiDict, file_storage_cls: type[FileStorage] = FileStorage, + max_content_length: int | None = None, + max_form_memory_size: int | None = None, + max_form_parts: int | None = None, + stream_factory: StreamFactory = default_stream_factory, ) -> None: - self.max_form_memory_size = max_form_memory_size - self.stream_factory = stream_factory - self.cls = cls self.buffer_size = buffer_size + self.cls = cls self.file_storage_cls = file_storage_cls + self.max_content_length = max_content_length + self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts + self.stream_factory = stream_factory def fail(self, message: str) -> NoReturn: raise ValueError(message) @@ -172,7 +189,9 @@ async def parse( container: IO[bytes] | list[bytes] _write: Callable[[bytes], Any] - parser = MultipartDecoder(boundary, self.max_form_memory_size) + parser = MultipartDecoder( + boundary, self.max_content_length, max_parts=self.max_form_parts + ) fields = [] files = [] diff --git a/src/quart/wrappers/base.py b/src/quart/wrappers/base.py index 6f458ce..e33caab 100644 --- a/src/quart/wrappers/base.py +++ b/src/quart/wrappers/base.py @@ -8,7 +8,6 @@ from werkzeug.sansio.request import Request as SansIORequest from .. import json -from ..globals import current_app if TYPE_CHECKING: from ..routing import QuartRule # noqa @@ -73,14 +72,6 @@ def __init__( self.http_version = http_version self.scope = scope - @property - def max_content_length(self) -> int | None: - """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" - if current_app: - return current_app.config["MAX_CONTENT_LENGTH"] - else: - return None - @property def endpoint(self) -> str | None: """Returns the corresponding endpoint matched for this request. diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index cf50b6f..1cf2b2f 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -141,6 +141,9 @@ class Request(BaseRequestWebsocket): body_class = Body form_data_parser_class = FormDataParser lock_class = asyncio.Lock + _max_content_length: int | None = None + _max_form_memory_size: int | None = None + _max_form_parts: int | None = None def __init__( self, @@ -189,6 +192,48 @@ def __init__( self._parsing_lock = self.lock_class() self._send_push_promise = send_push_promise + @property + def max_content_length(self) -> int | None: + if self._max_content_length is not None: + return self._max_content_length + + if current_app: + return current_app.config["MAX_CONTENT_LENGTH"] + + return None + + @max_content_length.setter + def max_content_length(self, value: int | None) -> None: + self._max_content_length = value + + @property + def max_form_memory_size(self) -> int | None: + if self._max_form_memory_size is not None: + return self._max_form_memory_size + + if current_app: + return current_app.config["MAX_FORM_MEMORY_SIZE"] + + return None + + @max_form_memory_size.setter + def max_form_memory_size(self, value: int | None) -> None: + self._max_form_memory_size = value + + @property + def max_form_parts(self) -> int | None: + if self._max_form_parts is not None: + return self._max_form_parts + + if current_app: + return current_app.config["MAX_FORM_PARTS"] + + return None + + @max_form_parts.setter + def max_form_parts(self, value: int | None) -> None: + self._max_form_parts = value + @property async def stream(self) -> NoReturn: raise NotImplementedError("Use body instead") @@ -284,6 +329,8 @@ async def files(self) -> MultiDict: def make_form_data_parser(self) -> FormDataParser: return self.form_data_parser_class( max_content_length=self.max_content_length, + max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, cls=self.parameter_storage_class, ) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index c5e85f2..2238ea7 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -3,6 +3,7 @@ import pytest from werkzeug.exceptions import RequestEntityTooLarge +from quart.formparser import FormDataParser from quart.formparser import MultiPartParser from quart.wrappers.request import Body @@ -19,3 +20,12 @@ async def test_multipart_max_form_memory_size() -> None: with pytest.raises(RequestEntityTooLarge): await parser.parse(body, b"bound", 0) + + +async def test_formparser_max_num_parts() -> None: + parser = FormDataParser(max_form_parts=1) + body = Body(None, None) + body.set_result(b"param1=data1¶m2=data2¶m3=data3") + + with pytest.raises(RequestEntityTooLarge): + await parser.parse(body, "application/x-url-encoded", None)