From be71ecaa113f642f03083bd5ed33af47c59308c8 Mon Sep 17 00:00:00 2001
From: tomasilluminati <tomas.illuminati@owasp.org>
Date: Sun, 19 Apr 2026 05:57:33 -0300
Subject: [PATCH] (fix): denial of service in twisted.names mitigation

---
 src/twisted/names/dns.py           | 116 +++++++++++++++++++++++++----
 src/twisted/names/test/test_dns.py |  91 ++++++++++++++++++++++
 2 files changed, 193 insertions(+), 14 deletions(-)

Index: Twisted-22.2.0/src/twisted/names/dns.py
===================================================================
--- Twisted-22.2.0.orig/src/twisted/names/dns.py
+++ Twisted-22.2.0/src/twisted/names/dns.py
@@ -10,6 +10,7 @@ Future Plans:
 """
 
 # System imports
+import contextvars
 import inspect
 import random
 import socket
@@ -125,6 +126,7 @@ __all__ = [
     "OP_UPDATE",
     "PORT",
     "AuthoritativeDomainError",
+    "DNSDecodeError",
     "DNSQueryTimeoutError",
     "DomainError",
 ]
@@ -425,6 +427,62 @@ def readPrecisely(file, l):
     return buff
 
 
+# Cap the total number of compression-pointer dereferences performed while
+# decoding a single DNS message.  A hostile peer can otherwise craft a packet
+# in which every record name chases a long compression chain, forcing O(N*M)
+# work and stalling the reactor.
+MAX_COMPRESSION_POINTERS_PER_MESSAGE = 1000
+
+
+class DNSDecodeError(ValueError):
+    """
+    Raised when a DNS message cannot be decoded because it violates a
+    protocol-level safety limit
+    """
+
+
+class _DecodeContext:
+    """
+    Mutable state shared between the L{IEncodable} decoders invoked while
+    reading a single DNS message.
+
+    The primary purpose is to bound the total number of compression-pointer
+    jumps taken across every name in the message, defending against packets
+    that fan out thousands of records pointing to deeply chained pointers.
+
+    @ivar jumps: The number of compression pointers followed so far.
+    @ivar maxJumps: The inclusive upper bound on C{jumps}.  Exceeding it
+        causes L{registerJump} to raise L{DNSDecodeError}.
+    """
+
+    __slots__ = ("jumps", "maxJumps")
+
+    def __init__(self, maxJumps: int = MAX_COMPRESSION_POINTERS_PER_MESSAGE) -> None:
+        self.jumps = 0
+        self.maxJumps = maxJumps
+
+    def registerJump(self) -> None:
+        """
+        Record that a compression pointer has been followed
+
+        @raise DNSDecodeError: if the cumulative number of jumps exceeds
+            L{maxJumps}
+        """
+        self.jumps += 1
+        if self.jumps > self.maxJumps:
+            raise DNSDecodeError(
+                "Too many compression pointers while decoding DNS message "
+                f"(limit is {self.maxJumps})"
+            )
+
+
+# Tracks state across nested calls without altering every record's signature.
+# L{Message.decode} manages the lifecycle per-message, while standalone decoders
+# default to a local context when C{_decodeContextVar} is C{None}
+
+_decodeContextVar = contextvars.ContextVar("_dnsDecodeContext", default=None)
+
+
 class IEncodable(Interface):
     """
     Interface for something which can be encoded to and decoded
@@ -572,7 +630,7 @@ class Name:
             strio.write(label)
         strio.write(b"\x00")
 
-    def decode(self, strio, length=None):
+    def decode(self, strio, length=None, context=None):
         """
         Decode a byte string into this Name.
 
@@ -580,12 +638,27 @@ class Name:
         @param strio: Bytes will be read from this file until the full Name
         is decoded.
 
+        @type context: L{_DecodeContext} or L{None}
+        @param context: Shared decoding state used to cap the total number
+            of compression-pointer jumps taken while decoding the enclosing
+            DNS message.  When L{None}, the context installed by
+            L{Message.decode} is used if one is active; otherwise a fresh,
+            call-local context is created so that direct callers remain
+            protected and backwards compatible.
+
         @raise EOFError: Raised when there are not enough bytes available
         from C{strio}.
 
-        @raise ValueError: Raised when the name cannot be decoded (for example,
-            because it contains a loop).
+        @raise ValueError: Raised when the name cannot be decoded because it
+            contains a compression loop.
+
+        @raise DNSDecodeError: Raised when the cumulative number of
+            compression-pointer jumps exceeds the configured limit.
         """
+        if context is None:
+            context = _decodeContextVar.get()
+            if context is None:
+                context = _DecodeContext()
         visited = set()
         self.name = b""
         off = 0
@@ -597,6 +670,7 @@ class Name:
                 return
             if (l >> 6) == 3:
                 new_off = (l & 63) << 8 | ord(readPrecisely(strio, 1))
+                context.registerJump()
                 if new_off in visited:
                     raise ValueError("Compression loop in encoded name")
                 visited.add(new_off)
@@ -2670,19 +2744,31 @@ class Message(tputil.FancyEqMixin):
         self.checkingDisabled = (byte4 >> 4) & 1
         self.rCode = byte4 & 0xF
 
