From 3b0c1b873793da29ff46a42d448cda3b5f937c8a Mon Sep 17 00:00:00 2001
From: Jie Luo <jieluo@google.com>
Date: Mon, 28 Aug 2023 09:58:39 -0700
Subject: [PATCH 1/2] Fix a bug that strips options from descriptor.proto in
 Pure Python.

GetOptions on fields (which parse the _serialized_options) will be called for the first time of parse or serialize instead of Build time.

Note: GetOptions on messages are still called in Build time because of message_set_wire_format. If message options are needed in descriptor.proto, a parse error will be raised in GetOptions(). We can check the file to not invoke GetOptions() for descriptor.proto as long as message_set_wire_format not needed in descriptor.proto.

Other options except message options do not invoke GetOptions() in Build time

PiperOrigin-RevId: 560741182
---
 python/google/protobuf/descriptor.py          |  18 ++--
 python/google/protobuf/descriptor_pool.py     |   1 +
 .../protobuf/internal/python_message.py       | 102 +++++++++++++-----
 .../protobuf/internal/reflection_test.py      |  43 ++++++++
 .../protobuf/internal/unknown_fields_test.py  |  33 ------
 5 files changed, 128 insertions(+), 69 deletions(-)

diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py
index fcb87cab5..4c03a4669 100755
--- a/python/google/protobuf/descriptor.py
+++ b/python/google/protobuf/descriptor.py
@@ -176,14 +176,15 @@ class DescriptorBase(metaclass=DescriptorMetaclass):
       raise RuntimeError('Unknown options class name %s!' %
                          (self._options_class_name))
 
-    with _lock:
-      if self._serialized_options is None:
+    if self._serialized_options is None:
+      with _lock:
         self._options = options_class()
-      else:
-        self._options = _ParseOptions(options_class(),
-                                      self._serialized_options)
+    else:
+      options = _ParseOptions(options_class(), self._serialized_options)
+      with _lock:
+        self._options = options
 
-      return self._options
+    return self._options
 
 
 class _NestedDescriptorBase(DescriptorBase):
@@ -285,6 +286,7 @@ class Descriptor(_NestedDescriptorBase):
       oneofs_by_name (dict(str, OneofDescriptor)): Same objects as in
           :attr:`oneofs`, but indexed by "name" attribute.
       file (FileDescriptor): Reference to file descriptor.
+      is_map_entry: If the message type is a map entry.
 
   """
 
@@ -310,6 +312,7 @@ class Descriptor(_NestedDescriptorBase):
         serialized_start=None,
         serialized_end=None,
         syntax=None,
+        is_map_entry=False,
         create_key=None):
       _message.Message._CheckCalledFromGeneratedFile()
       return _message.default_pool.FindMessageTypeByName(full_name)
@@ -322,7 +325,7 @@ class Descriptor(_NestedDescriptorBase):
                serialized_options=None,
                is_extendable=True, extension_ranges=None, oneofs=None,
                file=None, serialized_start=None, serialized_end=None,  # pylint: disable=redefined-builtin
-               syntax=None, create_key=None):
+               syntax=None, is_map_entry=False, create_key=None):
     """Arguments to __init__() are as described in the description
     of Descriptor fields above.
 
@@ -372,6 +375,7 @@ class Descriptor(_NestedDescriptorBase):
     for oneof in self.oneofs:
       oneof.containing_type = self
     self.syntax = syntax or "proto2"
+    self._is_map_entry = is_map_entry
 
   @property
   def fields_by_camelcase_name(self):
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 1ebf11834..04bd7a3e9 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -897,6 +897,7 @@ class DescriptorPool(object):
         serialized_start=None,
         serialized_end=None,
         syntax=syntax,
+        is_map_entry=desc_proto.options.map_entry,
         # pylint: disable=protected-access
         create_key=descriptor._internal_create_key)
     for nested in desc.nested_types:
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index bf9acefd2..a72276d1a 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -182,11 +182,14 @@ class GeneratedProtocolMessageType(type):
           % (descriptor.full_name))
       return
 
