Skip to content

Commit

Permalink
Python protobuf: kwargs initialization now allows dict() for proto ma…
Browse files Browse the repository at this point in the history
…p values.

It is now possible to initialize a field of type `map<string, Message>` with nested dict objects:
```
storage_pb2.Directory(
  name='/home/user',
  owner='user',
  content={
    '.bashrc': dict(    # <=== dict() allowed here
      size=1234,
      permissions='PRIVATE',
    ),
  },
)
```

PiperOrigin-RevId: 719437601
  • Loading branch information
protobuf-github-bot authored and copybara-github committed Jan 24, 2025
1 parent 39f13b0 commit 6235687
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 36 deletions.
4 changes: 4 additions & 0 deletions python/google/protobuf/internal/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,6 +2550,10 @@ def testMapMessageFieldConstruction(self):
msg2 = map_unittest_pb2.TestMap(
map_string_foreign_message=msg1.map_string_foreign_message)
self.assertEqual(42, msg2.map_string_foreign_message['test'].c)
msg3 = map_unittest_pb2.TestMap(
map_string_foreign_message={'test': dict(c=42)}
)
self.assertEqual(42, msg3.map_string_foreign_message['test'].c)

def testMapFieldRaisesCorrectError(self):
# Should raise a TypeError when given a non-iterable.
Expand Down
6 changes: 5 additions & 1 deletion python/google/protobuf/internal/python_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,11 @@ def init(self, **kwargs):
if _IsMapField(field):
if _IsMessageMapField(field):
for key in field_value:
field_copy[key].MergeFrom(field_value[key])
item_value = field_value[key]
if isinstance(item_value, dict):
field_copy[key].__init__(**item_value)
else:
field_copy[key].MergeFrom(item_value)
else:
field_copy.update(field_value)
else:
Expand Down
75 changes: 43 additions & 32 deletions python/google/protobuf/pyext/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1001,17 +1001,32 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
if (source_value.get() == nullptr || dest_value.get() == nullptr) {
return -1;
}
ScopedPyObjectPtr ok(PyObject_CallMethod(
dest_value.get(), "MergeFrom", "O", source_value.get()));
if (ok.get() == nullptr) {
if (!PyObject_TypeCheck(dest_value.get(), CMessage_Type)) {
PyErr_Format(PyExc_SystemError,
"Unexpectedly, a map of messages contains a "
"non-message value: %s",
Py_TYPE(dest_value.get())->tp_name);
return -1;
}
CMessage* target_message =
reinterpret_cast<CMessage*>(dest_value.get());
if (PyDict_Check(source_value.get())) {
if (InitAttributes(target_message, nullptr, source_value.get()) <
0) {
return -1;
}
} else {
ScopedPyObjectPtr ok(PyObject_CallMethod(
dest_value.get(), "MergeFrom", "O", source_value.get()));
if (ok.get() == nullptr) {
return -1;
}
}
}
} else {
ScopedPyObjectPtr function_return;
function_return.reset(
ScopedPyObjectPtr ok(
PyObject_CallMethod(map.get(), "update", "O", value));
if (function_return.get() == nullptr) {
if (ok.get() == nullptr) {
return -1;
}
}
Expand Down Expand Up @@ -1111,33 +1126,29 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
return -1;
}
}
} else {
if (PyObject_TypeCheck(value, CMessage_Type)) {
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
if (merged == nullptr) {
return -1;
}
} else {
if (descriptor->message_type()->well_known_type() !=
Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
PyObject_HasAttrString(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign")) {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign", "O", value));
if (ok.get() == nullptr) {
return -1;
}
} else {
PyErr_Format(PyExc_TypeError,
"Parameter to initialize message field must be "
"dict or instance of same class: expected %s got %s.",
std::string(descriptor->full_name()).c_str(),
Py_TYPE(value)->tp_name);
return -1;
}
} else if (PyObject_TypeCheck(value, CMessage_Type)) {
ScopedPyObjectPtr merged(MergeFrom(cmessage, value));
if (merged == nullptr) {
return -1;
}
} else if (descriptor->message_type()->well_known_type() !=
Descriptor::WELLKNOWNTYPE_UNSPECIFIED &&
PyObject_HasAttrString(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign")) {
AssureWritable(cmessage);
ScopedPyObjectPtr ok(
PyObject_CallMethod(reinterpret_cast<PyObject*>(cmessage),
"_internal_assign", "O", value));
if (ok.get() == nullptr) {
return -1;
}
} else {
PyErr_Format(PyExc_TypeError,
"Parameter to initialize message field must be "
"dict or instance of same class: expected %s got %s.",
std::string(descriptor->full_name()).c_str(),
Py_TYPE(value)->tp_name);
return -1;
}
} else {
ScopedPyObjectPtr new_val;
Expand Down
11 changes: 8 additions & 3 deletions python/message.c
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,14 @@ static bool PyUpb_Message_LookupName(PyUpb_Message* self, PyObject* py_name,
static bool PyUpb_Message_InitMessageMapEntry(PyObject* dst, PyObject* src) {
if (!src || !dst) return false;

PyObject* ok = PyObject_CallMethod(dst, "CopyFrom", "O", src);
if (!ok) return false;
Py_DECREF(ok);
if (PyDict_Check(src)) {
bool ok = PyUpb_Message_InitAttributes(dst, NULL, src) >= 0;
if (!ok) return false;
} else {
PyObject* ok = PyObject_CallMethod(dst, "CopyFrom", "O", src);
if (!ok) return false;
Py_DECREF(ok);
}

return true;
}
Expand Down

0 comments on commit 6235687

Please sign in to comment.