From 906d146f4f14d651e03354d990359417c92390bb Mon Sep 17 00:00:00 2001
From: Protobuf Team Bot <protobuf-github-bot@google.com>
Date: Tue, 4 Mar 2025 13:16:32 -0800
Subject: [PATCH 1/3] Internal pure python fixes

PiperOrigin-RevId: 733441339
---
 python/google/protobuf/internal/decoder.py    | 96 ++++++++++++++-----
 .../google/protobuf/internal/message_test.py  |  1 +
 .../protobuf/internal/python_message.py       |  5 +-
 3 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 845d77427..c9c09cd6c 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -212,7 +212,10 @@ def _SimpleDecoder(wire_type, decode_value):
   def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
     if is_packed:
       local_DecodeVarint = _DecodeVarint
-      def DecodePackedField(buffer, pos, end, message, field_dict):
+      def DecodePackedField(
+          buffer, pos, end, message, field_dict, current_depth=0
+      ):
+        del current_depth  # unused
         value = field_dict.get(key)
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
@@ -227,11 +230,15 @@ def _SimpleDecoder(wire_type, decode_value):
           del value[-1]   # Discard corrupt value.
           raise _DecodeError('Packed element was truncated.')
         return pos
+
       return DecodePackedField
     elif is_repeated:
       tag_bytes = encoder.TagBytes(field_number, wire_type)
       tag_len = len(tag_bytes)
-      def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+      def DecodeRepeatedField(
+          buffer, pos, end, message, field_dict, current_depth=0
+      ):
+        del current_depth  # unused
         value = field_dict.get(key)
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
@@ -246,14 +253,16 @@ def _SimpleDecoder(wire_type, decode_value):
             if new_pos > end:
               raise _DecodeError('Truncated message.')
             return new_pos
+
       return DecodeRepeatedField
     else:
-      def DecodeField(buffer, pos, end, message, field_dict):
+      def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
         (field_dict[key], pos) = decode_value(buffer, pos)
         if pos > end:
           del field_dict[key]  # Discard corrupt value.
           raise _DecodeError('Truncated message.')
         return pos
+
       return DecodeField
 
   return SpecificDecoder
@@ -387,7 +396,9 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
   enum_type = key.enum_type
   if is_packed:
     local_DecodeVarint = _DecodeVarint
-    def DecodePackedField(buffer, pos, end, message, field_dict):
+    def DecodePackedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       """Decode serialized packed enum to its value and a new position.
 
       Args:
@@ -400,6 +411,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
       Returns:
         int, new position in serialized data.
       """
+      del current_depth  # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -429,11 +441,14 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
           del message._unknown_fields[-1]
         raise _DecodeError('Packed element was truncated.')
       return pos
+
     return DecodePackedField
   elif is_repeated:
     tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       """Decode serialized repeated enum to its value and a new position.
 
       Args:
@@ -446,6 +461,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
       Returns:
         int, new position in serialized data.
       """
+      del current_depth  # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -468,9 +484,11 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
           if new_pos > end:
             raise _DecodeError('Truncated message.')
           return new_pos
+
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       """Decode serialized repeated enum to its value and a new position.
 
       Args:
@@ -483,6 +501,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
       Returns:
         int, new position in serialized data.
       """
+      del current_depth  # unused
       value_start_pos = pos
       (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
       if pos > end:
@@ -499,6 +518,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
             (tag_bytes, buffer[value_start_pos:pos].tobytes()))
         # pylint: enable=protected-access
       return pos
+
     return DecodeField
 
 
@@ -567,7 +587,10 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
+      del current_depth  # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -582,15 +605,19 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
           # Prediction failed.  Return.
           return new_pos
+
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+      del current_depth  # unused
       (size, pos) = local_DecodeVarint(buffer, pos)
       new_pos = pos + size
       if new_pos > end:
         raise _DecodeError('Truncated string.')
       field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
       return new_pos
+
     return DecodeField
 
 
@@ -604,7 +631,10 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
+      del current_depth  # unused
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -619,15 +649,19 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
           # Prediction failed.  Return.
           return new_pos
+
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
+      del current_depth  # unused
       (size, pos) = local_DecodeVarint(buffer, pos)
       new_pos = pos + size
       if new_pos > end:
         raise _DecodeError('Truncated string.')
       field_dict[key] = buffer[pos:new_pos].tobytes()
       return new_pos
+
     return DecodeField
 
 
@@ -643,7 +677,9 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_START_GROUP)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -652,7 +688,7 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
         if value is None:
           value = field_dict.setdefault(key, new_default(message))
         # Read sub-message.
-        pos = value.add()._InternalParse(buffer, pos, end)
+        pos = value.add()._InternalParse(buffer, pos, end, current_depth)
         # Read end tag.
         new_pos = pos+end_tag_len
         if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
@@ -662,19 +698,22 @@ def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
           # Prediction failed.  Return.
           return new_pos
+
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
       # Read sub-message.
-      pos = value._InternalParse(buffer, pos, end)
+      pos = value._InternalParse(buffer, pos, end, current_depth)
       # Read end tag.
       new_pos = pos+end_tag_len
       if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
         raise _DecodeError('Missing group end tag.')
       return new_pos