-        self.queries = []
-        for i in range(nqueries):
-            q = Query()
-            try:
-                q.decode(strio)
-            except EOFError:
-                return
-            self.queries.append(q)
+        # A single shared counter bounds the total compression-pointer work
+        # performed across every name in this message.  It is installed on
+        # the context variable so nested record decoders pick it up without
+        # needing to thread it through each signature.
+        token = _decodeContextVar.set(_DecodeContext())
+        try:
+            self.queries = []
+            for i in range(nqueries):
+                q = Query()
+                try:
+                    q.decode(strio)
+                except EOFError:
+                    return
+                self.queries.append(q)
 
-        items = ((self.answers, nans), (self.authority, nns), (self.additional, nadd))
+            items = (
+                (self.answers, nans),
+                (self.authority, nns),
+                (self.additional, nadd),
+            )
 
-        for (l, n) in items:
-            self.parseRecords(l, n, strio)
+            for l, n in items:
+                self.parseRecords(l, n, strio)
+        finally:
+            _decodeContextVar.reset(token)
 
     def parseRecords(self, list, num, strio):
         for i in range(num):
Index: Twisted-22.2.0/src/twisted/names/test/test_dns.py
===================================================================
--- Twisted-22.2.0.orig/src/twisted/names/test/test_dns.py
+++ Twisted-22.2.0/src/twisted/names/test/test_dns.py
@@ -347,6 +347,60 @@ class NameTests(unittest.TestCase):
         stream = BytesIO(b"\xc0\x00")
         self.assertRaises(ValueError, name.decode, stream)
 
+    def test_rejectTooManyCompressionPointers(self):
+        """
+        L{Name.decode} raises L{dns.DNSDecodeError} when the number of
+        compression-pointer dereferences taken for a single message exceeds
+        the limit carried by the shared L{dns._DecodeContext}.
+        """
+        # Five distinct pointers chained end-to-end, terminated by a zero
+        # label byte.  With a maxJumps of three the fourth dereference must
+        # trip the safety limit.
+        payload = b"\xc0\x02\xc0\x04\xc0\x06\xc0\x08\x00"
+        context = dns._DecodeContext(maxJumps=3)
+        self.assertRaises(
+            dns.DNSDecodeError,
+            dns.Name().decode,
+            BytesIO(payload),
+            None,
+            context,
+        )
+
+    def test_compressionPointerCounterIsShared(self):
+        """
+        The L{dns._DecodeContext} counter accumulates across successive
+        L{Name.decode} calls, so that a message whose individual names are
+        each within bounds is still rejected when their aggregate exceeds
+        the configured limit.
+        """
+        payload = b"\xc0\x02\xc0\x04\x00"
+        context = dns._DecodeContext(maxJumps=3)
+
+        stream = BytesIO(payload)
+        dns.Name().decode(stream, context=context)
+        self.assertEqual(context.jumps, 2)
+
+        stream.seek(0)
+        self.assertRaises(
+            dns.DNSDecodeError,
+            dns.Name().decode,
+            stream,
+            None,
+            context,
+        )
+
+    def test_decodeWithoutContextIsBackwardsCompatible(self):
+        """
+        L{Name.decode} continues to work when called without a context,
+        using a fresh per-call counter so existing callers are unaffected.
+        """
+        name = dns.Name()
+        stream = BytesIO()
+        dns.Name(b"example.org").encode(stream)
+        stream.seek(0)
+        name.decode(stream)
+        self.assertEqual(name.name, b"example.org")
+
     def test_equality(self):
         """
         L{Name} instances are equal as long as they have the same value for
@@ -756,6 +810,43 @@ class MessageTests(unittest.SynchronousT
         """
         self.assertEqual(dns.Message().authenticData, 0)
 
+    def test_rejectCompressionPointerFlood(self):
+        """
+        L{Message.decode} installs a shared compression-pointer counter and
+        raises L{dns.DNSDecodeError} when the aggregate number of pointer
+        dereferences across every record in the message exceeds
+        L{dns.MAX_COMPRESSION_POINTERS_PER_MESSAGE}.
+        """
+        chainLength = 100
+        numRecords = 8000
+        header = struct.pack(
+            "!H2B4H", 0x1234, 0x80, 0x00, 0, numRecords, 0, 0
+        )
+
+        # Long compression chain inside the RDATA of an unknown
+        # record so that subsequent records can aim pointers at it.
+        owner = b"\x04rrrr\x00"
+        chainBase = len(header) + len(owner) + 10
+        chain = bytearray()
+        for i in range(chainLength):
+            chain += struct.pack("!H", 0xC000 | (chainBase + 2 * (i + 1)))
+        chain += b"\x04test\x00"
+
+        firstRecord = (
+            owner
+            + struct.pack("!HHIH", 999, 1, 0, len(chain))
+            + bytes(chain)
+        )
+        followupRecord = (
+            struct.pack("!H", 0xC000 | chainBase)
+            + struct.pack("!HHIH", 1, 1, 0, 4)
+            + b"\x00\x00\x00\x00"
+        )
+        payload = header + firstRecord + followupRecord * (numRecords - 1)
+
+        message = dns.Message()
+        self.assertRaises(dns.DNSDecodeError, message.decode, BytesIO(payload))
+
     def test_authenticDataOverride(self):
         """
         L{dns.Message.__init__} accepts a C{authenticData} argument which
