Skip to content

Commit

Permalink
fix: Change to Set vs FrozenSet and thread the same set through (#1125)
Browse files Browse the repository at this point in the history
Co-authored-by: Anthonios Partheniou <[email protected]>
  • Loading branch information
tmc and parthea authored Sep 7, 2023
1 parent 0982f3b commit 723efca
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 27 deletions.
11 changes: 9 additions & 2 deletions gapic/schema/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def proto(self) -> Proto:
if not self.file_to_generate:
return naive

visited_messages: Set[wrappers.MessageType] = set()
# Return a context-aware proto object.
return dataclasses.replace(
naive,
Expand All @@ -754,13 +755,19 @@ def proto(self) -> Proto:
for k, v in naive.all_enums.items()
),
all_messages=collections.OrderedDict(
(k, v.with_context(collisions=naive.names))
(k, v.with_context(
collisions=naive.names,
visited_messages=visited_messages,
))
for k, v in naive.all_messages.items()
),
services=collections.OrderedDict(
# Note: services bind to themselves because services get their
# own output files.
(k, v.with_context(collisions=v.names))
(k, v.with_context(
collisions=v.names,
visited_messages=visited_messages,
))
for k, v in naive.services.items()
),
meta=naive.meta.with_context(collisions=naive.names),
Expand Down
8 changes: 4 additions & 4 deletions gapic/schema/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import dataclasses
import re
from typing import FrozenSet, Tuple, Optional
from typing import FrozenSet, Set, Tuple, Optional

from google.protobuf import descriptor_pb2

Expand All @@ -54,7 +54,7 @@ class Address(BaseAddress):
api_naming: naming.Naming = dataclasses.field(
default_factory=naming.NewNaming,
)
collisions: FrozenSet[str] = dataclasses.field(default_factory=frozenset)
collisions: Set[str] = dataclasses.field(default_factory=set)

def __eq__(self, other) -> bool:
# We don't want to use api_naming or collisions to determine equality,
Expand Down Expand Up @@ -351,7 +351,7 @@ def resolve(self, selector: str) -> str:
return f'{".".join(self.package)}.{selector}'
return selector

def with_context(self, *, collisions: FrozenSet[str]) -> 'Address':
def with_context(self, *, collisions: Set[str]) -> 'Address':
"""Return a derivative of this address with the provided context.
This method is used to address naming collisions. The returned
Expand Down Expand Up @@ -390,7 +390,7 @@ def doc(self):
return '\n\n'.join(self.documentation.leading_detached_comments)
return ''

def with_context(self, *, collisions: FrozenSet[str]) -> 'Metadata':
def with_context(self, *, collisions: Set[str]) -> 'Metadata':
"""Return a derivative of this metadata with the provided context.
This method is used to address naming collisions. The returned
Expand Down
69 changes: 48 additions & 21 deletions gapic/schema/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,8 @@ def type(self) -> Union['MessageType', 'EnumType', 'PrimitiveType']:
def with_context(
self,
*,
collisions: FrozenSet[str],
visited_messages: FrozenSet["MessageType"],
collisions: Set[str],
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'Field':
"""Return a derivative of this field with the provided context.
Expand All @@ -368,7 +368,7 @@ def with_context(
self,
message=self.message.with_context(
collisions=collisions,
skip_fields=self.message in visited_messages,
skip_fields=self.message in visited_messages if visited_messages else False,
visited_messages=visited_messages,
) if self.message else None,
enum=self.enum.with_context(collisions=collisions)
Expand Down Expand Up @@ -631,7 +631,7 @@ def path_regex_str(self) -> str:
return parsing_regex_str

def get_field(self, *field_path: str,
collisions: FrozenSet[str] = frozenset()) -> Field:
collisions: Optional[Set[str]] = None) -> Field:
"""Return a field arbitrarily deep in this message's structure.
This method recursively traverses the message tree to return the
Expand Down Expand Up @@ -672,7 +672,7 @@ def get_field(self, *field_path: str,
if len(field_path) == 1:
return cursor.with_context(
collisions=collisions,
visited_messages=frozenset({self}),
visited_messages=set({self}),
)

# Quick check: If cursor is a repeated field, then raise an exception.
Expand All @@ -698,9 +698,9 @@ def get_field(self, *field_path: str,
return cursor.message.get_field(*field_path[1:], collisions=collisions)

def with_context(self, *,
collisions: FrozenSet[str],
collisions: Set[str],
skip_fields: bool = False,
visited_messages: FrozenSet["MessageType"] = frozenset(),
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'MessageType':
"""Return a derivative of this message with the provided context.
Expand All @@ -712,7 +712,8 @@ def with_context(self, *,
underlying fields. This provides for an "exit" in the case of circular
references.
"""
visited_messages = visited_messages | {self}
visited_messages = visited_messages or set()
visited_messages.add(self)
return dataclasses.replace(
self,
fields={
Expand Down Expand Up @@ -777,7 +778,7 @@ def ident(self) -> metadata.Address:
"""Return the identifier data to be used in templates."""
return self.meta.address

def with_context(self, *, collisions: FrozenSet[str]) -> 'EnumType':
def with_context(self, *, collisions: Set[str]) -> 'EnumType':
"""Return a derivative of this enum with the provided context.
This method is used to address naming collisions. The returned
Expand Down Expand Up @@ -871,7 +872,10 @@ class ExtendedOperationInfo:
request_type: MessageType
operation_type: MessageType

def with_context(self, *, collisions: FrozenSet[str]) -> 'ExtendedOperationInfo':
def with_context(self, *,
collisions: Set[str],
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'ExtendedOperationInfo':
"""Return a derivative of this OperationInfo with the provided context.
This method is used to address naming collisions. The returned
Expand All @@ -881,10 +885,12 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'ExtendedOperationInfo'
return self if not collisions else dataclasses.replace(
self,
request_type=self.request_type.with_context(
collisions=collisions
collisions=collisions,
visited_messages=visited_messages,
),
operation_type=self.operation_type.with_context(
collisions=collisions,
visited_messages=visited_messages,
),
)

Expand All @@ -895,7 +901,10 @@ class OperationInfo:
response_type: MessageType
metadata_type: MessageType

def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo':
def with_context(self, *,
collisions: Set[str],
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'OperationInfo':
"""Return a derivative of this OperationInfo with the provided context.
This method is used to address naming collisions. The returned
Expand All @@ -905,10 +914,12 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'OperationInfo':
return dataclasses.replace(
self,
response_type=self.response_type.with_context(
collisions=collisions
collisions=collisions,
visited_messages=visited_messages,
),
metadata_type=self.metadata_type.with_context(
collisions=collisions
collisions=collisions,
visited_messages=visited_messages,
),
)

Expand Down Expand Up @@ -1533,7 +1544,10 @@ def void(self) -> bool:
"""Return True if this method has no return value, False otherwise."""
return self.output.ident.proto == 'google.protobuf.Empty'

def with_context(self, *, collisions: FrozenSet[str]) -> 'Method':
def with_context(self, *,
collisions: Set[str],
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'Method':
"""Return a derivative of this method with the provided context.
This method is used to address naming collisions. The returned
Expand All @@ -1543,21 +1557,29 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Method':
maybe_lro = None
if self.lro:
maybe_lro = self.lro.with_context(
collisions=collisions
collisions=collisions,
visited_messages=visited_messages,
) if collisions else self.lro

maybe_extended_lro = (
self.extended_lro.with_context(
collisions=collisions
collisions=collisions,
visited_messages=visited_messages,
) if self.extended_lro else None
)

return dataclasses.replace(
self,
lro=maybe_lro,
extended_lro=maybe_extended_lro,
input=self.input.with_context(collisions=collisions),
output=self.output.with_context(collisions=collisions),
input=self.input.with_context(
collisions=collisions,
visited_messages=visited_messages,
),
output=self.output.with_context(
collisions=collisions,
visited_messages=visited_messages,
),
meta=self.meta.with_context(collisions=collisions),
)

Expand Down Expand Up @@ -1842,7 +1864,10 @@ def operation_polling_method(self) -> Optional[Method]:
None
)

def with_context(self, *, collisions: FrozenSet[str]) -> 'Service':
def with_context(self, *,
collisions: Set[str],
visited_messages: Optional[Set["MessageType"]] = None,
) -> 'Service':
"""Return a derivative of this service with the provided context.
This method is used to address naming collisions. The returned
Expand All @@ -1855,7 +1880,9 @@ def with_context(self, *, collisions: FrozenSet[str]) -> 'Service':
k: v.with_context(
# A method's flattened fields create additional names
# that may conflict with module imports.
collisions=collisions | frozenset(v.flattened_fields.keys()))
collisions=collisions | set(v.flattened_fields.keys()),
visited_messages=visited_messages,
)
for k, v in self.methods.items()
},
meta=self.meta.with_context(collisions=collisions),
Expand Down

0 comments on commit 723efca

Please sign in to comment.