From 86ebb1671154baa6746fede415643c7e2d6ce463 Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski <marcelotryle@gmail.com>
Date: Wed, 15 Apr 2026 23:43:46 +0200
Subject: [PATCH 1/2] Add multipart header limits

---
 multipart/multipart.py | 55 ++++++++++++++++++++++++++++++++++-
 tests/test_multipart.py       | 19 ++++++++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

Index: python_multipart-0.0.9/multipart/multipart.py
===================================================================
--- python_multipart-0.0.9.orig/multipart/multipart.py
+++ python_multipart-0.0.9/multipart/multipart.py
@@ -47,6 +47,8 @@ if TYPE_CHECKING:  # pragma: no cover
         UPLOAD_ERROR_ON_BAD_CTE: bool
         MAX_MEMORY_FILE_SIZE: int
         MAX_BODY_SIZE: float
+        MAX_HEADER_COUNT: int
+        MAX_HEADER_SIZE: int
 
     class FileConfig(TypedDict, total=False):
         UPLOAD_DIR: str | None
@@ -127,6 +129,23 @@ def ord_char(c):
 def join_bytes(b):
     return bytes(list(b))
 
+# fmt: off
+# Mask for ASCII characters that can be http tokens.
+# Per RFC7230 - 3.2.6, this is all alpha-numeric characters
+# and these: !#$%&'*+-.^_`|~
+TOKEN_CHARS_SET = frozenset(
+    b"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+    b"abcdefghijklmnopqrstuvwxyz"
+    b"0123456789"
+    b"!#$%&'*+-.^_`|~")
+# fmt: on
+
+DEFAULT_MAX_HEADER_COUNT = 8
+"""Default maximum number of headers allowed per multipart part."""
+
+DEFAULT_MAX_HEADER_SIZE = 4096 + 128
+"""Default maximum size of a single multipart header line, including syntax overhead."""
+
 
 def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
     """
@@ -1033,9 +1052,19 @@ class MultipartParser(BaseParser):
 
     :param max_size: The maximum size of body to parse.  Defaults to infinity -
                      i.e. unbounded.
+    :max_header_count: The maximum number of headers allowed per part.
+    :max_header_size: The maximum size of a single header line (excluding the trailing CRLF).
     """
 
-    def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
+    def __init__(
+        self,
+        boundary: bytes | str,
+        callbacks: MultipartCallbacks = {},
+        max_size: float = float("inf"),
+        *,
+        max_header_count: int = DEFAULT_MAX_HEADER_COUNT,
+        max_header_size: int = DEFAULT_MAX_HEADER_SIZE,
+    ) -> None:
         # Initialize parser state.
         super().__init__()
         self.state = MultipartState.START
@@ -1048,6 +1077,12 @@ class MultipartParser(BaseParser):
         self.max_size = max_size
         self._current_size = 0
 
+        self.max_header_count = max_header_count
+        self._current_header_count = 0
+
+        self.max_header_size = max_header_size
+        self._current_header_size = 0
+
         # Setup marks.  These are used to track the state of data received.
         self.marks = {}
 
@@ -1115,10 +1150,18 @@ class MultipartParser(BaseParser):
         state = self.state
         index = self.index
         flags = self.flags
+        current_header_count = self._current_header_count
+        current_header_size = self._current_header_size
 
         # Our index defaults to 0.
         i = 0
 
+        def advance_header_size(amount: int = 1) -> None:
+            nonlocal current_header_size
+            current_header_size += amount
+            if current_header_size > self.max_header_size:
+                raise MultipartParseError("Maximum header size exceeded", offset=i)
+
         # Set a mark.
         def set_mark(name):
             self.marks[name] = i
@@ -1199,6 +1242,8 @@ class MultipartParser(BaseParser):
 
                     # Callback for the start of a part.
                     self.callback("part_begin")
+                    current_header_count = 0
+                    current_header_size = 0
 
                     # Move to the next character and state.
                     state = MultipartState.HEADER_FIELD_START
@@ -1220,6 +1265,12 @@ class MultipartParser(BaseParser):
                 # continue parsing our header field.
                 index = 0
 
+                if c != CR:
+                    current_header_count += 1
+                    if current_header_count > self.max_header_count:
+                        raise MultipartParseError("Maximum header count exceeded", offset=i)
+                    current_header_size = 0
+
                 # Set a mark of our header field.
                 set_mark("header_field")
 
@@ -1246,6 +1297,7 @@ class MultipartParser(BaseParser):
 
                 # If we've reached a colon, we're done with this header.
                 elif c == COLON:
+                    advance_header_size()
                     # A 0-length header is an error.
                     if index == 1:
                         msg = "Found 0-length header at %d" % (i,)
@@ -1260,20 +1312,17 @@ class MultipartParser(BaseParser):
                     # Move to parsing the header value.
                     state = MultipartState.HEADER_VALUE_START
 
