From 86ebb1671154baa6746fede415643c7e2d6ce463 Mon Sep 17 00:00:00 2001
From: Marcelo Trylesinski <marcelotryle@gmail.com>
Date: Wed, 15 Apr 2026 23:43:46 +0200
Subject: [PATCH 1/2] Add multipart header limits

---
 python_multipart/multipart.py | 55 ++++++++++++++++++++++++++++++++++-
 tests/test_multipart.py       | 19 ++++++++++++
 2 files changed, 73 insertions(+), 1 deletion(-)

Index: python_multipart-0.0.20/python_multipart/multipart.py
===================================================================
--- python_multipart-0.0.20.orig/python_multipart/multipart.py
+++ python_multipart-0.0.20/python_multipart/multipart.py
@@ -52,6 +52,8 @@ if TYPE_CHECKING:  # pragma: no cover
         UPLOAD_ERROR_ON_BAD_CTE: bool
         MAX_MEMORY_FILE_SIZE: int
         MAX_BODY_SIZE: float
+        MAX_HEADER_COUNT: int
+        MAX_HEADER_SIZE: int
 
     class FileConfig(TypedDict, total=False):
         UPLOAD_DIR: str | bytes | None
@@ -163,6 +165,12 @@ TOKEN_CHARS_SET = frozenset(
     b"!#$%&'*+-.^_`|~")
 # fmt: on
 
+DEFAULT_MAX_HEADER_COUNT = 8
+"""Default maximum number of headers allowed per multipart part."""
+
+DEFAULT_MAX_HEADER_SIZE = 4096 + 128
+"""Default maximum size of a single multipart header line, including syntax overhead."""
+
 
 def parse_options_header(value: str | bytes | None) -> tuple[bytes, dict[bytes, bytes]]:
     """Parses a Content-Type header into a value in the following format: (content_type, {parameters})."""
@@ -974,10 +982,18 @@ class MultipartParser(BaseParser):
         boundary: The multipart boundary.  This is required, and must match what is given in the HTTP request - usually in the Content-Type header.
         callbacks: A dictionary of callbacks.  See the documentation for [`BaseParser`][python_multipart.BaseParser].
         max_size: The maximum size of body to parse.  Defaults to infinity - i.e. unbounded.
