Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid type size explosion when expanding types #17842

Merged
merged 1 commit into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
super().__init__()
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
Expand Down
1 change: 1 addition & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class TypeVarEraser(TypeTranslator):
"""Implementation of type erasure"""

def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
super().__init__()
self.erase_id = erase_id
self.replacement = replacement

Expand Down
19 changes: 17 additions & 2 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
super().__init__()
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}

Expand Down Expand Up @@ -454,15 +455,25 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return t.copy_modified(items=items, fallback=fallback)

def visit_typeddict_type(self, t: TypedDictType) -> Type:
if cached := self.get_cached(t):
return cached
fallback = t.fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
self.set_cached(t, result)
return result

def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
return t

def visit_union_type(self, t: UnionType) -> Type:
# Use cache to avoid O(n**2) or worse expansion of types during translation
# (only for large unions, since caching adds overhead)
use_cache = len(t.items) > 3
if use_cache and (cached := self.get_cached(t)):
return cached

expanded = self.expand_types(t.items)
# After substituting for type variables in t.items, some resulting types
# might be subtypes of others, however calling make_simplified_union()
Expand All @@ -475,7 +486,11 @@ def visit_union_type(self, t: UnionType) -> Type:
# otherwise a single item union of a type alias will break it. Note this should not
# cause infinite recursion since pathological aliases like A = Union[A, B] are
# banned at the semantic analysis level.
return get_proper_type(simplified)
result = get_proper_type(simplified)

if use_cache:
self.set_cached(t, result)
return result

def visit_partial_type(self, t: PartialType) -> Type:
return t
Expand Down
2 changes: 2 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
if isinstance(right, Instance):
return self._is_subtype(left.fallback, right)
elif isinstance(right, TypedDictType):
if left == right:
return True # Fast path
if not left.names_are_wider_than(right):
return False
for name, l, r in left.zip(right):
Expand Down
36 changes: 34 additions & 2 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,26 @@ class TypeTranslator(TypeVisitor[Type]):

Subclass this and override some methods to implement a non-trivial
transformation.

We cache the results of certain translations to avoid
massively expanding the sizes of types.
"""

def __init__(self, cache: dict[Type, Type] | None = None) -> None:
# For deduplication of results
self.cache = cache

def get_cached(self, t: Type) -> Type | None:
if self.cache is None:
return None
return self.cache.get(t)

def set_cached(self, orig: Type, new: Type) -> None:
if self.cache is None:
# Minor optimization: construct lazily
self.cache = {}
self.cache[orig] = new

def visit_unbound_type(self, t: UnboundType) -> Type:
return t

Expand Down Expand Up @@ -251,28 +269,42 @@ def visit_tuple_type(self, t: TupleType) -> Type:
)

def visit_typeddict_type(self, t: TypedDictType) -> Type:
# Use cache to avoid O(n**2) or worse expansion of types during translation
if cached := self.get_cached(t):
return cached
items = {item_name: item_type.accept(self) for (item_name, item_type) in t.items.items()}
return TypedDictType(
result = TypedDictType(
items,
t.required_keys,
# TODO: This appears to be unsafe.
cast(Any, t.fallback.accept(self)),
t.line,
t.column,
)
self.set_cached(t, result)
return result

def visit_literal_type(self, t: LiteralType) -> Type:
fallback = t.fallback.accept(self)
assert isinstance(fallback, Instance) # type: ignore[misc]
return LiteralType(value=t.value, fallback=fallback, line=t.line, column=t.column)

def visit_union_type(self, t: UnionType) -> Type:
return UnionType(
# Use cache to avoid O(n**2) or worse expansion of types during translation
# (only for large unions, since caching adds overhead)
use_cache = len(t.items) > 3
if use_cache and (cached := self.get_cached(t)):
return cached

result = UnionType(
self.translate_types(t.items),
t.line,
t.column,
uses_pep604_syntax=t.uses_pep604_syntax,
)
if use_cache:
self.set_cached(t, result)
return result

def translate_types(self, types: Iterable[Type]) -> list[Type]:
return [t.accept(self) for t in types]
Expand Down
2 changes: 2 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,7 @@ def __init__(
lookup: Callable[[str, Context], SymbolTableNode | None],
scope: TypeVarLikeScope,
) -> None:
super().__init__()
self.seen_nodes = seen_nodes
self.lookup = lookup
self.scope = scope
Expand Down Expand Up @@ -2660,6 +2661,7 @@ class TypeVarDefaultTranslator(TrivialSyntheticTypeTranslator):
def __init__(
self, api: SemanticAnalyzerInterface, tvar_expr_name: str, context: Context
) -> None:
super().__init__()
self.api = api
self.tvar_expr_name = tvar_expr_name
self.context = context
Expand Down
13 changes: 9 additions & 4 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _expand_once(self) -> Type:

def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
# Private method mostly for debugging and testing.
unroller = UnrollAliasVisitor(set())
unroller = UnrollAliasVisitor(set(), {})
if nothing_args:
alias = self.copy_modified(args=[UninhabitedType()] * len(self.args))
else:
Expand Down Expand Up @@ -2586,7 +2586,8 @@ def __hash__(self) -> int:
def __eq__(self, other: object) -> bool:
if not isinstance(other, TypedDictType):
return NotImplemented

if self is other:
return True
return (
frozenset(self.items.keys()) == frozenset(other.items.keys())
and all(
Expand Down Expand Up @@ -3507,7 +3508,11 @@ def visit_type_list(self, t: TypeList) -> Type:


class UnrollAliasVisitor(TrivialSyntheticTypeTranslator):
def __init__(self, initial_aliases: set[TypeAliasType]) -> None:
def __init__(
self, initial_aliases: set[TypeAliasType], cache: dict[Type, Type] | None
) -> None:
assert cache is not None
super().__init__(cache)
self.recursed = False
self.initial_aliases = initial_aliases

Expand All @@ -3519,7 +3524,7 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# A = Tuple[B, B]
# B = int
# will not be detected as recursive on the second encounter of B.
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t})
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}, self.cache)
result = get_proper_type(t).accept(subvisitor)
if subvisitor.recursed:
self.recursed = True
Expand Down
Loading