Skip to content

Commit

Permalink
Avoid type size explosion when expanding types (#17842)
Browse files Browse the repository at this point in the history
If TypedDict A has multiple items that refer to TypedDict B, don't
duplicate the types representing B during type expansion (or generally
when translating types). If TypedDicts are deeply nested, this could
result in lot of redundant type objects.

Example where this could matter (assume B is a big TypedDict):

```
class B(TypedDict):
    ...

class A(TypedDict):
    a: B
    b: B
    c: B
    ...
    z: B

```

Also deduplicate large unions. It's common to have aliases that are
defined as large unions, and again we want to avoid duplicating these
unions.

This may help with #17231, but this fix may not be sufficient.
  • Loading branch information
JukkaL authored Sep 27, 2024
1 parent 1995155 commit 26a77f9
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 8 deletions.
1 change: 1 addition & 0 deletions mypy/applytype.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def __init__(
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
super().__init__()
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
Expand Down
1 change: 1 addition & 0 deletions mypy/erasetype.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class TypeVarEraser(TypeTranslator):
"""Implementation of type erasure"""

def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None:
super().__init__()
self.erase_id = erase_id
self.replacement = replacement

Expand Down
19 changes: 17 additions & 2 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator):
variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value

def __init__(self, variables: Mapping[TypeVarId, Type]) -> None:
super().__init__()
self.variables = variables
self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {}

Expand Down Expand Up @@ -454,15 +455,25 @@ def visit_tuple_type(self, t: TupleType) -> Type:
return t.copy_modified(items=items, fallback=fallback)

def visit_typeddict_type(self, t: TypedDictType) -> Type:
if cached := self.get_cached(t):
return cached
fallback = t.fallback.accept(self)
assert isinstance(fallback, ProperType) and isinstance(fallback, Instance)
return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback)
self.set_cached(t, result)
return result

def visit_literal_type(self, t: LiteralType) -> Type:
# TODO: Verify this implementation is correct
return t

def visit_union_type(self, t: UnionType) -> Type:
# Use cache to avoid O(n**2) or worse expansion of types during translation
# (only for large unions, since caching adds overhead)
use_cache = len(t.items) > 3
if use_cache and (cached := self.get_cached(t)):
return cached

expanded = self.expand_types(t.items)
# After substituting for type variables in t.items, some resulting types
# might be subtypes of others, however calling make_simplified_union()
Expand All @@ -475,7 +486,11 @@ def visit_union_type(self, t: UnionType) -> Type:
# otherwise a single item union of a type alias will break it. Note this should not
# cause infinite recursion since pathological aliases like A = Union[A, B] are
# banned at the semantic analysis level.
return get_proper_type(simplified)
result = get_proper_type(simplified)

if use_cache:
self.set_cached(t, result)
return result

def visit_partial_type(self, t: PartialType) -> Type:
return t
Expand Down
2 changes: 2 additions & 0 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
if isinstance(right, Instance):
return self._is_subtype(left.fallback, right)
elif isinstance(right, TypedDictType):
if left == right:
return True # Fast path
if not left.names_are_wider_than(right):
return False
for name, l, r in left.zip(right):
Expand Down
36 changes: 34 additions & 2 deletions mypy/type_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,26 @@ class TypeTranslator(TypeVisitor[Type]):
Subclass this and override some methods to implement a non-trivial
transformation.
We cache the results of certain translations to avoid
massively expanding the sizes of types.
"""

def __init__(self, cache: dict[Type, Type] | None = None) -> None:
# For deduplication of results
self.cache = cache

def get_cached(self, t: Type) -> Type | None:
if self.cache is None:
return None
return self.cache.get(t)

def set_cached(self, orig: Type, new: Type) -> None:
if self.cache is None:
# Minor optimization: construct lazily
self.cache = {}
self.cache[orig] = new

def visit_unbound_type(self, t: UnboundType) -> Type:
return t

Expand Down Expand Up @@ -251,28 +269,42 @@ def visit_tuple_type(self, t: TupleType) -> Type:
)

def visit_typeddict_type(self, t: TypedDictType) -> Type:
# Use cache to avoid O(n**2) or worse expansion of types during translation
if cached := self.get_cached(t):
return cached
items = {item_name: item_type.accept(self) for (item_name, item_type) in t.items.items()}
return TypedDictType(
result = TypedDictType(
items,
t.required_keys,
# TODO: This appears to be unsafe.
cast(Any, t.fallback.accept(self)),
t.line,
t.column,
)
self.set_cached(t, result)
return result

def visit_literal_type(self, t: LiteralType) -> Type:
fallback = t.fallback.accept(self)
assert isinstance(fallback, Instance) # type: ignore[misc]
return LiteralType(value=t.value, fallback=fallback, line=t.line, column=t.column)

def visit_union_type(self, t: UnionType) -> Type:
return UnionType(
# Use cache to avoid O(n**2) or worse expansion of types during translation
# (only for large unions, since caching adds overhead)
use_cache = len(t.items) > 3
if use_cache and (cached := self.get_cached(t)):
return cached

result = UnionType(
self.translate_types(t.items),
t.line,
t.column,
uses_pep604_syntax=t.uses_pep604_syntax,
)
if use_cache:
self.set_cached(t, result)
return result

def translate_types(self, types: Iterable[Type]) -> list[Type]:
return [t.accept(self) for t in types]
Expand Down
2 changes: 2 additions & 0 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2271,6 +2271,7 @@ def __init__(
lookup: Callable[[str, Context], SymbolTableNode | None],
scope: TypeVarLikeScope,
) -> None:
super().__init__()
self.seen_nodes = seen_nodes
self.lookup = lookup
self.scope = scope
Expand Down Expand Up @@ -2660,6 +2661,7 @@ class TypeVarDefaultTranslator(TrivialSyntheticTypeTranslator):
def __init__(
self, api: SemanticAnalyzerInterface, tvar_expr_name: str, context: Context
) -> None:
super().__init__()
self.api = api
self.tvar_expr_name = tvar_expr_name
self.context = context
Expand Down
13 changes: 9 additions & 4 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def _expand_once(self) -> Type:

def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]:
# Private method mostly for debugging and testing.
unroller = UnrollAliasVisitor(set())
unroller = UnrollAliasVisitor(set(), {})
if nothing_args:
alias = self.copy_modified(args=[UninhabitedType()] * len(self.args))
else:
Expand Down Expand Up @@ -2586,7 +2586,8 @@ def __hash__(self) -> int:
def __eq__(self, other: object) -> bool:
if not isinstance(other, TypedDictType):
return NotImplemented

if self is other:
return True
return (
frozenset(self.items.keys()) == frozenset(other.items.keys())
and all(
Expand Down Expand Up @@ -3507,7 +3508,11 @@ def visit_type_list(self, t: TypeList) -> Type:


class UnrollAliasVisitor(TrivialSyntheticTypeTranslator):
def __init__(self, initial_aliases: set[TypeAliasType]) -> None:
def __init__(
self, initial_aliases: set[TypeAliasType], cache: dict[Type, Type] | None
) -> None:
assert cache is not None
super().__init__(cache)
self.recursed = False
self.initial_aliases = initial_aliases

Expand All @@ -3519,7 +3524,7 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type:
# A = Tuple[B, B]
# B = int
# will not be detected as recursive on the second encounter of B.
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t})
subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}, self.cache)
result = get_proper_type(t).accept(subvisitor)
if subvisitor.recursed:
self.recursed = True
Expand Down

0 comments on commit 26a77f9

Please sign in to comment.