Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: HTTP body field messages with enums or recursive fields #1201

Merged
merged 1 commit into from
Feb 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 37 additions & 20 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,29 +97,42 @@ def map(self) -> bool:

@utils.cached_property
def mock_value_original_type(self) -> Union[bool, str, bytes, int, float, Dict[str, Any], List[Any], None]:
# Return messages as dicts and let the message ctor handle the conversion.
if self.message:
if self.map:
# Not worth the hassle, just return an empty map.
return {}
visited_messages = set()

msg_dict = {
f.name: f.mock_value_original_type
for f in self.message.fields.values()
}
def recursive_mock_original_type(field):
if field.message:
# Return messages as dicts and let the message ctor handle the conversion.
if field.message in visited_messages:
return {}

return [msg_dict] if self.repeated else msg_dict
visited_messages.add(field.message)
if field.map:
# Not worth the hassle, just return an empty map.
return {}

answer = self.primitive_mock() or None
msg_dict = {
f.name: recursive_mock_original_type(f)
for f in field.message.fields.values()
}

# If this is a repeated field, then the mock answer should
# be a list.
if self.repeated:
first_item = self.primitive_mock(suffix=1) or None
second_item = self.primitive_mock(suffix=2) or None
answer = [first_item, second_item]
return [msg_dict] if field.repeated else msg_dict

return answer
if field.enum:
# First Truthy value, fallback to the first value
return next((v for v in field.type.values if v.number), field.type.values[0]).number

answer = field.primitive_mock() or None

# If this is a repeated field, then the mock answer should
# be a list.
if field.repeated:
first_item = field.primitive_mock(suffix=1) or None
second_item = field.primitive_mock(suffix=2) or None
answer = [first_item, second_item]

return answer

return recursive_mock_original_type(self)

@utils.cached_property
def mock_value(self) -> str:
Expand Down Expand Up @@ -887,8 +900,12 @@ class HttpRule:
def path_fields(self, method: "Method") -> List[Tuple[Field, str, str]]:
"""return list of (name, template) tuples extracted from uri."""
input = method.input
return [(input.get_field(*match.group("name").split(".")), match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)]
return [
(input.get_field(*match.group("name").split(".")),
match.group("name"), match.group("template"))
for match in path_template._VARIABLE_RE.finditer(self.uri)
if match.group("name")
]

def sample_request(self, method: "Method") -> Dict[str, Any]:
"""return json dict for sample request matching the uri template."""
Expand Down
2 changes: 1 addition & 1 deletion test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def make_enum(
name: str,
package: str = 'foo.bar.v1',
module: str = 'baz',
values: typing.Tuple[str, int] = (),
values: typing.Sequence[typing.Tuple[str, int]] = (),
meta: metadata.Metadata = None,
options: desc.EnumOptions = None,
) -> wrappers.EnumType:
Expand Down
48 changes: 48 additions & 0 deletions tests/fragments/test_non_primitive_body.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ service SmallCompute {
post: "/computation/v1/first_name/{first_name}/last_name/{last_name}"
};
};

rpc EnumBody(EnumBodyRequest) returns (EnumBodyResponse) {
option (google.api.http) = {
body: "resource"
post: "/enum_body/v1/names/{name}"
};
}

rpc RecursiveBody(RecursiveBodyRequest) returns (RecursiveBodyResponse) {
option (google.api.http) = {
body: "resource"
post: "/recursive_body/v1/names/{name}"
};
}
}

message SerialNumber {
Expand All @@ -50,4 +64,38 @@ message MethodRequest {

message MethodResponse {
string name = 1;
}

message EnumBodyRequest {
message Resource{
enum Ordering {
UNKNOWN = 0;
CHRONOLOGICAL = 1;
ALPHABETICAL = 2;
DIFFICULTY = 3;
}

Ordering ordering = 1;
}

string name = 1;
Resource resource = 2;
}

message EnumBodyResponse {
string data = 1;
}

message RecursiveBodyRequest {
message Resource {
int32 depth = 1;
Resource child_resource = 2;
}

string name = 1;
Resource resource = 2;
}

message RecursiveBodyResponse {
string data = 1;
}
40 changes: 37 additions & 3 deletions tests/unit/schema/wrappers/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from test_utils.test_utils import (
make_field,
make_message,
make_enum,
)


Expand Down Expand Up @@ -343,7 +344,41 @@ def test_mock_value_original_type_message():
assert entry_field.mock_value_original_type == {}


def test_mock_value_recursive():
def test_mock_value_original_type_enum():
mollusc_field = make_field(
name="class",
enum=make_enum(
name="Class",
values=[
("UNKNOWN", 0),
("GASTROPOD", 1),
("BIVALVE", 2),
("CEPHALOPOD", 3),
],
),
)

assert mollusc_field.mock_value_original_type == 1

empty_field = make_field(
name="empty",
enum=make_enum(
name="Empty",
values=[("UNKNOWN", 0)],
),
)

assert empty_field.mock_value_original_type == 0


@pytest.mark.parametrize(
"mock_method,expected",
[
("mock_value", "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))"),
("mock_value_original_type", {"turtle": {}}),
],
)
def test_mock_value_recursive(mock_method, expected):
# The elaborate setup is an unfortunate requirement.
file_pb = descriptor_pb2.FileDescriptorProto(
name="turtle.proto",
Expand All @@ -367,8 +402,7 @@ def test_mock_value_recursive():
turtle_field = my_api.messages["animalia.chordata.v2.Turtle"].fields["turtle"]

# If not handled properly, this will run forever and eventually OOM.
actual = turtle_field.mock_value
expected = "ac_turtle.Turtle(turtle=ac_turtle.Turtle(turtle=turtle.Turtle(turtle=None)))"
actual = getattr(turtle_field, mock_method)
assert actual == expected


Expand Down