From 623568724601c20f04386977976544c5d85e0378 Mon Sep 17 00:00:00 2001 From: Protobuf Team Bot Date: Fri, 24 Jan 2025 14:38:04 -0800 Subject: [PATCH] Python protobuf: kwargs initialization now allows dict() for proto map values. It is now possible to initialize a field of type `map` with nested dict objects: ``` storage_pb2.Directory( name='/home/user', owner='user', content={ '.bashrc': dict( # <=== dict() allowed here size=1234, permissions='PRIVATE', ), }, ) ``` PiperOrigin-RevId: 719437601 --- .../google/protobuf/internal/message_test.py | 4 + .../protobuf/internal/python_message.py | 6 +- python/google/protobuf/pyext/message.cc | 75 +++++++++++-------- python/message.c | 11 ++- 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index dfd948c6ce2a4..36833e1448929 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -2550,6 +2550,10 @@ def testMapMessageFieldConstruction(self): msg2 = map_unittest_pb2.TestMap( map_string_foreign_message=msg1.map_string_foreign_message) self.assertEqual(42, msg2.map_string_foreign_message['test'].c) + msg3 = map_unittest_pb2.TestMap( + map_string_foreign_message={'test': dict(c=42)} + ) + self.assertEqual(42, msg3.map_string_foreign_message['test'].c) def testMapFieldRaisesCorrectError(self): # Should raise a TypeError when given a non-iterable. diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 61766099a03ae..23ac28e36d2e0 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -523,7 +523,11 @@ def init(self, **kwargs): if _IsMapField(field): if _IsMessageMapField(field): for key in field_value: - field_copy[key].MergeFrom(field_value[key]) + item_value = field_value[key] + if isinstance(item_value, dict): + field_copy[key].__init__(**item_value) + else: + field_copy[key].MergeFrom(item_value) else: field_copy.update(field_value) else: diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 531638bb166c4..6d656fc38c94c 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -1001,17 +1001,32 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { if (source_value.get() == nullptr || dest_value.get() == nullptr) { return -1; } - ScopedPyObjectPtr ok(PyObject_CallMethod( - dest_value.get(), "MergeFrom", "O", source_value.get())); - if (ok.get() == nullptr) { + if (!PyObject_TypeCheck(dest_value.get(), CMessage_Type)) { + PyErr_Format(PyExc_SystemError, + "Unexpectedly, a map of messages contains a " + "non-message value: %s", + Py_TYPE(dest_value.get())->tp_name); return -1; } + CMessage* target_message = + reinterpret_cast(dest_value.get()); + if (PyDict_Check(source_value.get())) { + if (InitAttributes(target_message, nullptr, source_value.get()) < + 0) { + return -1; + } + } else { + ScopedPyObjectPtr ok(PyObject_CallMethod( + dest_value.get(), "MergeFrom", "O", source_value.get())); + if (ok.get() == nullptr) { + return -1; + } + } } } else { - ScopedPyObjectPtr function_return; - function_return.reset( + ScopedPyObjectPtr ok( PyObject_CallMethod(map.get(), "update", "O", value)); - if (function_return.get() == nullptr) { + if (ok.get() == nullptr) { return -1; } } @@ -1111,33 +1126,29 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { return -1; } } - } else { - if (PyObject_TypeCheck(value, CMessage_Type)) { - ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); - if (merged == nullptr) { - return -1; - } - } else { - if (descriptor->message_type()->well_known_type() != - Descriptor::WELLKNOWNTYPE_UNSPECIFIED && - PyObject_HasAttrString(reinterpret_cast(cmessage), - "_internal_assign")) { - AssureWritable(cmessage); - ScopedPyObjectPtr ok( - PyObject_CallMethod(reinterpret_cast(cmessage), - "_internal_assign", "O", value)); - if (ok.get() == nullptr) { - return -1; - } - } else { - PyErr_Format(PyExc_TypeError, - "Parameter to initialize message field must be " - "dict or instance of same class: expected %s got %s.", - std::string(descriptor->full_name()).c_str(), - Py_TYPE(value)->tp_name); - return -1; - } + } else if (PyObject_TypeCheck(value, CMessage_Type)) { + ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); + if (merged == nullptr) { + return -1; } + } else if (descriptor->message_type()->well_known_type() != + Descriptor::WELLKNOWNTYPE_UNSPECIFIED && + PyObject_HasAttrString(reinterpret_cast(cmessage), + "_internal_assign")) { + AssureWritable(cmessage); + ScopedPyObjectPtr ok( + PyObject_CallMethod(reinterpret_cast(cmessage), + "_internal_assign", "O", value)); + if (ok.get() == nullptr) { + return -1; + } + } else { + PyErr_Format(PyExc_TypeError, + "Parameter to initialize message field must be " + "dict or instance of same class: expected %s got %s.", + std::string(descriptor->full_name()).c_str(), + Py_TYPE(value)->tp_name); + return -1; } } else { ScopedPyObjectPtr new_val; diff --git a/python/message.c b/python/message.c index 691ddcf5ac43a..3fe7480a6f60c 100644 --- a/python/message.c +++ b/python/message.c @@ -318,9 +318,14 @@ static bool PyUpb_Message_LookupName(PyUpb_Message* self, PyObject* py_name, static bool PyUpb_Message_InitMessageMapEntry(PyObject* dst, PyObject* src) { if (!src || !dst) return false; - PyObject* ok = PyObject_CallMethod(dst, "CopyFrom", "O", src); - if (!ok) return false; - Py_DECREF(ok); + if (PyDict_Check(src)) { + bool ok = PyUpb_Message_InitAttributes(dst, NULL, src) >= 0; + if (!ok) return false; + } else { + PyObject* ok = PyObject_CallMethod(dst, "CopyFrom", "O", src); + if (!ok) return false; + Py_DECREF(ok); + } return true; }