diff --git a/api_core/google/api_core/protobuf_helpers.py b/api_core/google/api_core/protobuf_helpers.py index 6031ff0e476f..fbe1a82db4d6 100644 --- a/api_core/google/api_core/protobuf_helpers.py +++ b/api_core/google/api_core/protobuf_helpers.py @@ -19,6 +19,8 @@ from google.protobuf.message import Message +_SENTINEL = object() + def from_any_pb(pb_type, any_pb): """Converts an ``Any`` protobuf to the specified message type. @@ -44,11 +46,13 @@ def from_any_pb(pb_type, any_pb): def check_oneof(**kwargs): - """Raise ValueError if more than one keyword argument is not none. + """Raise ValueError if more than one keyword argument is not ``None``. + Args: kwargs (dict): The keyword arguments sent to the function. + Raises: - ValueError: If more than one entry in kwargs is not none. + ValueError: If more than one entry in ``kwargs`` is not ``None``. """ # Sanity check: If no keyword arguments were sent, this is fine. if not kwargs: @@ -62,10 +66,12 @@ def check_oneof(**kwargs): def get_messages(module): - """Return a dictionary of message names and objects. + """Discovers all protobuf Message classes in a given import module. + Args: - module (module): A Python module; dir() will be run against this + module (module): A Python module; :func:`dir` will be run against this module to find Message subclasses. + Returns: dict[str, Message]: A dictionary with the Message class names as keys, and the Message subclasses themselves as values. @@ -76,3 +82,168 @@ def get_messages(module): if inspect.isclass(candidate) and issubclass(candidate, Message): answer[name] = candidate return answer + + +def _resolve_subkeys(key, separator='.'): + """Resolve a potentially nested key. + + If the key contains the ``separator`` (e.g. ``.``) then the key will be + split on the first instance of the subkey:: + + >>> _resolve_subkeys('a.b.c') + ('a', 'b.c') + >>> _resolve_subkeys('d|e|f', separator='|') + ('d', 'e|f') + + If not, the subkey will be :data:`None`:: + + >>> _resolve_subkeys('foo') + ('foo', None) + + Args: + key (str): A string that may or may not contain the separator. + separator (str): The namespace separator. Defaults to `.`. + + Returns: + Tuple[str, str]: The key and subkey(s). + """ + parts = key.split(separator, 1) + + if len(parts) > 1: + return parts + else: + return parts[0], None + + +def get(msg_or_dict, key, default=_SENTINEL): + """Retrieve a key's value from a protobuf Message or dictionary. + + Args: + mdg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the + object. + key (str): The key to retrieve from the object. + default (Any): If the key is not present on the object, and a default + is set, returns that default instead. A type-appropriate falsy + default is generally recommended, as protobuf messages almost + always have default values for unset values and it is not always + possible to tell the difference between a falsy value and an + unset one. If no default is set then :class:`KeyError` will be + raised if the key is not present in the object. + + Returns: + Any: The return value from the underlying Message or dict. + + Raises: + KeyError: If the key is not found. Note that, for unset values, + messages and dictionaries may not have consistent behavior. + TypeError: If ``msg_or_dict`` is not a Message or Mapping. + """ + # We may need to get a nested key. Resolve this. + key, subkey = _resolve_subkeys(key) + + # Attempt to get the value from the two types of objects we know about. + # If we get something else, complain. + if isinstance(msg_or_dict, Message): + answer = getattr(msg_or_dict, key, default) + elif isinstance(msg_or_dict, collections.Mapping): + answer = msg_or_dict.get(key, default) + else: + raise TypeError( + 'get() expected a dict or protobuf message, got {!r}.'.format( + type(msg_or_dict))) + + # If the object we got back is our sentinel, raise KeyError; this is + # a "not found" case. + if answer is _SENTINEL: + raise KeyError(key) + + # If a subkey exists, call this method recursively against the answer. + if subkey is not None and answer is not default: + return get(answer, subkey, default=default) + + return answer + + +def _set_field_on_message(msg, key, value): + """Set helper for protobuf Messages.""" + # Attempt to set the value on the types of objects we know how to deal + # with. + if isinstance(value, (collections.MutableSequence, tuple)): + # Clear the existing repeated protobuf message of any elements + # currently inside it. + while getattr(msg, key): + getattr(msg, key).pop() + + # Write our new elements to the repeated field. + for item in value: + if isinstance(item, collections.Mapping): + getattr(msg, key).add(**item) + else: + # protobuf's RepeatedCompositeContainer doesn't support + # append. + getattr(msg, key).extend([item]) + elif isinstance(value, collections.Mapping): + # Assign the dictionary values to the protobuf message. + for item_key, item_value in value.items(): + set(getattr(msg, key), item_key, item_value) + elif isinstance(value, Message): + getattr(msg, key).CopyFrom(value) + else: + setattr(msg, key, value) + + +def set(msg_or_dict, key, value): + """Set a key's value on a protobuf Message or dictionary. + + Args: + msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the + object. + key (str): The key to set. + value (Any): The value to set. + + Raises: + TypeError: If ``msg_or_dict`` is not a Message or dictionary. + """ + # Sanity check: Is our target object valid? + if not isinstance(msg_or_dict, (collections.MutableMapping, Message)): + raise TypeError( + 'set() expected a dict or protobuf message, got {!r}.'.format( + type(msg_or_dict))) + + # We may be setting a nested key. Resolve this. + basekey, subkey = _resolve_subkeys(key) + + # If a subkey exists, then get that object and call this method + # recursively against it using the subkey. + if subkey is not None: + if isinstance(msg_or_dict, collections.MutableMapping): + msg_or_dict.setdefault(basekey, {}) + set(get(msg_or_dict, basekey), subkey, value) + return + + if isinstance(msg_or_dict, collections.MutableMapping): + msg_or_dict[key] = value + else: + _set_field_on_message(msg_or_dict, key, value) + + +def setdefault(msg_or_dict, key, value): + """Set the key on a protobuf Message or dictionary to a given value if the + current value is falsy. + + Because protobuf Messages do not distinguish between unset values and + falsy ones particularly well (by design), this method treats any falsy + value (e.g. 0, empty list) as a target to be overwritten, on both Messages + and dictionaries. + + Args: + msg_or_dict (Union[~google.protobuf.message.Message, Mapping]): the + object. + key (str): The key on the object in question. + value (Any): The value to set. + + Raises: + TypeError: If ``msg_or_dict`` is not a Message or dictionary. + """ + if not get(msg_or_dict, key, default=None): + set(msg_or_dict, key, value) diff --git a/api_core/tests/unit/test_protobuf_helpers.py b/api_core/tests/unit/test_protobuf_helpers.py index b9aca76a9bcd..8f86aa4401d5 100644 --- a/api_core/tests/unit/test_protobuf_helpers.py +++ b/api_core/tests/unit/test_protobuf_helpers.py @@ -14,8 +14,11 @@ import pytest +from google.api import http_pb2 from google.api_core import protobuf_helpers +from google.longrunning import operations_pb2 from google.protobuf import any_pb2 +from google.protobuf import timestamp_pb2 from google.protobuf.message import Message from google.type import date_pb2 from google.type import timeofday_pb2 @@ -65,3 +68,165 @@ def test_get_messages(): # Ensure that no non-Message objects were exported. for value in answer.values(): assert issubclass(value, Message) + + +def test_get_dict_absent(): + with pytest.raises(KeyError): + assert protobuf_helpers.get({}, 'foo') + + +def test_get_dict_present(): + assert protobuf_helpers.get({'foo': 'bar'}, 'foo') == 'bar' + + +def test_get_dict_default(): + assert protobuf_helpers.get({}, 'foo', default='bar') == 'bar' + + +def test_get_dict_nested(): + assert protobuf_helpers.get({'foo': {'bar': 'baz'}}, 'foo.bar') == 'baz' + + +def test_get_dict_nested_default(): + assert protobuf_helpers.get({}, 'foo.baz', default='bacon') == 'bacon' + assert ( + protobuf_helpers.get({'foo': {}}, 'foo.baz', default='bacon') == + 'bacon') + + +def test_get_msg_sentinel(): + msg = timestamp_pb2.Timestamp() + with pytest.raises(KeyError): + assert protobuf_helpers.get(msg, 'foo') + + +def test_get_msg_present(): + msg = timestamp_pb2.Timestamp(seconds=42) + assert protobuf_helpers.get(msg, 'seconds') == 42 + + +def test_get_msg_default(): + msg = timestamp_pb2.Timestamp() + assert protobuf_helpers.get(msg, 'foo', default='bar') == 'bar' + + +def test_invalid_object(): + with pytest.raises(TypeError): + protobuf_helpers.get(object(), 'foo', 'bar') + + +def test_set_dict(): + mapping = {} + protobuf_helpers.set(mapping, 'foo', 'bar') + assert mapping == {'foo': 'bar'} + + +def test_set_msg(): + msg = timestamp_pb2.Timestamp() + protobuf_helpers.set(msg, 'seconds', 42) + assert msg.seconds == 42 + + +def test_set_dict_nested(): + mapping = {} + protobuf_helpers.set(mapping, 'foo.bar', 'baz') + assert mapping == {'foo': {'bar': 'baz'}} + + +def test_set_invalid_object(): + with pytest.raises(TypeError): + protobuf_helpers.set(object(), 'foo', 'bar') + + +def test_set_list(): + list_ops_response = operations_pb2.ListOperationsResponse() + + protobuf_helpers.set(list_ops_response, 'operations', [ + {'name': 'foo'}, + operations_pb2.Operation(name='bar'), + ]) + + assert len(list_ops_response.operations) == 2 + + for operation in list_ops_response.operations: + assert isinstance(operation, operations_pb2.Operation) + + assert list_ops_response.operations[0].name == 'foo' + assert list_ops_response.operations[1].name == 'bar' + + +def test_set_list_clear_existing(): + list_ops_response = operations_pb2.ListOperationsResponse( + operations=[{'name': 'baz'}], + ) + + protobuf_helpers.set(list_ops_response, 'operations', [ + {'name': 'foo'}, + operations_pb2.Operation(name='bar'), + ]) + + assert len(list_ops_response.operations) == 2 + for operation in list_ops_response.operations: + assert isinstance(operation, operations_pb2.Operation) + assert list_ops_response.operations[0].name == 'foo' + assert list_ops_response.operations[1].name == 'bar' + + +def test_set_msg_with_msg_field(): + rule = http_pb2.HttpRule() + pattern = http_pb2.CustomHttpPattern(kind='foo', path='bar') + + protobuf_helpers.set(rule, 'custom', pattern) + + assert rule.custom.kind == 'foo' + assert rule.custom.path == 'bar' + + +def test_set_msg_with_dict_field(): + rule = http_pb2.HttpRule() + pattern = {'kind': 'foo', 'path': 'bar'} + + protobuf_helpers.set(rule, 'custom', pattern) + + assert rule.custom.kind == 'foo' + assert rule.custom.path == 'bar' + + +def test_set_msg_nested_key(): + rule = http_pb2.HttpRule( + custom=http_pb2.CustomHttpPattern(kind='foo', path='bar')) + + protobuf_helpers.set(rule, 'custom.kind', 'baz') + + assert rule.custom.kind == 'baz' + assert rule.custom.path == 'bar' + + +def test_setdefault_dict_unset(): + mapping = {} + protobuf_helpers.setdefault(mapping, 'foo', 'bar') + assert mapping == {'foo': 'bar'} + + +def test_setdefault_dict_falsy(): + mapping = {'foo': None} + protobuf_helpers.setdefault(mapping, 'foo', 'bar') + assert mapping == {'foo': 'bar'} + + +def test_setdefault_dict_truthy(): + mapping = {'foo': 'bar'} + protobuf_helpers.setdefault(mapping, 'foo', 'baz') + assert mapping == {'foo': 'bar'} + + +def test_setdefault_pb2_falsy(): + operation = operations_pb2.Operation() + protobuf_helpers.setdefault(operation, 'name', 'foo') + assert operation.name == 'foo' + + +def test_setdefault_pb2_truthy(): + operation = operations_pb2.Operation(name='bar') + protobuf_helpers.setdefault(operation, 'name', 'foo') + assert operation.name == 'bar'