From 3729ca5b848f222e11ef0f1d351e31678369076d Mon Sep 17 00:00:00 2001
From: Protobuf Team Bot <protobuf-github-bot@google.com>
Date: Tue, 13 May 2025 14:42:18 -0700
Subject: [PATCH 2/3] Add recursion depth limits to pure python

PiperOrigin-RevId: 758382549
---
 python/google/protobuf/internal/decoder.py    | 35 ++++++++-
 .../google/protobuf/internal/decoder_test.py  | 71 +++++++++++++++++++
 .../google/protobuf/internal/message_test.py  | 58 ++++++++++++++-
 .../protobuf/internal/self_recursive.proto    | 24 +++++++
 4 files changed, 184 insertions(+), 4 deletions(-)
 create mode 100644 python/google/protobuf/internal/decoder_test.py
 create mode 100644 python/google/protobuf/internal/self_recursive.proto

diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index c9c09cd6c..e9b569f07 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -688,7 +688,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
         # Read sub-message.
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
         pos = value.add()._InternalParse(buffer, pos, end, current_depth)
+        current_depth -= 1
         # Read end tag.
         new_pos = pos+end_tag_len
         if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -707,7 +713,11 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
       # Read sub-message.
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
       pos = value._InternalParse(buffer, pos, end, current_depth)
