From abb04a512496206de279225340ed022852fbf51f Mon Sep 17 00:00:00 2001
From: pgjones <philip.graham.jones@googlemail.com>
Date: Mon, 23 Dec 2024 13:45:22 +0000
Subject: [PATCH] Support max_form_parts and max_form_memory_size

These allow greater control over safer form parsing with the former
limiting the number of parts and the latter limiting any individual
(data) parts maximum size in bytes. The default values are taken from
Flask.
---
 src/quart/app.py              |  7 +++--
 src/quart/formparser.py       | 51 ++++++++++++++++++++++++-----------
 src/quart/wrappers/base.py    |  9 -------
 src/quart/wrappers/request.py | 47 ++++++++++++++++++++++++++++++++
 tests/test_formparser.py      | 10 +++++++
 5 files changed, 97 insertions(+), 27 deletions(-)

diff --git a/src/quart/app.py b/src/quart/app.py
index 03813e6..6a3e422 100644
--- a/src/quart/app.py
+++ b/src/quart/app.py
@@ -248,6 +248,8 @@ class Quart(App):
             "EXPLAIN_TEMPLATE_LOADING": False,
             "MAX_CONTENT_LENGTH": 16 * 1024 * 1024,  # 16 MB Limit
             "MAX_COOKIE_SIZE": 4093,
+            "MAX_FORM_MEMORY_SIZE": 500_000,
+            "MAX_FORM_PARTS": 1_000,
             "PERMANENT_SESSION_LIFETIME": timedelta(days=31),
             # Replaces PREFERRED_URL_SCHEME to allow for WebSocket scheme
             "PREFER_SECURE_URLS": False,
