From f302e7e9d74097401a42bdb240ff30a92e565360 Mon Sep 17 00:00:00 2001
From: Ben Darnell <ben@bendarnell.com>
Date: Sun, 13 Sep 2015 13:15:06 -0400
Subject: [PATCH] Make HTTPHeaders a subclass of MutableMapping ABC instead of
 dict.

This simplifies the implementation since MutableMapping is designed
for subclassing while dict has many special cases that need to be
overridden. In particular, this change fixes the setdefault()
method.

Fixes #1500.
---
 tornado/httputil.py           | 55 ++++++++++-------------------------
 tornado/test/httputil_test.py | 10 +++++++
 2 files changed, 26 insertions(+), 39 deletions(-)

Index: tornado-4.2.1/tornado/httputil.py
===================================================================
--- tornado-4.2.1.orig/tornado/httputil.py
+++ tornado-4.2.1/tornado/httputil.py
@@ -96,7 +96,7 @@ class _NormalizedHeaderCache(dict):
 _normalized_headers = _NormalizedHeaderCache(1000)
 
 
-class HTTPHeaders(dict):
+class HTTPHeaders(collections.MutableMapping):
     """A dictionary that maintains ``Http-Header-Case`` for all keys.
 
     Supports multiple values per key via a pair of new methods,
@@ -125,10 +125,14 @@ class HTTPHeaders(dict):
     Set-Cookie: C=D
     """
     def __init__(self, *args, **kwargs):
-        # Don't pass args or kwargs to dict.__init__, as it will bypass
-        # our __setitem__
-        dict.__init__(self)
+        # Formally, HTTP headers are a mapping from a field name to a "combined field value",
+        # which may be constructed from multiple field lines by joining them with commas.
+        # In practice, however, some headers (notably Set-Cookie) do not follow this convention,
+        # so we maintain a mapping from field name to a list of field lines in self._as_list.
+        # self._combined_cache is a cache of the combined field values derived from self._as_list
+        # on demand (and cleared whenever the list is modified).
         self._as_list = {}
+        self._combined_cache = {}
         self._last_key = None
         if (len(args) == 1 and len(kwargs) == 0 and
                 isinstance(args[0], HTTPHeaders)):
@@ -146,10 +150,7 @@ class HTTPHeaders(dict):
         norm_name = _normalized_headers[name]
         self._last_key = norm_name
         if norm_name in self:
-            # bypass our override of __setitem__ since it modifies _as_list
-            dict.__setitem__(self, norm_name,
-                             native_str(self[norm_name]) + ',' +
-                             native_str(value))
+            self._combined_cache.pop(norm_name, None)
             self._as_list[norm_name].append(value)
         else:
             self[norm_name] = value
@@ -181,8 +182,7 @@ class HTTPHeaders(dict):
             # continuation of a multi-line header
             new_part = ' ' + line.lstrip()
             self._as_list[self._last_key][-1] += new_part
-            dict.__setitem__(self, self._last_key,
-                             self[self._last_key] + new_part)
+            self._combined_cache.pop(self._last_key, None)
         else:
             name, value = line.split(":", 1)
             self.add(name, value.strip())
@@ -201,32 +201,36 @@ class HTTPHeaders(dict):
                 h.parse_line(line)
         return h
 
-    # dict implementation overrides
+    # MutableMapping abstract method implementations.
 
     def __setitem__(self, name, value):
         norm_name = _normalized_headers[name]
-        dict.__setitem__(self, norm_name, value)
+        self._combined_cache[norm_name] = value
         self._as_list[norm_name] = [value]
 
+    def __contains__(self, name):
+        # This is an important optimization to avoid the expensive concatenation
+        # in __getitem__ when it's not needed.
+        if not isinstance(name, str):
+            return False
+        return name in self._as_list
+
     def __getitem__(self, name):
-        return dict.__getitem__(self, _normalized_headers[name])
+        header = _normalized_headers[name]
+        if header not in self._combined_cache:
+            self._combined_cache[header] = ",".join(self._as_list[header])
+        return self._combined_cache[header]
 
     def __delitem__(self, name):
         norm_name = _normalized_headers[name]
-        dict.__delitem__(self, norm_name)
+        del self._combined_cache[norm_name]
         del self._as_list[norm_name]
 
-    def __contains__(self, name):
-        norm_name = _normalized_headers[name]
-        return dict.__contains__(self, norm_name)
-
-    def get(self, name, default=None):
-        return dict.get(self, _normalized_headers[name], default)
+    def __len__(self):
+        return len(self._as_list)
 
-    def update(self, *args, **kwargs):
-        # dict.update bypasses our __setitem__
-        for k, v in dict(*args, **kwargs).items():
-            self[k] = v
+    def __iter__(self):
+        return iter(self._as_list)
 
     def copy(self):
         # default implementation returns dict(self), not the subclass
Index: tornado-4.2.1/tornado/test/httputil_test.py
===================================================================
--- tornado-4.2.1.orig/tornado/test/httputil_test.py
+++ tornado-4.2.1/tornado/test/httputil_test.py
@@ -301,7 +301,31 @@ Foo: even
             self.assertIsNot(headers, h1)
             self.assertIsNot(headers.get_list('A'), h1.get_list('A'))
 
+    def test_linear_performance(self):
+        def f(n):
+            start = time.time()
+            headers = HTTPHeaders()
+            for i in range(n):
+                headers.add("X-Foo", "bar")
+            return time.time() - start
 
+        # This runs under 50ms on my laptop as of 2025-12-09.
+        d1 = f(10000)
+        d2 = f(100000)
+        if d2 / d1 > 20:
+            # d2 should be about 10x d1 but allow a wide margin for variability.
+            self.fail("HTTPHeaders.add() does not scale linearly: %s vs %s" % (d1, d2))
+
+    def test_setdefault(self):
+        headers = HTTPHeaders()
+        headers['foo'] = 'bar'
+        # If a value is present, setdefault returns it without changes.
+        self.assertEqual(headers.setdefault('foo', 'baz'), 'bar')
+        self.assertEqual(headers['foo'], 'bar')
+        # If a value is not present, setdefault sets it for future use.
+        self.assertEqual(headers.setdefault('quux', 'xyzzy'), 'xyzzy')
+        self.assertEqual(headers['quux'], 'xyzzy')
+        self.assertEqual(sorted(headers.get_all()), [('Foo', 'bar'), ('Quux', 'xyzzy')])
 
 class FormatTimestampTest(unittest.TestCase):
     # Make sure that all the input types are supported.
