From 1e7f83ea1b1945065ce1b89051cd655e4b8de22d Mon Sep 17 00:00:00 2001
From: Protobuf Team Bot <protobuf-github-bot@google.com>
Date: Tue, 13 May 2025 14:42:18 -0700
Subject: [PATCH 2/2] Add recursion depth limits to pure python

PiperOrigin-RevId: 758382549
---
 python/google/protobuf/internal/decoder.py    | 35 ++++++++++-
 .../google/protobuf/internal/decoder_test.py  | 14 +++++
 .../google/protobuf/internal/message_test.py  | 60 +++++++++++++++++--
 .../protobuf/internal/self_recursive.proto    |  1 +
 4 files changed, 105 insertions(+), 5 deletions(-)

diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 89d829142..de0bc19a5 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -668,7 +668,13 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
         # Read sub-message.
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
         pos = value.add()._InternalParse(buffer, pos, end, current_depth)
+        current_depth -= 1
         # Read end tag.
         new_pos = pos+end_tag_len
         if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -687,7 +693,11 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
       # Read sub-message.
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
       pos = value._InternalParse(buffer, pos, end, current_depth)
+      current_depth -= 1
       # Read end tag.
       new_pos = pos+end_tag_len
       if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -720,6 +730,11 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
         if new_pos > end:
           raise _DecodeError('Truncated message.')
         # Read sub-message.