@@ -1130,8 +1132,9 @@ async def handle_websocket_exception(
 
     def log_exception(
         self,
-        exception_info: tuple[type, BaseException, TracebackType]
-        | tuple[None, None, None],
+        exception_info: (
+            tuple[type, BaseException, TracebackType] | tuple[None, None, None]
+        ),
     ) -> None:
         """Log a exception to the :attr:`logger`.
 
diff --git a/src/quart/formparser.py b/src/quart/formparser.py
index e8ea6e1..50e7ee4 100644
--- a/src/quart/formparser.py
+++ b/src/quart/formparser.py
@@ -43,15 +43,20 @@ class FormDataParser:
 
     def __init__(
         self,
-        stream_factory: StreamFactory = default_stream_factory,
-        max_form_memory_size: int | None = None,
-        max_content_length: int | None = None,
+        *,
         cls: type[MultiDict] | None = MultiDict,
+        max_content_length: int | None = None,
+        max_form_memory_size: int | None = None,
+        max_form_parts: int | None = None,
         silent: bool = True,
+        stream_factory: StreamFactory = default_stream_factory,
     ) -> None:
-        self.stream_factory = stream_factory
         self.cls = cls
+        self.max_content_length = max_content_length
+        self.max_form_memory_size = max_form_memory_size
+        self.max_form_parts = max_form_parts
         self.silent = silent
+        self.stream_factory = stream_factory
 
     def get_parse_func(
         self, mimetype: str, options: dict[str, str]
@@ -87,9 +92,12 @@ async def _parse_multipart(
         options: dict[str, str],
     ) -> tuple[MultiDict, MultiDict]:
         parser = MultiPartParser(
-            self.stream_factory,
             cls=self.cls,
             file_storage_cls=self.file_storage_class,
+            max_content_length=self.max_content_length,
+            max_form_memory_size=self.max_form_memory_size,
+            max_form_parts=self.max_form_parts,
+            stream_factory=self.stream_factory,
         )
         boundary = options.get("boundary", "").encode("ascii")
 
@@ -105,10 +113,14 @@ async def _parse_urlencoded(
         content_length: int | None,
         options: dict[str, str],
     ) -> tuple[MultiDict, MultiDict]:
-        form = parse_qsl(
-            (await body).decode(),
-            keep_blank_values=True,
-        )
+        try:
+            form = parse_qsl(
+                (await body).decode(),
+                keep_blank_values=True,
+                max_num_fields=self.max_form_parts,
+            )
+        except ValueError:
+            raise RequestEntityTooLarge() from None
         return self.cls(form), self.cls()
 
     parse_functions: dict[str, ParserFunc] = {
@@ -121,17 +133,22 @@ async def _parse_urlencoded(
 class MultiPartParser:
     def __init__(
         self,
-        stream_factory: StreamFactory = default_stream_factory,
-        max_form_memory_size: int | None = None,
-        cls: type[MultiDict] = MultiDict,
+        *,
         buffer_size: int = 64 * 1024,
+        cls: type[MultiDict] = MultiDict,
         file_storage_cls: type[FileStorage] = FileStorage,
+        max_content_length: int | None = None,
+        max_form_memory_size: int | None = None,
+        max_form_parts: int | None = None,
+        stream_factory: StreamFactory = default_stream_factory,
     ) -> None:
-        self.max_form_memory_size = max_form_memory_size
-        self.stream_factory = stream_factory
-        self.cls = cls
         self.buffer_size = buffer_size
+        self.cls = cls
         self.file_storage_cls = file_storage_cls
+        self.max_content_length = max_content_length
+        self.max_form_memory_size = max_form_memory_size
+        self.max_form_parts = max_form_parts
+        self.stream_factory = stream_factory
 
     def fail(self, message: str) -> NoReturn:
         raise ValueError(message)
@@ -172,7 +189,9 @@ async def parse(
         container: IO[bytes] | list[bytes]
         _write: Callable[[bytes], Any]
 
-        parser = MultipartDecoder(boundary, self.max_form_memory_size)
+        parser = MultipartDecoder(
+            boundary, self.max_content_length, max_parts=self.max_form_parts
+        )
 
         fields = []
         files = []
diff --git a/src/quart/wrappers/base.py b/src/quart/wrappers/base.py
index 6f458ce..e33caab 100644
--- a/src/quart/wrappers/base.py
+++ b/src/quart/wrappers/base.py
@@ -8,7 +8,6 @@
 from werkzeug.sansio.request import Request as SansIORequest
 
 from .. import json
-from ..globals import current_app
 
 if TYPE_CHECKING:
     from ..routing import QuartRule  # noqa
@@ -73,14 +72,6 @@ def __init__(
         self.http_version = http_version
         self.scope = scope
 
-    @property
-    def max_content_length(self) -> int | None:
-        """Read-only view of the ``MAX_CONTENT_LENGTH`` config key."""
-        if current_app:
-            return current_app.config["MAX_CONTENT_LENGTH"]
-        else:
-            return None
-
     @property
     def endpoint(self) -> str | None:
         """Returns the corresponding endpoint matched for this request.
diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py
index cf50b6f..1cf2b2f 100644
--- a/src/quart/wrappers/request.py
+++ b/src/quart/wrappers/request.py
@@ -141,6 +141,9 @@ class Request(BaseRequestWebsocket):
     body_class = Body
     form_data_parser_class = FormDataParser
     lock_class = asyncio.Lock
+    _max_content_length: int | None = None
+    _max_form_memory_size: int | None = None
+    _max_form_parts: int | None = None
 
     def __init__(
         self,
@@ -189,6 +192,48 @@ def __init__(
         self._parsing_lock = self.lock_class()
         self._send_push_promise = send_push_promise
 
+    @property
+    def max_content_length(self) -> int | None:
+        if self._max_content_length is not None:
+            return self._max_content_length
+
+        if current_app:
+            return current_app.config["MAX_CONTENT_LENGTH"]
+
+        return None
+
+    @max_content_length.setter
+    def max_content_length(self, value: int | None) -> None:
+        self._max_content_length = value
+
+    @property
+    def max_form_memory_size(self) -> int | None:
+        if self._max_form_memory_size is not None:
+            return self._max_form_memory_size
+
+        if current_app:
+            return current_app.config["MAX_FORM_MEMORY_SIZE"]
+
+        return None
+
+    @max_form_memory_size.setter
+    def max_form_memory_size(self, value: int | None) -> None:
+        self._max_form_memory_size = value
+
+    @property
+    def max_form_parts(self) -> int | None:
+        if self._max_form_parts is not None:
+            return self._max_form_parts
+
+        if current_app:
+            return current_app.config["MAX_FORM_PARTS"]
+
+        return None
+
+    @max_form_parts.setter
+    def max_form_parts(self, value: int | None) -> None:
+        self._max_form_parts = value
+
     @property
     async def stream(self) -> NoReturn:
         raise NotImplementedError("Use body instead")
@@ -284,6 +329,8 @@ async def files(self) -> MultiDict:
     def make_form_data_parser(self) -> FormDataParser:
         return self.form_data_parser_class(
             max_content_length=self.max_content_length,
+            max_form_memory_size=self.max_form_memory_size,
+            max_form_parts=self.max_form_parts,
             cls=self.parameter_storage_class,
         )
 
diff --git a/tests/test_formparser.py b/tests/test_formparser.py
index c5e85f2..2238ea7 100644
--- a/tests/test_formparser.py
+++ b/tests/test_formparser.py
@@ -3,6 +3,7 @@
 import pytest
 from werkzeug.exceptions import RequestEntityTooLarge
 
+from quart.formparser import FormDataParser
 from quart.formparser import MultiPartParser
 from quart.wrappers.request import Body
 
@@ -19,3 +20,12 @@ async def test_multipart_max_form_memory_size() -> None:
 
     with pytest.raises(RequestEntityTooLarge):
         await parser.parse(body, b"bound", 0)
+
+
+async def test_formparser_max_num_parts() -> None:
+    parser = FormDataParser(max_form_parts=1)
+    body = Body(None, None)
+    body.set_result(b"param1=data1&param2=data2&param3=data3")
+
+    with pytest.raises(RequestEntityTooLarge):
+        await parser.parse(body, "application/x-url-encoded", None)