From bcb6cea4edde1d00ff4f0eece883dea951f66e1b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= <alex.gronholm@nextday.fi>
Date: Sun, 1 Mar 2026 16:49:59 +0200
Subject: [PATCH] Added the max_depth decoder parameter

---
 docs/versionhistory.rst |  2 ++
 source/decoder.c        | 26 +++++++++++++++-----------
 source/decoder.h        |  2 ++
 tests/test_decoder.py   | 15 +++++++++++++++
 4 files changed, 34 insertions(+), 11 deletions(-)

Index: cbor2-5.6.5/source/decoder.c
===================================================================
--- cbor2-5.6.5.orig/source/decoder.c
+++ cbor2-5.6.5/source/decoder.c
@@ -140,6 +140,7 @@ CBORDecoder_new(PyTypeObject *type, PyOb
         Py_INCREF(Py_None);
         self->object_hook = Py_None;
         self->str_errors = PyBytes_FromString("strict");
+        self->max_depth = CBOR2_DEFAULT_MAX_DEPTH;
         self->immutable = false;
         self->shared_index = -1;
         self->decode_depth = 0;
@@ -152,18 +153,18 @@ error:
 
 
 // CBORDecoder.__init__(self, fp=None, tag_hook=None, object_hook=None,
-//                      str_errors='strict')
+//                      str_errors='strict', *, max_depth=400)
 int
 CBORDecoder_init(CBORDecoderObject *self, PyObject *args, PyObject *kwargs)
 {
     static char *keywords[] = {
-        "fp", "tag_hook", "object_hook", "str_errors", NULL
+        "fp", "tag_hook", "object_hook", "str_errors", "max_depth", NULL
     };
     PyObject *fp = NULL, *tag_hook = NULL, *object_hook = NULL,
              *str_errors = NULL;
 
-    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOO", keywords,
-                &fp, &tag_hook, &object_hook, &str_errors))
+    if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|OOOn", keywords,
+                &fp, &tag_hook, &object_hook, &str_errors, &self->max_depth))
         return -1;
 
     if (_CBORDecoder_set_fp(self, fp, NULL) == -1)
@@ -2033,9 +2034,17 @@ decode(CBORDecoderObject *self, DecodeOp
         self->shared_index = -1;
     }
 
+    if (self->decode_depth == self->max_depth) {
+        PyErr_Format(
+            _CBOR2_CBORDecodeError,
+            "maximum container nesting depth (%u) exceeded", self->max_depth);
+        return NULL;
+    }
+
     if (Py_EnterRecursiveCall(" in CBORDecoder.decode"))
         return NULL;
 