+        current_depth += 1
+        if current_depth > _recursion_limit:
+          raise _DecodeError(
+              'Error parsing message: too many levels of nesting.'
+          )
         if (
             value.add()._InternalParse(buffer, pos, new_pos, current_depth)
             != new_pos
@@ -727,6 +742,7 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
           # The only reason _InternalParse would return early is if it
           # encountered an end-group tag.
           raise _DecodeError('Unexpected end-group tag.')
+        current_depth -= 1
         # Predict that the next tag is another copy of the same repeated field.
         pos = new_pos + tag_len
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -746,10 +762,14 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
       if new_pos > end:
         raise _DecodeError('Truncated message.')
       # Read sub-message.
+      current_depth += 1
+      if current_depth > _recursion_limit:
+        raise _DecodeError('Error parsing message: too many levels of nesting.')
       if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
         # The only reason _InternalParse would return early is if it encountered
         # an end-group tag.
         raise _DecodeError('Unexpected end-group tag.')
+      current_depth -= 1
       return new_pos
 
     return DecodeField
@@ -984,6 +1004,15 @@ def _SkipGroup(buffer, pos, end):
     pos = new_pos
 
 
+DEFAULT_RECURSION_LIMIT = 100
+_recursion_limit = DEFAULT_RECURSION_LIMIT
+
+
+def SetRecursionLimit(new_limit):
+  global _recursion_limit
+  _recursion_limit = new_limit
+
+
 def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
   """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
 
@@ -1017,7 +1046,11 @@ def _DecodeUnknownField(
     data = buffer[pos:pos+size].tobytes()
     pos += size
   elif wire_type == wire_format.WIRETYPE_START_GROUP:
-    (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+    current_depth += 1
+    if current_depth >= _recursion_limit:
+      raise _DecodeError('Error parsing message: too many levels of nesting.')
+    data, pos = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
+    current_depth -= 1
   elif wire_type == wire_format.WIRETYPE_END_GROUP:
     return (0, -1)
   else:
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
index f801b6e76..11e6465b6 100644
--- a/python/google/protobuf/internal/decoder_test.py
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -11,8 +11,10 @@
 import io
 import unittest
 
+from google.protobuf import message
 from google.protobuf.internal import decoder
 from google.protobuf.internal import testing_refleaks
+from google.protobuf.internal import wire_format
 
 
 _INPUT_BYTES = b'\x84r\x12'
@@ -52,6 +54,18 @@ class DecoderTest(unittest.TestCase):
     size = decoder._DecodeVarint(input_io)
     self.assertEqual(size, None)
 
+  def test_decode_unknown_group_field_too_many_levels(self):
+    data = memoryview(b'\023' * 5_000_000)
+    self.assertRaisesRegex(
+        message.DecodeError,
+        'Error parsing message',
+        decoder._DecodeUnknownField,
+        data,
+        1,
+        wire_format.WIRETYPE_START_GROUP,
+        1
+    )
+
 
 if __name__ == '__main__':
   unittest.main()
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 48e6df806..6facb8135 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -36,6 +36,7 @@ from google.protobuf.internal import enum_type_wrapper
 from google.protobuf.internal import more_extensions_pb2
 from google.protobuf.internal import more_messages_pb2
 from google.protobuf.internal import packed_field_test_pb2
+from google.protobuf.internal import self_recursive_pb2
 from google.protobuf.internal import test_proto3_optional_pb2
 from google.protobuf.internal import test_util
 from google.protobuf.internal import testing_refleaks
@@ -1339,6 +1340,52 @@ class MessageTest(unittest.TestCase):
     self.assertNotIn('oneof_string', m)
 
 
+@testing_refleaks.TestCase
+class TestRecursiveGroup(unittest.TestCase):
+
+  def _MakeRecursiveGroupMessage(self, n):
+    msg = self_recursive_pb2.SelfRecursive()
+    sub = msg
+    for _ in range(n):
+      sub = sub.sub_group
+    sub.i = 1
+    return msg.SerializeToString()
+
+  def testRecursiveGroups(self):
+    recurse_msg = self_recursive_pb2.SelfRecursive()
+    data = self._MakeRecursiveGroupMessage(100)
+    recurse_msg.ParseFromString(data)
+    self.assertTrue(recurse_msg.HasField('sub_group'))
+
+  def testRecursiveGroupsException(self):
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
+    recurse_msg = self_recursive_pb2.SelfRecursive()
+    data = self._MakeRecursiveGroupMessage(300)
+    with self.assertRaises(message.DecodeError) as context:
+      recurse_msg.ParseFromString(data)
+    self.assertIn('Error parsing message', str(context.exception))
+    if api_implementation.Type() == 'python':
+      self.assertIn('too many levels of nesting', str(context.exception))
+
+  def testRecursiveGroupsUnknownFields(self):
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
+    test_msg = unittest_pb2.TestAllTypes()
+    data = self._MakeRecursiveGroupMessage(300)  # unknown to test_msg
+    with self.assertRaises(message.DecodeError) as context:
+      test_msg.ParseFromString(data)
+    self.assertIn(
+        'Error parsing message',
+        str(context.exception),
+    )
+    if api_implementation.Type() == 'python':
+      self.assertIn('too many levels of nesting', str(context.exception))
+      decoder.SetRecursionLimit(310)
+      test_msg.ParseFromString(data)
+      decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
+
+
 # Class to test proto2-only features (required, extensions, etc.)
 @testing_refleaks.TestCase
 class Proto2Test(unittest.TestCase):
@@ -2722,8 +2769,6 @@ class PackedFieldTest(unittest.TestCase):
     self.assertEqual(golden_data, message.SerializeToString())
 
 
-@unittest.skipIf(api_implementation.Type() == 'python',
-                 'explicit tests of the C++ implementation')
 @testing_refleaks.TestCase
 class OversizeProtosTest(unittest.TestCase):
 
@@ -2740,16 +2785,23 @@ class OversizeProtosTest(unittest.TestCase):
     msg.ParseFromString(self.GenerateNestedProto(100))
 
   def testAssertOversizeProto(self):
-    api_implementation._c_module.SetAllowOversizeProtos(False)
+    if api_implementation.Type() != 'python':
+      api_implementation._c_module.SetAllowOversizeProtos(False)
     msg = unittest_pb2.TestRecursiveMessage()
     with self.assertRaises(message.DecodeError) as context:
       msg.ParseFromString(self.GenerateNestedProto(101))
     self.assertIn('Error parsing message', str(context.exception))
 
   def testSucceedOversizeProto(self):
-    api_implementation._c_module.SetAllowOversizeProtos(True)
+
+    if api_implementation.Type() == 'python':
+      decoder.SetRecursionLimit(310)
+    else:
+      api_implementation._c_module.SetAllowOversizeProtos(True)
+
     msg = unittest_pb2.TestRecursiveMessage()
     msg.ParseFromString(self.GenerateNestedProto(101))
+    decoder.SetRecursionLimit(decoder.DEFAULT_RECURSION_LIMIT)
 
 
 if __name__ == '__main__':
diff --git a/python/google/protobuf/internal/self_recursive.proto b/python/google/protobuf/internal/self_recursive.proto
index 20bc2b4d3..d2a7f004b 100644
--- a/python/google/protobuf/internal/self_recursive.proto
+++ b/python/google/protobuf/internal/self_recursive.proto
@@ -12,6 +12,7 @@ package google.protobuf.python.internal;
 message SelfRecursive {
   SelfRecursive sub = 1;
   int32 i = 2;
+  SelfRecursive sub_group = 3 [features.message_encoding = DELIMITED];
 }
 
 message IndirectRecursive {
-- 
2.51.1