+
     return DecodeField
 
 
@@ -688,7 +727,9 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
     tag_bytes = encoder.TagBytes(field_number,
                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
     tag_len = len(tag_bytes)
-    def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+    def DecodeRepeatedField(
+        buffer, pos, end, message, field_dict, current_depth=0
+    ):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -699,7 +740,10 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
         if new_pos > end:
           raise _DecodeError('Truncated message.')
         # Read sub-message.
-        if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
+        if (
+            value.add()._InternalParse(buffer, pos, new_pos, current_depth)
+            != new_pos
+        ):
           # The only reason _InternalParse would return early is if it
           # encountered an end-group tag.
           raise _DecodeError('Unexpected end-group tag.')
@@ -708,9 +752,11 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
           # Prediction failed.  Return.
           return new_pos
+
     return DecodeRepeatedField
   else:
-    def DecodeField(buffer, pos, end, message, field_dict):
+
+    def DecodeField(buffer, pos, end, message, field_dict, current_depth=0):
       value = field_dict.get(key)
       if value is None:
         value = field_dict.setdefault(key, new_default(message))
@@ -720,11 +766,12 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
       if new_pos > end:
         raise _DecodeError('Truncated message.')
       # Read sub-message.
-      if value._InternalParse(buffer, pos, new_pos) != new_pos:
+      if value._InternalParse(buffer, pos, new_pos, current_depth) != new_pos:
         # The only reason _InternalParse would return early is if it encountered
         # an end-group tag.
         raise _DecodeError('Unexpected end-group tag.')
       return new_pos
+
     return DecodeField
 
 
@@ -831,7 +878,8 @@ def MapDecoder(field_descriptor, new_default, is_message_map):
   # Can't read _concrete_class yet; might not be initialized.
   message_type = field_descriptor.message_type
 
-  def DecodeMap(buffer, pos, end, message, field_dict):
+  def DecodeMap(buffer, pos, end, message, field_dict, current_depth=0):
+    del current_depth  # Unused.
     submsg = message_type._concrete_class()
     value = field_dict.get(key)
     if value is None:
@@ -914,7 +962,7 @@ def _SkipGroup(buffer, pos, end):
     pos = new_pos
 
 
-def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
+def _DecodeUnknownFieldSet(buffer, pos, end_pos=None, current_depth=0):
   """Decode UnknownFieldSet.  Returns the UnknownFieldSet and new position."""
 
   unknown_field_set = containers.UnknownFieldSet()
@@ -924,14 +972,16 @@ def _DecodeUnknownFieldSet(buffer, pos, end_pos=None):
     field_number, wire_type = wire_format.UnpackTag(tag)
     if wire_type == wire_format.WIRETYPE_END_GROUP:
       break
-    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+    (data, pos) = _DecodeUnknownField(buffer, pos, wire_type, current_depth)
     # pylint: disable=protected-access
     unknown_field_set._add(field_number, wire_type, data)
 
   return (unknown_field_set, pos)
 
 
-def _DecodeUnknownField(buffer, pos, wire_type):
+def _DecodeUnknownField(
+    buffer, pos, wire_type, current_depth=0
+):
   """Decode a unknown field.  Returns the UnknownField and new position."""
 
   if wire_type == wire_format.WIRETYPE_VARINT:
@@ -945,7 +995,7 @@ def _DecodeUnknownField(buffer, pos, wire_type):
     data = buffer[pos:pos+size]
     pos += size
   elif wire_type == wire_format.WIRETYPE_START_GROUP:
-    (data, pos) = _DecodeUnknownFieldSet(buffer, pos)
+    (data, pos) = _DecodeUnknownFieldSet(buffer, pos, None, current_depth)
   elif wire_type == wire_format.WIRETYPE_END_GROUP:
     return (0, -1)
   else:
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 9de3e7a01..1ef6cfca5 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -79,6 +79,7 @@ from google.protobuf import descriptor_pool
 from google.protobuf import message_factory
 from google.protobuf import text_format
 from google.protobuf.internal import api_implementation
+from google.protobuf.internal import decoder
 from google.protobuf.internal import encoder
 from google.protobuf.internal import more_extensions_pb2
 from google.protobuf.internal import packed_field_test_pb2
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 803ae8484..720671214 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -1141,7 +1141,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
   local_SkipField = decoder.SkipField
   decoders_by_tag = cls._decoders_by_tag
 
-  def InternalParse(self, buffer, pos, end):
+  def InternalParse(self, buffer, pos, end, current_depth=0):
     """Create a message from serialized bytes.
 
     Args:
@@ -1190,10 +1190,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
             (tag_bytes, buffer[old_pos:new_pos].tobytes()))
         pos = new_pos
       else:
-        pos = field_decoder(buffer, new_pos, end, self, field_dict)
+        pos = field_decoder(buffer, new_pos, end, self, field_dict, current_depth)
         if field_desc:
           self._UpdateOneofState(field_desc)
     return pos
+
   cls._InternalParse = InternalParse
 
 
-- 
2.53.0