+    self->decode_depth++;
     if (fp_read(self, &lead.byte, 1) == 0) {
         switch (lead.major) {
             case 0: ret = decode_uint(self, lead.subtype);       break;
@@ -2051,6 +2060,8 @@ decode(CBORDecoderObject *self, DecodeOp
     }
 
     Py_LeaveRecursiveCall();
+    self->decode_depth--;
+
     if (options & DECODE_IMMUTABLE)
         self->immutable = old_immutable;
     if (options & DECODE_UNSHARED)
@@ -2075,10 +2086,7 @@ PyObject *
 CBORDecoder_decode(CBORDecoderObject *self)
 {
     PyObject *ret;
-    self->decode_depth++;
     ret = decode(self, DECODE_NORMAL);
-    self->decode_depth--;
-    assert(self->decode_depth >= 0);
     if (self->decode_depth == 0) {
         clear_shareable_state(self);
     }
@@ -2095,7 +2103,6 @@ CBORDecoder_decode_from_bytes(CBORDecode
     if (!_CBOR2_BytesIO && _CBOR2_init_BytesIO() == -1)
         return NULL;
 
-    self->decode_depth++;
     save_read = self->read;
     buf = PyObject_CallFunctionObjArgs(_CBOR2_BytesIO, data, NULL);
     if (buf) {
@@ -2107,8 +2114,6 @@ CBORDecoder_decode_from_bytes(CBORDecode
         Py_DECREF(buf);
     }
     self->read = save_read;
-    self->decode_depth--;
-    assert(self->decode_depth >= 0);
     if (self->decode_depth == 0) {
         clear_shareable_state(self);
     }
Index: cbor2-5.6.5/source/decoder.h
===================================================================
--- cbor2-5.6.5.orig/source/decoder.h
+++ cbor2-5.6.5/source/decoder.h
@@ -11,6 +11,7 @@ typedef struct {
     PyObject *shareables;
     PyObject *stringref_namespace;
     PyObject *str_errors;
+    ssize_t max_depth;
     bool immutable;
     Py_ssize_t shared_index;
     Py_ssize_t decode_depth;
@@ -21,3 +22,5 @@ extern PyTypeObject CBORDecoderType;
 PyObject * CBORDecoder_new(PyTypeObject *, PyObject *, PyObject *);
 int CBORDecoder_init(CBORDecoderObject *, PyObject *, PyObject *);
 PyObject * CBORDecoder_decode(CBORDecoderObject *);
+
+#define CBOR2_DEFAULT_MAX_DEPTH 400
Index: cbor2-5.6.5/tests/test_decoder.py
===================================================================
--- cbor2-5.6.5.orig/tests/test_decoder.py
+++ cbor2-5.6.5/tests/test_decoder.py
@@ -123,6 +123,21 @@ def test_load(impl):
         assert impl.load(fp=stream) == 1
 
 
+class TestMaximumDepth:
+    def test_default(self, impl) -> None:
+        with pytest.raises(
+            impl.CBORDecodeError,
+            match="maximum container nesting depth \\(400\\) exceeded",
+        ):
+            impl.loads(b"\x81" * 401 + b"\x80")
+
+    def test_explicit(self, impl) -> None:
+        with pytest.raises(
+            impl.CBORDecodeError, match=r"maximum container nesting depth \(9\) exceeded"
+        ):
+            impl.loads(b"\x81" * 10 + b"\x80", max_depth=9)
+
+
 @pytest.mark.parametrize(
     "payload, expected",
     [
Index: cbor2-5.6.5/cbor2/_decoder.py
===================================================================
--- cbor2-5.6.5.orig/cbor2/_decoder.py
+++ cbor2-5.6.5/cbor2/_decoder.py
@@ -12,6 +12,7 @@ from typing import IO, TYPE_CHECKING, An
 
 from ._types import (
     CBORDecodeEOF,
+    CBORDecodeError,
     CBORDecodeValueError,
     CBORSimpleValue,
     CBORTag,
@@ -60,6 +61,7 @@ class CBORDecoder:
         "_immutable",
         "_str_errors",
         "_stringref_namespace",
+        "_max_depth",
         "_decode_depth",
     )
 
@@ -72,6 +74,8 @@ class CBORDecoder:
         tag_hook: Callable[[CBORDecoder, CBORTag], Any] | None = None,
         object_hook: Callable[[CBORDecoder, dict[Any, Any]], Any] | None = None,
         str_errors: Literal["strict", "error", "replace"] = "strict",
+        *,
+        max_depth: int = 400,
     ):
         """
         :param fp:
@@ -90,6 +94,8 @@ class CBORDecoder:
         :param str_errors:
             determines how to handle unicode decoding errors (see the `Error Handlers`_
             section in the standard library documentation for details)
+        :param max_depth:
+            the maximum allowed container nesting depth
 
         .. _Error Handlers: https://docs.python.org/3/library/codecs.html#error-handlers
 
@@ -102,6 +108,7 @@ class CBORDecoder:
         self._shareables: list[object] = []
         self._stringref_namespace: list[str | bytes] | None = None
         self._immutable = False
+        self._max_depth = max_depth
         self._decode_depth = 0
 
     @property
@@ -209,13 +216,24 @@ class CBORDecoder:
 
         return data
 
-    def _decode(self, immutable: bool = False, unshared: bool = False) -> Any:
+    def decode(self, immutable: bool = False, unshared: bool = False) -> Any:
+        """
+        Decode the next value from the stream.
+
+        :raises CBORDecodeError: if there is any problem decoding the stream
+
+        """
+        if self._decode_depth > self._max_depth:
+            raise CBORDecodeError(f"maximum container nesting depth ({self._max_depth}) exceeded")
+
         if immutable:
             old_immutable = self._immutable
             self._immutable = True
         if unshared:
             old_index = self._share_index
             self._share_index = None
+
+        self._decode_depth += 1
         try:
             initial_byte = self.read(1)[0]
             major_type = initial_byte >> 5
@@ -228,34 +246,12 @@ class CBORDecoder:
             if unshared:
                 self._share_index = old_index
 
-    @contextmanager
-    def _decoding_context(self):
-        """
-        Context manager for tracking decode depth and clearing shared state.
-
-        Shared state is cleared at the end of each top-level decode to prevent
-        shared references from leaking between independent decode operations.
-        Nested calls (from hooks) must preserve the state.
-        """
-        self._decode_depth += 1
-        try:
-            yield
-        finally:
             self._decode_depth -= 1
             assert self._decode_depth >= 0
             if self._decode_depth == 0:
                 self._shareables.clear()
                 self._share_index = None
 
-    def decode(self) -> object:
-        """
-        Decode the next value from the stream.
-
-        :raises CBORDecodeError: if there is any problem decoding the stream
-        """
-        with self._decoding_context():
-            return self._decode()
-
     def decode_from_bytes(self, buf: bytes) -> object:
         """
         Wrap the given bytestring as a file and call :meth:`decode` with it as
@@ -265,13 +261,12 @@ class CBORDecoder:
         object needs to be decoded separately from the rest but while still
         taking advantage of the shared value registry.
         """
-        with self._decoding_context():
-            with BytesIO(buf) as fp:
-                old_fp = self.fp
-                self.fp = fp
-                retval = self._decode()
-                self.fp = old_fp
-                return retval
+        with BytesIO(buf) as fp:
+            old_fp = self.fp
+            self.fp = fp
+            retval = self.decode()
+            self.fp = old_fp
+            return retval
 
     @overload
     def _decode_length(self, subtype: int) -> int: ...
@@ -422,7 +417,7 @@ class CBORDecoder:
             if not self._immutable:
                 self.set_shareable(items)
             while True:
-                value = self._decode()
+                value = self.decode()
                 if value is break_marker:
                     break
                 else:
@@ -436,7 +431,7 @@ class CBORDecoder:
                 self.set_shareable(items)
 
             for index in range(length):
-                items.append(self._decode())
+                items.append(self.decode())
 
         if self._immutable:
             items_tuple = tuple(items)
@@ -453,17 +448,17 @@ class CBORDecoder:
             dictionary: dict[Any, Any] = {}
             self.set_shareable(dictionary)
             while True:
-                key = self._decode(immutable=True, unshared=True)
+                key = self.decode(immutable=True, unshared=True)
                 if key is break_marker:
                     break
                 else:
-                    dictionary[key] = self._decode(unshared=True)
+                    dictionary[key] = self.decode(unshared=True)
         else:
             dictionary = {}
             self.set_shareable(dictionary)
             for _ in range(length):
-                key = self._decode(immutable=True, unshared=True)
-                dictionary[key] = self._decode(unshared=True)
+                key = self.decode(immutable=True, unshared=True)
+                dictionary[key] = self.decode(unshared=True)
 
         if self._object_hook:
             dictionary = self._object_hook(self, dictionary)
@@ -483,7 +478,7 @@ class CBORDecoder:
 
         tag = CBORTag(tagnum, None)
         self.set_shareable(tag)
-        tag.value = self._decode(unshared=True)
+        tag.value = self.decode(unshared=True)
         if self._tag_hook:
             tag = self._tag_hook(self, tag)
 
@@ -508,17 +503,17 @@ class CBORDecoder:
     #
     def decode_epoch_date(self) -> date:
         # Semantic tag 100
-        value = self._decode()
+        value = self.decode()
         return self.set_shareable(date.fromordinal(value + 719163))
 
     def decode_date_string(self) -> date:
         # Semantic tag 1004
-        value = self._decode()
+        value = self.decode()
         return self.set_shareable(date.fromisoformat(value))
 
     def decode_datetime_string(self) -> datetime:
         # Semantic tag 0
-        value = self._decode()
+        value = self.decode()
         match = timestamp_re.match(value)
         if match:
             (
@@ -566,7 +561,7 @@ class CBORDecoder:
 
     def decode_epoch_datetime(self) -> datetime:
         # Semantic tag 1
-        value = self._decode()
+        value = self.decode()
 
         try:
             tmp = datetime.fromtimestamp(value, timezone.utc)
@@ -579,7 +574,7 @@ class CBORDecoder:
         # Semantic tag 2
         from binascii import hexlify
 
-        value = self._decode()
+        value = self.decode()
         if not isinstance(value, bytes):
             raise CBORDecodeValueError("invalid bignum value " + str(value))
 
@@ -594,7 +589,7 @@ class CBORDecoder:
         from decimal import Decimal
 
         try:
-            exp, sig = self._decode()
+            exp, sig = self.decode()
         except (TypeError, ValueError) as e:
             raise CBORDecodeValueError("Incorrect tag 4 payload") from e
         tmp = Decimal(sig).as_tuple()
@@ -605,7 +600,7 @@ class CBORDecoder:
         from decimal import Decimal
 
         try:
-            exp, sig = self._decode()
+            exp, sig = self.decode()
         except (TypeError, ValueError) as e:
             raise CBORDecodeValueError("Incorrect tag 5 payload") from e
 
@@ -616,7 +611,7 @@ class CBORDecoder:
         if self._stringref_namespace is None:
             raise CBORDecodeValueError("string reference outside of namespace")
 
-        index: int = self._decode()
+        index: int = self.decode()
         try:
             value = self._stringref_namespace[index]
         except IndexError:
@@ -630,13 +625,13 @@ class CBORDecoder:
         self._share_index = len(self._shareables)
         self._shareables.append(None)
         try:
-            return self._decode()
+            return self.decode()
         finally:
             self._share_index = old_index
 
     def decode_sharedref(self) -> Any:
         # Semantic tag 29
-        value = self._decode(unshared=True)
+        value = self.decode(unshared=True)
         try:
             shared = self._shareables[value]
         except IndexError:
@@ -651,7 +646,7 @@ class CBORDecoder:
         # Semantic tag 30
         from fractions import Fraction
 
-        inputval = self._decode(immutable=True, unshared=True)
+        inputval = self.decode(immutable=True, unshared=True)
         try:
             value = Fraction(*inputval)
         except (TypeError, ZeroDivisionError) as exc:
@@ -667,7 +662,7 @@ class CBORDecoder:
     def decode_regexp(self) -> re.Pattern[str]:
         # Semantic tag 35
         try:
-            value = re.compile(self._decode())
+            value = re.compile(self.decode())
         except re.error as exc:
             raise CBORDecodeValueError("error decoding regular expression") from exc
 
@@ -678,7 +673,7 @@ class CBORDecoder:
         from email.parser import Parser
 
         try:
-            value = Parser().parsestr(self._decode())
+            value = Parser().parsestr(self.decode())
         except TypeError as exc:
             raise CBORDecodeValueError("error decoding MIME message") from exc
 
@@ -689,7 +684,7 @@ class CBORDecoder:
         from uuid import UUID
 
         try:
-            value = UUID(bytes=self._decode())
+            value = UUID(bytes=self.decode())
         except (TypeError, ValueError) as exc:
             raise CBORDecodeValueError("error decoding UUID value") from exc
 
@@ -699,16 +694,16 @@ class CBORDecoder:
         # Semantic tag 256
         old_namespace = self._stringref_namespace
         self._stringref_namespace = []
-        value = self._decode()
+        value = self.decode()
         self._stringref_namespace = old_namespace
         return value
 
     def decode_set(self) -> set[Any] | frozenset[Any]:
         # Semantic tag 258
         if self._immutable:
-            return self.set_shareable(frozenset(self._decode(immutable=True)))
+            return self.set_shareable(frozenset(self.decode(immutable=True)))
         else:
-            return self.set_shareable(set(self._decode(immutable=True)))
+            return self.set_shareable(set(self.decode(immutable=True)))
 
     def decode_ipaddress(self) -> IPv4Address | IPv6Address | CBORTag:
         # Semantic tag 260
@@ -741,7 +736,7 @@ class CBORDecoder:
 
     def decode_self_describe_cbor(self) -> Any:
         # Semantic tag 55799
-        return self._decode()
+        return self.decode()
 
     #
     # Special decoders (major tag 7)
@@ -813,6 +808,8 @@ def loads(
     tag_hook: Callable[[CBORDecoder, CBORTag], Any] | None = None,
     object_hook: Callable[[CBORDecoder, dict[Any, Any]], Any] | None = None,
     str_errors: Literal["strict", "error", "replace"] = "strict",
+    *,
+    max_depth: int = 400,
 ) -> Any:
     """
     Deserialize an object from a bytestring.
@@ -831,6 +828,8 @@ def loads(
     :param str_errors:
         determines how to handle unicode decoding errors (see the `Error Handlers`_
         section in the standard library documentation for details)
+    :param max_depth:
+        the maximum allowed container nesting depth
     :return:
         the deserialized object
 
@@ -839,7 +838,7 @@ def loads(
     """
     with BytesIO(s) as fp:
         return CBORDecoder(
-            fp, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors
+            fp, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors, max_depth=max_depth
         ).decode()
 
 
@@ -848,6 +847,8 @@ def load(
     tag_hook: Callable[[CBORDecoder, CBORTag], Any] | None = None,
     object_hook: Callable[[CBORDecoder, dict[Any, Any]], Any] | None = None,
     str_errors: Literal["strict", "error", "replace"] = "strict",
+    *,
+    max_depth: int = 400,
 ) -> Any:
     """
     Deserialize an object from an open file.
@@ -866,6 +867,8 @@ def load(
     :param str_errors:
         determines how to handle unicode decoding errors (see the `Error Handlers`_
         section in the standard library documentation for details)
+    :param max_depth:
+        the maximum allowed container nesting depth
     :return:
         the deserialized object
 
@@ -873,5 +876,5 @@ def load(
 
     """
     return CBORDecoder(
-        fp, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors
+        fp, tag_hook=tag_hook, object_hook=object_hook, str_errors=str_errors, max_depth=max_depth
     ).decode()
