From 6f8c93eec7f134b8210a7578b110c4f246d6429b Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski <marcelotryle@gmail.com>
Date: Sat, 10 Feb 2024 13:47:52 +0100
Subject: [PATCH] Add `TypedDict` callbacks

---
 .gitignore              |  2 +-
 multipart/multipart.py  | 75 ++++++++++++++++++++++++++++-------------
 tests/test_multipart.py | 16 ++++-----
 3 files changed, 60 insertions(+), 33 deletions(-)

diff --git a/.gitignore b/.gitignore
index 546cf0a..f7f7b71 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,7 +28,7 @@ lib64
 pip-log.txt
 
 # Unit test / coverage reports
-.coverage.*
+.coverage*
 .tox
 nosetests.xml
 
diff --git a/multipart/multipart.py b/multipart/multipart.py
index 651bfc1..ac2648e 100644
--- a/multipart/multipart.py
+++ b/multipart/multipart.py
@@ -9,11 +9,38 @@
 from enum import IntEnum
 from io import BytesIO
 from numbers import Number
-from typing import Dict, Tuple, Union
+from typing import TYPE_CHECKING
 
 from .decoders import Base64Decoder, QuotedPrintableDecoder
 from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
 
+if TYPE_CHECKING:  # pragma: no cover
+    from typing import Callable, TypedDict
+
+    class QuerystringCallbacks(TypedDict, total=False):
+        on_field_start: Callable[[], None]
+        on_field_name: Callable[[bytes, int, int], None]
+        on_field_data: Callable[[bytes, int, int], None]
+        on_field_end: Callable[[], None]
+        on_end: Callable[[], None]
+
+    class OctetStreamCallbacks(TypedDict, total=False):
+        on_start: Callable[[], None]
+        on_data: Callable[[bytes, int, int], None]
+        on_end: Callable[[], None]
+
+    class MultipartCallbacks(TypedDict, total=False):
+        on_part_begin: Callable[[], None]
+        on_part_data: Callable[[bytes, int, int], None]
+        on_part_end: Callable[[], None]
+        on_headers_begin: Callable[[], None]
+        on_header_field: Callable[[bytes, int, int], None]
+        on_header_value: Callable[[bytes, int, int], None]
+        on_header_end: Callable[[], None]
+        on_headers_finished: Callable[[], None]
+        on_end: Callable[[], None]
+
+
 # Unique missing object.
 _missing = object()
 
@@ -86,7 +113,7 @@ def join_bytes(b):
     return bytes(list(b))
 
 
-def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]:
+def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
     """
     Parses a Content-Type header into a value in the following format:
         (content_type, {parameters})
@@ -148,15 +175,15 @@ class Field:
     :param name: the name of the form field
     """
 
-    def __init__(self, name):
+    def __init__(self, name: str):
         self._name = name
-        self._value = []
+        self._value: list[bytes] = []
 
         # We cache the joined version of _value for speed.
         self._cache = _missing
 
     @classmethod
-    def from_value(klass, name, value):
+    def from_value(cls, name: str, value: bytes | None) -> Field:
         """Create an instance of a :class:`Field`, and set the corresponding
         value - either None or an actual value.  This method will also
         finalize the Field itself.
@@ -166,7 +193,7 @@ def from_value(klass, name, value):
                       None
         """
 
-        f = klass(name)
+        f = cls(name)
         if value is None:
             f.set_none()
         else:
@@ -174,14 +201,14 @@ def from_value(klass, name, value):
         f.finalize()
         return f
 
-    def write(self, data):
+    def write(self, data: bytes) -> int:
         """Write some data into the form field.
 
         :param data: a bytestring
         """
         return self.on_data(data)
 
-    def on_data(self, data):
+    def on_data(self, data: bytes) -> int:
         """This method is a callback that will be called whenever data is
         written to the Field.
 
@@ -191,16 +218,16 @@ def on_data(self, data):
         self._cache = _missing
         return len(data)
 
-    def on_end(self):
+    def on_end(self) -> None:
         """This method is called whenever the Field is finalized."""
         if self._cache is _missing:
             self._cache = b"".join(self._value)
 
-    def finalize(self):
+    def finalize(self) -> None:
         """Finalize the form field."""
         self.on_end()
 
-    def close(self):
+    def close(self) -> None:
         """Close the Field object.  This will free any underlying cache."""
         # Free our value array.
         if self._cache is _missing:
@@ -208,7 +235,7 @@ def close(self):
 
         del self._value
 
-    def set_none(self):
+    def set_none(self) -> None:
         """Some fields in a querystring can possibly have a value of None - for
         example, the string "foo&bar=&baz=asdf" will have a field with the
         name "foo" and value None, one with name "bar" and value "", and one