+      current_depth -= 1
       # Read end tag.
       new_pos = pos+end_tag_len
       if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -740,6 +750,11 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
         if new_pos > end:
           raise _DecodeError('Truncated message.')
         # Read sub-message.
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
         if (
             value.add()._InternalParse(buffer, pos, new_pos, current_depth)
             != new_pos
@@ -747,6 +762,7 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
           # The only reason _InternalParse would return early is if it
           # encountered an end-group tag.
           raise _DecodeError('Unexpected end-group tag.')
+        current_depth -= 1
         # Predict that the next tag is another copy of the same repeated field.
         pos = new_pos + tag_len
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -766,10 +782,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
       if new_pos > end:
         raise _DecodeError('Truncated message.')
       # Read sub-message.
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
       if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
         # The only reason _InternalParse would return early is if it encountered
         # an end-group tag.
         raise _DecodeError('Unexpected end-group tag.')
+      current_depth -= 1
       return new_pos
 
     return DecodeField
@@ -962,6 +982,15 @@ def _SkipGroup(buffer, pos, end):
     pos = new_pos
 
 
+DEFAULT_RECURSION_LIMIT = 100
+_recursion_limit = DEFAULT_RECURSION_LIMIT
+
+
+def SetRecursionLimit(new_limit):
+  global _recursion_limit
+  _recursion_limit = new_limit
+
+
 def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
   """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
 
@@ -995,7 +1024,11 @@ def _DecodeUnknownField(
     data = buffer[pos:pos+size]
     pos += size
   elif wire_type == wire_format.WIRETYPE_START_GROUP:
-    (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+    current_depth += 1
+    if current_depth >= _recursion_limit:
+      raise _DecodeError('Error parsing message: too many levels of nesting.')
+    data, pos = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+    current_depth -= 1
   elif wire_type == wire_format.WIRETYPE_END_GROUP:
     return (0, -1)
   else:
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
new file mode 100644
index 000000000..11e6465b6
--- /dev/null
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -0,0 +1,71 @@
+# -*- coding: utf-8 -*-
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc.  All rights reserved.
+#
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file or at
+# https://developers.google.com/open-source/licenses/bsd
+
+"""Test decoder."""
+
+import io
+import unittest
+
+from google.protobuf import message
+from google.protobuf.internal import decoder
+from google.protobuf.internal import testing_refleaks
+from google.protobuf.internal import wire_format
+
+
+_INPUT_BYTES = b'\x84r\x12'
+_EXPECTED = (14596, 18)
+
+
+@testing_refleaks.TestCase
+class DecoderTest(unittest.TestCase):
+
+  def test_decode_varint_bytes(self):
+    (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 0)
+    self.assertEqual(size, _EXPECTED[0])
+    self.assertEqual(pos, 2)
+
+    (size, pos) = decoder._DecodeVarint(_INPUT_BYTES, 2)
+    self.assertEqual(size, _EXPECTED[1])
+    self.assertEqual(pos, 3)
+
+  def test_decode_varint_bytes_empty(self):
+    with self.assertRaises(IndexError) as context:
+      (size, pos) = decoder._DecodeVarint(b'', 0)
+    self.assertIn('index out of range', str(context.exception))
+
+  def test_decode_varint_bytesio(self):
+    index = 0
+    input_io = io.BytesIO(_INPUT_BYTES)
+    while True:
+      size = decoder._DecodeVarint(input_io)
+      if size is None:
+        break
+      self.assertEqual(size, _EXPECTED[index])
+      index += 1
+    self.assertEqual(index, len(_EXPECTED))
+
+  def test_decode_varint_bytesio_empty(self):
+    input_io = io.BytesIO(b'')
+    size = decoder._DecodeVarint(input_io)
+    self.assertEqual(size, None)
+
+  def test_decode_unknown_group_field_too_many_levels(self):
+    data = memoryview(b'\023' * 5_000_000)
+    self.assertRaisesRegex(
+        message.DecodeError,
+        'Error parsing message',
+        decoder._DecodeUnknownField,
+        data,
+        1,
+        wire_format.WIRETYPE_START_GROUP,
+        1
+    )
+
+
+if __name__ == '__main__':
+  unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 1ef6cfca5..8afc523ed 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -83,6 +83,7 @@ from google.protobuf.internal import decoder
 from google.protobuf.internal import encoder
 from google.protobuf.internal import more_extensions_pb2
 from google.protobuf.internal import packed_field_test_pb2
+from google.protobuf.internal import self_recursive_pb2
 from google.protobuf.internal import test_util
 from google.protobuf.internal import testing_refleaks
 from google.protobuf import message
@@ -1300,6 +1301,52 @@ class MessageTest(unittest.TestCase):
     self.assertEqual(True, m.repeated_bool[0])
 
 
+@testing_refleaks.TestCase
+class TestRecursiveGroup(unittest.TestCase):
+
+  def _MakeRecursiveGroupMessage(self, n):
+    msg = self_recursive_pb2.SelfRecursive()
+    sub = msg
+    for _ in range(n):
+      sub = sub.sub_group
+    sub.i = 1
+    return msg.SerializeToString()
+
+  def testRecursiveGroups(self):
+    recurse_msg = self_recursive_pb2.SelfRecursive()
+    data = self._MakeRecursiveGroupMessage(100)
+    recurse_msg.ParseFromString(data)
+    self.assertTrue(recurse_msg.HasField('sub_group'))
+
+  def testRecursiveGroupsException(self):
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
+    recurse_msg = self_recursive_pb2.SelfRecursive()
+    data = self._MakeRecursiveGroupMessage(300)
+    with self.assertRaises(message.DecodeError) as context:
+      recurse_msg.ParseFromString(data)
+    self.assertIn('Error parsing message', str(context.exception))
+    if api_implementation.Type() == 'python':
+      self.assertIn('too many levels of nesting', str(context.exception))
+
+  def testRecursiveGroupsUnknownFields(self):
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
+    test_msg = unittest_pb2.TestAllTypes()
+    data = self._MakeRecursiveGroupMessage(300)  # unknown to test_msg
+    with self.assertRaises(message.DecodeError) as context:
+      test_msg.ParseFromString(data)
+    self.assertIn(
+        'Error parsing message',
+        str(context.exception),
+    )
+    if api_implementation.Type() == 'python':
+      self.assertIn('too many levels of nesting', str(context.exception))
+      decoder.SetRecursionLimit(310)
+      test_msg.ParseFromString(data)
+      decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
+
+
 # Class to test proto2-only features (required, extensions, etc.)
 @testing_refleaks.TestCase
 class Proto2Test(unittest.TestCase):
@@ -2529,10 +2576,10 @@ class PackedFieldTest(unittest.TestCase):
                    b'\x70\x01')
     self.assertEqual(golden_data, message.SerializeToString())
 
-
 @unittest.skipIf(api_implementation.Type() != 'cpp' or
                  sys.version_info < (2, 7),
                  'explicit tests of the C++ implementation for PY27 and above')
+
 @testing_refleaks.TestCase
 class OversizeProtosTest(unittest.TestCase):
 
@@ -2578,7 +2625,8 @@ class OversizeProtosTest(unittest.TestCase):
 
   def testAssertOversizeProto(self):
     from google.protobuf.pyext._message import SetAllowOversizeProtos
-    SetAllowOversizeProtos(False)
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
     q = self.proto_cls()
     try:
       q.ParseFromString(self.p_serialized)
@@ -2587,10 +2635,14 @@ class OversizeProtosTest(unittest.TestCase):
 
   def testSucceedOversizeProto(self):
     from google.protobuf.pyext._message import SetAllowOversizeProtos
-    SetAllowOversizeProtos(True)
+    if api_implementation.Type() == 'python':
+      decoder.SetRecursionLimit(310)
+    else:
+      api_implementation._c_module.SetAllowOversizeProtos(True)
     q = self.proto_cls()
     q.ParseFromString(self.p_serialized)
     self.assertEqual(self.p.field.payload, q.field.payload)
+    decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
new file mode 100644
index 000000000..d2a7f004b
--- /dev/null
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -0,0 +1,24 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2024 Google Inc.  All rights reserved.
+//
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file or at
+// https://developers.google.com/open-source/licenses/bsd
+
+edition = "2023";
+
+package google.protobuf.python.internal;
+
+message SelfRecursive {
+  SelfRecursive sub = 1;
+  int32 i = 2;
+  SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
+}
+
+message IndirectRecursive {
+  IntermediateRecursive intermediate = 1;
+}
+
+message IntermediateRecursive {
+  IndirectRecursive indirect = 1;
+}
-- 
2.53.0