-    cls._decoders_by_tag = {}
+    cls._message_set_decoders_by_tag = {}
+    cls._fields_by_tag = {}
     if (descriptor.has_options and
         descriptor.GetOptions().message_set_wire_format):
-      cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
-          decoder.MessageSetItemDecoder(descriptor), None)
+      cls._message_set_decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
+          decoder.MessageSetItemDecoder(descriptor),
+          None,
+      )
 
     # Attach stuff to each FieldDescriptor for quick lookup later on.
     for field in descriptor.fields:
@@ -272,16 +275,36 @@ def _IsMessageSetExtension(field):
 
 def _IsMapField(field):
   return (field.type == _FieldDescriptor.TYPE_MESSAGE and
-          field.message_type.has_options and
-          field.message_type.GetOptions().map_entry)
+          field.message_type._is_map_entry)
 
 
 def _IsMessageMapField(field):
   value_type = field.message_type.fields_by_name['value']
   return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
 
-
 def _AttachFieldHelpers(cls, field_descriptor):
+  is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
+  field_descriptor._default_constructor = _DefaultValueConstructorForField(
+      field_descriptor
+  )
+
+  def AddFieldByTag(wiretype, is_packed):
+    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+    cls._fields_by_tag[tag_bytes] = (field_descriptor, is_packed)
+
+  AddFieldByTag(
+      type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False
+  )
+
+  if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
+    # To support wire compatibility of adding packed = true, add a decoder for
+    # packed values regardless of the field's options.
+    AddFieldByTag(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
+
+
+def _MaybeAddEncoder(cls, field_descriptor):
+  if hasattr(field_descriptor, '_encoder'):
+    return
   is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
   is_map_entry = _IsMapField(field_descriptor)
   is_packed = field_descriptor.is_packed
@@ -301,11 +324,17 @@ def _AttachFieldHelpers(cls, field_descriptor):
 
   field_descriptor._encoder = field_encoder
   field_descriptor._sizer = sizer
-  field_descriptor._default_constructor = _DefaultValueConstructorForField(
-      field_descriptor)
 
-  def AddDecoder(wiretype, is_packed):
-    tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
+
+def _MaybeAddDecoder(cls, field_descriptor):
+  if hasattr(field_descriptor, '_decoders'):
+    return
+
+  is_repeated = field_descriptor.label == _FieldDescriptor.LABEL_REPEATED
+  is_map_entry = _IsMapField(field_descriptor)
+  field_descriptor._decoders = {}
+
+  def AddDecoder(is_packed):
     decode_type = field_descriptor.type
     if (decode_type == _FieldDescriptor.TYPE_ENUM and
         not field_descriptor.enum_type.is_closed):
@@ -337,15 +366,14 @@ def _AttachFieldHelpers(cls, field_descriptor):
           field_descriptor, field_descriptor._default_constructor,
           not field_descriptor.has_presence)
 
-    cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
+    field_descriptor._decoders[is_packed] = field_decoder
 
-  AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
-             False)
+  AddDecoder(False)
 
   if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
     # To support wire compatibility of adding packed = true, add a decoder for
     # packed values regardless of the field's options.
-    AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
+    AddDecoder(True)
 
 
 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
@@ -1031,12 +1059,17 @@ def _AddByteSizeMethod(message_descriptor, cls):
 
     size = 0
     descriptor = self.DESCRIPTOR
-    if descriptor.GetOptions().map_entry:
+    if descriptor._is_map_entry:
       # Fields of map entry should always be serialized.
-      size = descriptor.fields_by_name['key']._sizer(self.key)
-      size += descriptor.fields_by_name['value']._sizer(self.value)
+      key_field = descriptor.fields_by_name['key']
+      _MaybeAddEncoder(cls, key_field)
+      size = key_field._sizer(self.key)
+      value_field = descriptor.fields_by_name['value']
+      _MaybeAddEncoder(cls, value_field)
+      size += value_field._sizer(self.value)
     else:
       for field_descriptor, field_value in self.ListFields():
+        _MaybeAddEncoder(cls, field_descriptor)
         size += field_descriptor._sizer(field_value)
       for tag_bytes, value_bytes in self._unknown_fields:
         size += len(tag_bytes) + len(value_bytes)