@@ -218,7 +245,7 @@ def set_none(self):
         self._cache = None
 
     @property
-    def field_name(self):
+    def field_name(self) -> str:
         """This property returns the name of the field."""
         return self._name
 
@@ -230,13 +257,13 @@ def value(self):
 
         return self._cache
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, Field):
             return self.field_name == other.field_name and self.value == other.value
         else:
             return NotImplemented
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         if len(self.value) > 97:
             # We get the repr, and then insert three dots before the final
             # quote.
@@ -553,7 +580,7 @@ class BaseParser:
     def __init__(self):
         self.logger = logging.getLogger(__name__)
 
-    def callback(self, name, data=None, start=None, end=None):
+    def callback(self, name: str, data=None, start=None, end=None):
         """This function calls a provided callback with some data.  If the
         callback is not set, will do nothing.
 
@@ -584,7 +611,7 @@ def callback(self, name, data=None, start=None, end=None):
             self.logger.debug("Calling %s with no data", name)
             func()
 
-    def set_callback(self, name, new_func):
+    def set_callback(self, name: str, new_func):
         """Update the function for a callback.  Removes from the callbacks dict
         if new_func is None.
 
@@ -637,7 +664,7 @@ class OctetStreamParser(BaseParser):
                      i.e. unbounded.
     """
 
-    def __init__(self, callbacks={}, max_size=float("inf")):
+    def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
         super().__init__()
         self.callbacks = callbacks
         self._started = False
@@ -647,7 +674,7 @@ def __init__(self, callbacks={}, max_size=float("inf")):
         self.max_size = max_size
         self._current_size = 0
 
-    def write(self, data):
+    def write(self, data: bytes):
         """Write some data to the parser, which will perform size verification,
         and then pass the data to the underlying callback.
 
@@ -732,7 +759,9 @@ class QuerystringParser(BaseParser):
                      i.e. unbounded.
     """
 
-    def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
+    state: QuerystringState
+
+    def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")):
         super().__init__()
         self.state = QuerystringState.BEFORE_FIELD
         self._found_sep = False
@@ -748,7 +777,7 @@ def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
         # Should parsing be strict?
         self.strict_parsing = strict_parsing
 
-    def write(self, data):
+    def write(self, data: bytes):
         """Write some data to the parser, which will perform size verification,
         parse into either a field name or value, and then pass the
         corresponding data to the underlying callback.  If an error is
@@ -780,7 +809,7 @@ def write(self, data):
 
         return l
 
-    def _internal_write(self, data, length):
+    def _internal_write(self, data: bytes, length: int):
         state = self.state
         strict_parsing = self.strict_parsing
         found_sep = self._found_sep
@@ -989,7 +1018,7 @@ class MultipartParser(BaseParser):
                      i.e. unbounded.
     """
 
-    def __init__(self, boundary, callbacks={}, max_size=float("inf")):
+    def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
         # Initialize parser state.
         super().__init__()
         self.state = MultipartState.START
diff --git a/tests/test_multipart.py b/tests/test_multipart.py
index b9cba86..16db5b3 100644
--- a/tests/test_multipart.py
+++ b/tests/test_multipart.py
@@ -333,9 +333,9 @@ def on_field_end():
             del name_buffer[:]
             del data_buffer[:]
 
-        callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
-
-        self.p = QuerystringParser(callbacks)
+        self.p = QuerystringParser(
+            callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
+        )
 
     def test_simple_querystring(self):
         self.p.write(b"foo=bar")
@@ -464,18 +464,16 @@ def setUp(self):
         self.started = 0
         self.finished = 0
 
-        def on_start():
+        def on_start() -> None:
             self.started += 1
 
-        def on_data(data, start, end):
+        def on_data(data: bytes, start: int, end: int) -> None:
             self.d.append(data[start:end])
 
-        def on_end():
+        def on_end() -> None:
             self.finished += 1
 
-        callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end}
-
-        self.p = OctetStreamParser(callbacks)
+        self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end})
 
     def assert_data(self, data, finalize=True):
         self.assertEqual(b"".join(self.d), data)