diff --git a/src/quart/app.py b/src/quart/app.py index 03813e6..6a3e422 100644 --- a/src/quart/app.py +++ b/src/quart/app.py @@ -248,6 +248,8 @@ class Quart(App): "EXPLAIN_TEMPLATE_LOADING": False, "MAX_CONTENT_LENGTH": 16 * 1024 * 1024, # 16 MB Limit "MAX_COOKIE_SIZE": 4093, + "MAX_FORM_MEMORY_SIZE": 500_000, + "MAX_FORM_PARTS": 1_000, "PERMANENT_SESSION_LIFETIME": timedelta(days=31), # Replaces PREFERRED_URL_SCHEME to allow for WebSocket scheme "PREFER_SECURE_URLS": False, @@ -1130,8 +1132,9 @@ async def handle_websocket_exception( def log_exception( self, - exception_info: tuple[type, BaseException, TracebackType] - | tuple[None, None, None], + exception_info: ( + tuple[type, BaseException, TracebackType] | tuple[None, None, None] + ), ) -> None: """Log a exception to the :attr:`logger`. diff --git a/src/quart/formparser.py b/src/quart/formparser.py index e8ea6e1..50e7ee4 100644 --- a/src/quart/formparser.py +++ b/src/quart/formparser.py @@ -43,15 +43,20 @@ class FormDataParser: def __init__( self, - stream_factory: StreamFactory = default_stream_factory, - max_form_memory_size: int | None = None, - max_content_length: int | None = None, + *, cls: type[MultiDict] | None = MultiDict, + max_content_length: int | None = None, + max_form_memory_size: int | None = None, + max_form_parts: int | None = None, silent: bool = True, + stream_factory: StreamFactory = default_stream_factory, ) -> None: - self.stream_factory = stream_factory self.cls = cls + self.max_content_length = max_content_length + self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts self.silent = silent + self.stream_factory = stream_factory def get_parse_func( self, mimetype: str, options: dict[str, str] @@ -87,9 +92,12 @@ async def _parse_multipart( options: dict[str, str], ) -> tuple[MultiDict, MultiDict]: parser = MultiPartParser( - self.stream_factory, cls=self.cls, file_storage_cls=self.file_storage_class, + max_content_length=self.max_content_length, + max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, + stream_factory=self.stream_factory, ) boundary = options.get("boundary", "").encode("ascii") @@ -105,10 +113,14 @@ async def _parse_urlencoded( content_length: int | None, options: dict[str, str], ) -> tuple[MultiDict, MultiDict]: - form = parse_qsl( - (await body).decode(), - keep_blank_values=True, - ) + try: + form = parse_qsl( + (await body).decode(), + keep_blank_values=True, + max_num_fields=self.max_form_parts, + ) + except ValueError: + raise RequestEntityTooLarge() from None return self.cls(form), self.cls() parse_functions: dict[str, ParserFunc] = { @@ -121,17 +133,22 @@ async def _parse_urlencoded( class MultiPartParser: def __init__( self, - stream_factory: StreamFactory = default_stream_factory, - max_form_memory_size: int | None = None, - cls: type[MultiDict] = MultiDict, + *, buffer_size: int = 64 * 1024, + cls: type[MultiDict] = MultiDict, file_storage_cls: type[FileStorage] = FileStorage, + max_content_length: int | None = None, + max_form_memory_size: int | None = None, + max_form_parts: int | None = None, + stream_factory: StreamFactory = default_stream_factory, ) -> None: - self.max_form_memory_size = max_form_memory_size - self.stream_factory = stream_factory - self.cls = cls self.buffer_size = buffer_size + self.cls = cls self.file_storage_cls = file_storage_cls + self.max_content_length = max_content_length + self.max_form_memory_size = max_form_memory_size + self.max_form_parts = max_form_parts + self.stream_factory = stream_factory def fail(self, message: str) -> NoReturn: raise ValueError(message) @@ -172,7 +189,9 @@ async def parse( container: IO[bytes] | list[bytes] _write: Callable[[bytes], Any] - parser = MultipartDecoder(boundary, self.max_form_memory_size) + parser = MultipartDecoder( + boundary, self.max_content_length, max_parts=self.max_form_parts + ) fields = [] files = [] diff --git a/src/quart/wrappers/base.py b/src/quart/wrappers/base.py index 6f458ce..e33caab 100644 --- a/src/quart/wrappers/base.py +++ b/src/quart/wrappers/base.py @@ -8,7 +8,6 @@ from werkzeug.sansio.request import Request as SansIORequest from .. import json -from ..globals import current_app if TYPE_CHECKING: from ..routing import QuartRule # noqa @@ -73,14 +72,6 @@ def __init__( self.http_version = http_version self.scope = scope - @property - def max_content_length(self) -> int | None: - """Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" - if current_app: - return current_app.config["MAX_CONTENT_LENGTH"] - else: - return None - @property def endpoint(self) -> str | None: """Returns the corresponding endpoint matched for this request. diff --git a/src/quart/wrappers/request.py b/src/quart/wrappers/request.py index cf50b6f..1cf2b2f 100644 --- a/src/quart/wrappers/request.py +++ b/src/quart/wrappers/request.py @@ -141,6 +141,9 @@ class Request(BaseRequestWebsocket): body_class = Body form_data_parser_class = FormDataParser lock_class = asyncio.Lock + _max_content_length: int | None = None + _max_form_memory_size: int | None = None + _max_form_parts: int | None = None def __init__( self, @@ -189,6 +192,48 @@ def __init__( self._parsing_lock = self.lock_class() self._send_push_promise = send_push_promise + @property + def max_content_length(self) -> int | None: + if self._max_content_length is not None: + return self._max_content_length + + if current_app: + return current_app.config["MAX_CONTENT_LENGTH"] + + return None + + @max_content_length.setter + def max_content_length(self, value: int | None) -> None: + self._max_content_length = value + + @property + def max_form_memory_size(self) -> int | None: + if self._max_form_memory_size is not None: + return self._max_form_memory_size + + if current_app: + return current_app.config["MAX_FORM_MEMORY_SIZE"] + + return None + + @max_form_memory_size.setter + def max_form_memory_size(self, value: int | None) -> None: + self._max_form_memory_size = value + + @property + def max_form_parts(self) -> int | None: + if self._max_form_parts is not None: + return self._max_form_parts + + if current_app: + return current_app.config["MAX_FORM_PARTS"] + + return None + + @max_form_parts.setter + def max_form_parts(self, value: int | None) -> None: + self._max_form_parts = value + @property async def stream(self) -> NoReturn: raise NotImplementedError("Use body instead") @@ -284,6 +329,8 @@ async def files(self) -> MultiDict: def make_form_data_parser(self) -> FormDataParser: return self.form_data_parser_class( max_content_length=self.max_content_length, + max_form_memory_size=self.max_form_memory_size, + max_form_parts=self.max_form_parts, cls=self.parameter_storage_class, ) diff --git a/tests/test_formparser.py b/tests/test_formparser.py index c5e85f2..2238ea7 100644 --- a/tests/test_formparser.py +++ b/tests/test_formparser.py @@ -3,6 +3,7 @@ import pytest from werkzeug.exceptions import RequestEntityTooLarge +from quart.formparser import FormDataParser from quart.formparser import MultiPartParser from quart.wrappers.request import Body @@ -19,3 +20,12 @@ async def test_multipart_max_form_memory_size() -> None: with pytest.raises(RequestEntityTooLarge): await parser.parse(body, b"bound", 0) + + +async def test_formparser_max_num_parts() -> None: + parser = FormDataParser(max_form_parts=1) + body = Body(None, None) + body.set_result(b"param1=data1¶m2=data2¶m3=data3") + + with pytest.raises(RequestEntityTooLarge): + await parser.parse(body, "application/x-url-encoded", None)