+                elif c not in TOKEN_CHARS_SET:
+                    msg = "Found invalid character %r in header at %d" % (c, i)
+                    self.logger.warning(msg)
+                    raise MultipartParseError(msg, offset=i)
                 else:
-                    # Lower-case this character, and ensure that it is in fact
-                    # a valid letter.  If not, it's an error.
-                    cl = lower_char(c)
-                    if cl < LOWER_A or cl > LOWER_Z:
-                        msg = "Found non-alphanumeric character %r in " "header at %d" % (c, i)
-                        self.logger.warning(msg)
-                        e = MultipartParseError(msg)
-                        e.offset = i
-                        raise e
+                    advance_header_size()
 
             elif state == MultipartState.HEADER_VALUE_START:
                 # Skip leading spaces.
                 if c == SPACE:
+                    advance_header_size()
                     i += 1
                     continue
 
@@ -1290,7 +1339,10 @@ class MultipartParser(BaseParser):
                 if c == CR:
                     data_callback("header_value")
                     self.callback("header_end")
+                    current_header_size = 0
                     state = MultipartState.HEADER_VALUE_ALMOST_DONE
+                else:
+                    advance_header_size()
 
             elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
                 # The last character should be a LF.  If not, it's an error.
@@ -1412,6 +1464,8 @@ class MultipartParser(BaseParser):
                             # a part, and are starting a new one.
                             self.callback("part_end")
                             self.callback("part_begin")
+                            current_header_count = 0
+                            current_header_size = 0
 
                             # Move to parsing new headers.
                             index = 0
@@ -1499,6 +1553,8 @@ class MultipartParser(BaseParser):
         self.state = state
         self.index = index
         self.flags = flags
+        self._current_header_count = current_header_count
+        self._current_header_size = current_header_size
 
         # Return our data length to indicate no errors, and that we processed
         # all of it.
@@ -1576,6 +1632,8 @@ class FormParser:
     #: Note: all file sizes should be in bytes.
     DEFAULT_CONFIG: FormParserConfig = {
         "MAX_BODY_SIZE": float("inf"),
+        "MAX_HEADER_COUNT": DEFAULT_MAX_HEADER_COUNT,
+        "MAX_HEADER_SIZE": DEFAULT_MAX_HEADER_SIZE,
         "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
         "UPLOAD_DIR": None,
         "UPLOAD_KEEP_FILENAME": False,
@@ -1797,6 +1855,8 @@ class FormParser:
                     "on_end": on_end,
                 },
                 max_size=self.config["MAX_BODY_SIZE"],
+                max_header_count=self.config["MAX_HEADER_COUNT"],
+                max_header_size=self.config["MAX_HEADER_SIZE"],
             )
 
         else:
Index: python_multipart-0.0.9/tests/test_multipart.py
===================================================================
--- python_multipart-0.0.9.orig/tests/test_multipart.py
+++ python_multipart-0.0.9/tests/test_multipart.py
@@ -1023,6 +1023,18 @@ class TestFormParser(unittest.TestCase):
         with self.assertRaises(MultipartParseError):
             i = self.f.write(data)
 
+    def test_multipart_header_count_limit(self) -> None:
+        self.make("poc")
+        payload = b'--poc\r\nContent-Disposition: form-data; name="x"\r\n' + (b"X-A: 1\r\n" * 8)
+        with self.assertRaisesRegex(MultipartParseError, "Maximum header count exceeded"):
+            self.f.write(payload)
+
+    def test_multipart_header_size_limit(self) -> None:
+        self.make("poc")
+        payload = b'--poc\r\nContent-Disposition: form-data; name="x"\r\n' + b"X-A: " + (b"a" * (4096 + 125))
+        with self.assertRaisesRegex(MultipartParseError, "Maximum header size exceeded"):
+            self.f.write(payload)
+
     def test_octet_stream(self):
         files = []
 
Index: python_multipart-0.0.9/multipart/exceptions.py
===================================================================
--- python_multipart-0.0.9.orig/multipart/exceptions.py
+++ python_multipart-0.0.9/multipart/exceptions.py
@@ -9,7 +9,9 @@ class ParseError(FormParserError):
 
     #: This is the offset in the input data chunk (*NOT* the overall stream) in
     #: which the parse error occurred.  It will be -1 if not specified.
-    offset = -1
+    def __init__(self, message, offset=-1):
+        super().__init__(message)
+        self.offset = offset
 
 
 class MultipartParseError(ParseError):
Index: python_multipart-0.0.9/tests/test_data/http/bad_header_char.http
===================================================================
--- python_multipart-0.0.9.orig/tests/test_data/http/bad_header_char.http
+++ python_multipart-0.0.9/tests/test_data/http/bad_header_char.http
@@ -1,5 +1,5 @@
 ------WebKitFormBoundaryTkr3kCBQlBe1nrhc
-Content-999position: form-data; name="field"
+Content-<<<position: form-data; name="field"
 
 This is a test.
 ------WebKitFormBoundaryTkr3kCBQlBe1nrhc--
\ No newline at end of file