@@ -1079,14 +1112,17 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
       deterministic = bool(deterministic)
 
     descriptor = self.DESCRIPTOR
-    if descriptor.GetOptions().map_entry:
+    if descriptor._is_map_entry:
       # Fields of map entry should always be serialized.
-      descriptor.fields_by_name['key']._encoder(
-          write_bytes, self.key, deterministic)
-      descriptor.fields_by_name['value']._encoder(
-          write_bytes, self.value, deterministic)
+      key_field = descriptor.fields_by_name['key']
+      _MaybeAddEncoder(cls, key_field)
+      key_field._encoder(write_bytes, self.key, deterministic)
+      value_field = descriptor.fields_by_name['value']
+      _MaybeAddEncoder(cls, value_field)
+      value_field._encoder(write_bytes, self.value, deterministic)
     else:
       for field_descriptor, field_value in self.ListFields():
+        _MaybeAddEncoder(cls, field_descriptor)
         field_descriptor._encoder(write_bytes, field_value, deterministic)
       for tag_bytes, value_bytes in self._unknown_fields:
         write_bytes(tag_bytes)
@@ -1114,7 +1150,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
 
   local_ReadTag = decoder.ReadTag
   local_SkipField = decoder.SkipField
-  decoders_by_tag = cls._decoders_by_tag
+  fields_by_tag = cls._fields_by_tag
+  message_set_decoders_by_tag = cls._message_set_decoders_by_tag
 
   def InternalParse(self, buffer, pos, end):
     """Create a message from serialized bytes.
@@ -1137,8 +1174,14 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
     unknown_field_set = self._unknown_field_set
     while pos != end:
       (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
-      field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
-      if field_decoder is None:
+      field_decoder, field_des = message_set_decoders_by_tag.get(
+          tag_bytes, (None, None)
+      )
+      if field_decoder:
+        pos = field_decoder(buffer, new_pos, end, self, field_dict)
+        continue
+      field_des, is_packed = fields_by_tag.get(tag_bytes, (None, None))
+      if field_des is None:
         if not self._unknown_fields:   # pylint: disable=protected-access
           self._unknown_fields = []    # pylint: disable=protected-access
         if unknown_field_set is None:
@@ -1167,9 +1210,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
             (tag_bytes, buffer[old_pos:new_pos].tobytes()))
         pos = new_pos
       else:
+        _MaybeAddDecoder(cls, field_des)
+        field_decoder = field_des._decoders[is_packed]
         pos = field_decoder(buffer, new_pos, end, self, field_dict)
-        if field_desc:
-          self._UpdateOneofState(field_desc)
+        if field_des.containing_oneof:
+          self._UpdateOneofState(field_des)
     return pos
   cls._InternalParse = InternalParse
 
@@ -1205,8 +1250,7 @@ def _AddIsInitializedMethod(message_descriptor, cls):
     for field, value in list(self._fields.items()):  # dict can change size!
       if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
         if field.label == _FieldDescriptor.LABEL_REPEATED:
-          if (field.message_type.has_options and
-              field.message_type.GetOptions().map_entry):
+          if (field.message_type._is_map_entry):
             continue
           for element in value:
             if not element.IsInitialized():
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 0708d51e6..c5a600fa9 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -2053,6 +2053,49 @@ class Proto2ReflectionTest(unittest.TestCase):
       # dependency on the C++ logging code.
       self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception))
 
+  def testDescriptorProtoHasFileOptions(self):
+    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+    self.assertEqual(
+        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+        'com.google.protobuf',
+    )
+
+  def testDescriptorProtoHasFieldOptions(self):
+    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+    self.assertEqual(
+        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+        'com.google.protobuf',
+    )
+    packed_desc = (
+        descriptor_pb2.SourceCodeInfo.DESCRIPTOR.nested_types_by_name.get(
+            'Location'
+        ).fields_by_name.get('path')
+    )
+    self.assertTrue(packed_desc.has_options)
+    self.assertTrue(packed_desc.GetOptions().packed)
+
+  def testDescriptorProtoHasFeatureOptions(self):
+    self.assertTrue(descriptor_pb2.DESCRIPTOR.has_options)
+    self.assertEqual(
+        descriptor_pb2.DESCRIPTOR.GetOptions().java_package,
+        'com.google.protobuf',
+    )
+    presence_desc = descriptor_pb2.FeatureSet.DESCRIPTOR.fields_by_name.get(
+        'field_presence'
+    )
+    self.assertTrue(presence_desc.has_options)
+    self.assertEqual(
+        presence_desc.GetOptions().retention,
+        descriptor_pb2.FieldOptions.OptionRetention.RETENTION_RUNTIME,
+    )
+    self.assertListsEqual(
+        presence_desc.GetOptions().targets,
+        [
+            descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FIELD,
+            descriptor_pb2.FieldOptions.OptionTargetType.TARGET_TYPE_FILE,
+        ],
+    )
+
   def testStringUTF8Serialization(self):
     proto = message_set_extensions_pb2.TestMessageSet()
     extension_message = message_set_extensions_pb2.TestMessageSetExtension2
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index ec1aa1b45..9a8d7d751 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -177,25 +177,6 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
     self.empty_message = unittest_pb2.TestEmptyMessage()
     self.empty_message.ParseFromString(self.all_fields_data)
 
-  # InternalCheckUnknownField() is an additional Pure Python check which checks
-  # a detail of unknown fields. It cannot be used by the C++
-  # implementation because some protect members are called.
-  # The test is added for historical reasons. It is not necessary as
-  # serialized string is checked.
-  # TODO(jieluo): Remove message._unknown_fields.
-  def InternalCheckUnknownField(self, name, expected_value):
-    if api_implementation.Type() != 'python':
-      return
-    field_descriptor = self.descriptor.fields_by_name[name]
-    wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
-    field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
-    result_dict = {}
-    for tag_bytes, value in self.empty_message._unknown_fields:
-      if tag_bytes == field_tag:
-        decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
-        decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
-    self.assertEqual(expected_value, result_dict[field_descriptor])
-
   def CheckUnknownField(self, name, unknown_field_set, expected_value):
     field_descriptor = self.descriptor.fields_by_name[name]
     expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
@@ -223,50 +204,36 @@ class UnknownFieldsAccessorsTest(unittest.TestCase):
     self.CheckUnknownField('optional_nested_enum',
                            unknown_field_set,
                            self.all_fields.optional_nested_enum)
-    self.InternalCheckUnknownField('optional_nested_enum',
-                                   self.all_fields.optional_nested_enum)
 
     # Test repeated enum.
     self.CheckUnknownField('repeated_nested_enum',
                            unknown_field_set,
                            self.all_fields.repeated_nested_enum)
-    self.InternalCheckUnknownField('repeated_nested_enum',
-                                   self.all_fields.repeated_nested_enum)
 
     # Test varint.
     self.CheckUnknownField('optional_int32',
                            unknown_field_set,
                            self.all_fields.optional_int32)
-    self.InternalCheckUnknownField('optional_int32',
-                                   self.all_fields.optional_int32)
 
     # Test fixed32.
     self.CheckUnknownField('optional_fixed32',
                            unknown_field_set,
                            self.all_fields.optional_fixed32)
-    self.InternalCheckUnknownField('optional_fixed32',
-                                   self.all_fields.optional_fixed32)
 
     # Test fixed64.
     self.CheckUnknownField('optional_fixed64',
                            unknown_field_set,
                            self.all_fields.optional_fixed64)
-    self.InternalCheckUnknownField('optional_fixed64',
-                                   self.all_fields.optional_fixed64)
 
     # Test length delimited.
     self.CheckUnknownField('optional_string',
                            unknown_field_set,
                            self.all_fields.optional_string.encode('utf-8'))
-    self.InternalCheckUnknownField('optional_string',
-                                   self.all_fields.optional_string)
 
     # Test group.
     self.CheckUnknownField('optionalgroup',
                            unknown_field_set,
                            (17, 0, 117))
-    self.InternalCheckUnknownField('optionalgroup',
-                                   self.all_fields.optionalgroup)
 
     self.assertEqual(98, len(unknown_field_set))
 
-- 
2.49.0