+        max_header_count: The maximum number of headers allowed per part.
+        max_header_size: The maximum size of a single header line (excluding the trailing CRLF).
     """  # noqa: E501
 
     def __init__(
-        self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf")
+        self,
+        boundary: bytes | str,
+        callbacks: MultipartCallbacks = {},
+        max_size: float = float("inf"),
+        *,
+        max_header_count: int = DEFAULT_MAX_HEADER_COUNT,
+        max_header_size: int = DEFAULT_MAX_HEADER_SIZE,
     ) -> None:
         # Initialize parser state.
         super().__init__()
@@ -991,6 +1007,12 @@ class MultipartParser(BaseParser):
         self.max_size = max_size
         self._current_size = 0
 
+        self.max_header_count = max_header_count
+        self._current_header_count = 0
+
+        self.max_header_size = max_header_size
+        self._current_header_size = 0
+
         # Setup marks.  These are used to track the state of data received.
         self.marks: dict[str, int] = {}
 
@@ -1044,10 +1066,18 @@ class MultipartParser(BaseParser):
         state = self.state
         index = self.index
         flags = self.flags
+        current_header_count = self._current_header_count
+        current_header_size = self._current_header_size
 
         # Our index defaults to 0.
         i = 0
 
+        def advance_header_size(amount: int = 1) -> None:
+            nonlocal current_header_size
+            current_header_size += amount
+            if current_header_size > self.max_header_size:
+                raise MultipartParseError("Maximum header size exceeded", offset=i)
+
         # Set a mark.
         def set_mark(name: str) -> None:
             self.marks[name] = i
@@ -1152,6 +1182,8 @@ class MultipartParser(BaseParser):
 
                     # Callback for the start of a part.
                     self.callback("part_begin")
+                    current_header_count = 0
+                    current_header_size = 0
 
                     # Move to the next character and state.
                     state = MultipartState.HEADER_FIELD_START
@@ -1173,6 +1205,12 @@ class MultipartParser(BaseParser):
                 # continue parsing our header field.
                 index = 0
 
+                if c != CR:
+                    current_header_count += 1
+                    if current_header_count > self.max_header_count:
+                        raise MultipartParseError("Maximum header count exceeded", offset=i)
+                    current_header_size = 0
+
                 # Set a mark of our header field.
                 set_mark("header_field")
 
@@ -1202,6 +1240,7 @@ class MultipartParser(BaseParser):
 
                 # If we've reached a colon, we're done with this header.
                 if c == COLON:
+                    advance_header_size()
                     # A 0-length header is an error.
                     if index == 1:
                         msg = "Found 0-length header at %d" % (i,)
@@ -1222,10 +1261,13 @@ class MultipartParser(BaseParser):
                     e = MultipartParseError(msg)
                     e.offset = i
                     raise e
+                else:
+                    advance_header_size()
 
             elif state == MultipartState.HEADER_VALUE_START:
                 # Skip leading spaces.
                 if c == SPACE:
+                    advance_header_size()
                     i += 1
                     continue
 
@@ -1242,7 +1284,10 @@ class MultipartParser(BaseParser):
                 if c == CR:
                     data_callback("header_value", i)
                     self.callback("header_end")
+                    current_header_size = 0
                     state = MultipartState.HEADER_VALUE_ALMOST_DONE
+                else:
+                    advance_header_size()
 
             elif state == MultipartState.HEADER_VALUE_ALMOST_DONE:
                 # The last character should be a LF.  If not, it's an error.
@@ -1366,6 +1411,8 @@ class MultipartParser(BaseParser):
                             # a part, and are starting a new one.
                             self.callback("part_end")
                             self.callback("part_begin")
+                            current_header_count = 0
+                            current_header_size = 0
 
                             # Move to parsing new headers.
                             index = 0
@@ -1455,6 +1502,8 @@ class MultipartParser(BaseParser):
         self.state = state
         self.index = index
         self.flags = flags
+        self._current_header_count = current_header_count
+        self._current_header_size = current_header_size
 
         # Return our data length to indicate no errors, and that we processed
         # all of it.
@@ -1513,6 +1562,8 @@ class FormParser:
     #: Note: all file sizes should be in bytes.
     DEFAULT_CONFIG: FormParserConfig = {
         "MAX_BODY_SIZE": float("inf"),
+        "MAX_HEADER_COUNT": DEFAULT_MAX_HEADER_COUNT,
+        "MAX_HEADER_SIZE": DEFAULT_MAX_HEADER_SIZE,
         "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024,
         "UPLOAD_DIR": None,
         "UPLOAD_KEEP_FILENAME": False,
@@ -1748,6 +1799,8 @@ class FormParser:
                     "on_end": _on_end,
                 },
                 max_size=self.config["MAX_BODY_SIZE"],
+                max_header_count=self.config["MAX_HEADER_COUNT"],
+                max_header_size=self.config["MAX_HEADER_SIZE"],
             )
 
         else:
Index: python_multipart-0.0.20/tests/test_multipart.py
===================================================================
--- python_multipart-0.0.20.orig/tests/test_multipart.py
+++ python_multipart-0.0.20/tests/test_multipart.py
@@ -1072,6 +1072,18 @@ class TestFormParser(unittest.TestCase):
         with self.assertRaisesRegex(MultipartParseError, "Expected boundary character %r, got %r" % (b"b"[0], b"B"[0])):
             self.f.write(data)
 
+    def test_multipart_header_count_limit(self) -> None:
+        self.make("poc")
+        payload = b'--poc\r\nContent-Disposition: form-data; name="x"\r\n' + (b"X-A: 1\r\n" * 8)
+        with self.assertRaisesRegex(MultipartParseError, "Maximum header count exceeded"):
+            self.f.write(payload)
+
+    def test_multipart_header_size_limit(self) -> None:
+        self.make("poc")
+        payload = b'--poc\r\nContent-Disposition: form-data; name="x"\r\n' + b"X-A: " + (b"a" * (4096 + 124))
+        with self.assertRaisesRegex(MultipartParseError, "Maximum header size exceeded"):
+            self.f.write(payload)
+
     def test_octet_stream(self) -> None:
         files: list[File] = []
 
Index: python_multipart-0.0.20/python_multipart/exceptions.py
===================================================================
--- python_multipart-0.0.20.orig/python_multipart/exceptions.py
+++ python_multipart-0.0.20/python_multipart/exceptions.py
@@ -9,7 +9,9 @@ class ParseError(FormParserError):
 
     #: This is the offset in the input data chunk (*NOT* the overall stream) in
     #: which the parse error occurred.  It will be -1 if not specified.
-    offset = -1
+    def __init__(self, message, offset=-1):
+        super().__init__(message)
+        self.offset = offset
 
 
 class MultipartParseError(ParseError):
