From 26a77f9d373a49fa3796082fb55a54517a876364 Mon Sep 17 00:00:00 2001 From: Jukka Lehtosalo Date: Fri, 27 Sep 2024 17:03:59 +0100 Subject: [PATCH] Avoid type size explosion when expanding types (#17842) If TypedDict A has multiple items that refer to TypedDict B, don't duplicate the types representing B during type expansion (or generally when translating types). If TypedDicts are deeply nested, this could result in lot of redundant type objects. Example where this could matter (assume B is a big TypedDict): ``` class B(TypedDict): ... class A(TypedDict): a: B b: B c: B ... z: B ``` Also deduplicate large unions. It's common to have aliases that are defined as large unions, and again we want to avoid duplicating these unions. This may help with #17231, but this fix may not be sufficient. --- mypy/applytype.py | 1 + mypy/erasetype.py | 1 + mypy/expandtype.py | 19 +++++++++++++++++-- mypy/subtypes.py | 2 ++ mypy/type_visitor.py | 36 ++++++++++++++++++++++++++++++++++-- mypy/typeanal.py | 2 ++ mypy/types.py | 13 +++++++++---- 7 files changed, 66 insertions(+), 8 deletions(-) diff --git a/mypy/applytype.py b/mypy/applytype.py index 783748cd8a5eb..e88947cc64304 100644 --- a/mypy/applytype.py +++ b/mypy/applytype.py @@ -215,6 +215,7 @@ def __init__( bound_tvars: frozenset[TypeVarLikeType] = frozenset(), seen_aliases: frozenset[TypeInfo] = frozenset(), ) -> None: + super().__init__() self.poly_tvars = set(poly_tvars) # This is a simplified version of TypeVarScope used during semantic analysis. self.bound_tvars = bound_tvars diff --git a/mypy/erasetype.py b/mypy/erasetype.py index 5d95b221af156..222e7f2a6d7ab 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -161,6 +161,7 @@ class TypeVarEraser(TypeTranslator): """Implementation of type erasure""" def __init__(self, erase_id: Callable[[TypeVarId], bool], replacement: Type) -> None: + super().__init__() self.erase_id = erase_id self.replacement = replacement diff --git a/mypy/expandtype.py b/mypy/expandtype.py index 9336be54437b9..b2040ec074c32 100644 --- a/mypy/expandtype.py +++ b/mypy/expandtype.py @@ -179,6 +179,7 @@ class ExpandTypeVisitor(TrivialSyntheticTypeTranslator): variables: Mapping[TypeVarId, Type] # TypeVar id -> TypeVar value def __init__(self, variables: Mapping[TypeVarId, Type]) -> None: + super().__init__() self.variables = variables self.recursive_tvar_guard: dict[TypeVarId, Type | None] = {} @@ -454,15 +455,25 @@ def visit_tuple_type(self, t: TupleType) -> Type: return t.copy_modified(items=items, fallback=fallback) def visit_typeddict_type(self, t: TypedDictType) -> Type: + if cached := self.get_cached(t): + return cached fallback = t.fallback.accept(self) assert isinstance(fallback, ProperType) and isinstance(fallback, Instance) - return t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback) + result = t.copy_modified(item_types=self.expand_types(t.items.values()), fallback=fallback) + self.set_cached(t, result) + return result def visit_literal_type(self, t: LiteralType) -> Type: # TODO: Verify this implementation is correct return t def visit_union_type(self, t: UnionType) -> Type: + # Use cache to avoid O(n**2) or worse expansion of types during translation + # (only for large unions, since caching adds overhead) + use_cache = len(t.items) > 3 + if use_cache and (cached := self.get_cached(t)): + return cached + expanded = self.expand_types(t.items) # After substituting for type variables in t.items, some resulting types # might be subtypes of others, however calling make_simplified_union() @@ -475,7 +486,11 @@ def visit_union_type(self, t: UnionType) -> Type: # otherwise a single item union of a type alias will break it. Note this should not # cause infinite recursion since pathological aliases like A = Union[A, B] are # banned at the semantic analysis level. - return get_proper_type(simplified) + result = get_proper_type(simplified) + + if use_cache: + self.set_cached(t, result) + return result def visit_partial_type(self, t: PartialType) -> Type: return t diff --git a/mypy/subtypes.py b/mypy/subtypes.py index c76b3569fdd4c..608d098791a9c 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -886,6 +886,8 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: if isinstance(right, Instance): return self._is_subtype(left.fallback, right) elif isinstance(right, TypedDictType): + if left == right: + return True # Fast path if not left.names_are_wider_than(right): return False for name, l, r in left.zip(right): diff --git a/mypy/type_visitor.py b/mypy/type_visitor.py index 59e13d12485cc..38e4c5ba0d01c 100644 --- a/mypy/type_visitor.py +++ b/mypy/type_visitor.py @@ -181,8 +181,26 @@ class TypeTranslator(TypeVisitor[Type]): Subclass this and override some methods to implement a non-trivial transformation. + + We cache the results of certain translations to avoid + massively expanding the sizes of types. """ + def __init__(self, cache: dict[Type, Type] | None = None) -> None: + # For deduplication of results + self.cache = cache + + def get_cached(self, t: Type) -> Type | None: + if self.cache is None: + return None + return self.cache.get(t) + + def set_cached(self, orig: Type, new: Type) -> None: + if self.cache is None: + # Minor optimization: construct lazily + self.cache = {} + self.cache[orig] = new + def visit_unbound_type(self, t: UnboundType) -> Type: return t @@ -251,8 +269,11 @@ def visit_tuple_type(self, t: TupleType) -> Type: ) def visit_typeddict_type(self, t: TypedDictType) -> Type: + # Use cache to avoid O(n**2) or worse expansion of types during translation + if cached := self.get_cached(t): + return cached items = {item_name: item_type.accept(self) for (item_name, item_type) in t.items.items()} - return TypedDictType( + result = TypedDictType( items, t.required_keys, # TODO: This appears to be unsafe. @@ -260,6 +281,8 @@ def visit_typeddict_type(self, t: TypedDictType) -> Type: t.line, t.column, ) + self.set_cached(t, result) + return result def visit_literal_type(self, t: LiteralType) -> Type: fallback = t.fallback.accept(self) @@ -267,12 +290,21 @@ def visit_literal_type(self, t: LiteralType) -> Type: return LiteralType(value=t.value, fallback=fallback, line=t.line, column=t.column) def visit_union_type(self, t: UnionType) -> Type: - return UnionType( + # Use cache to avoid O(n**2) or worse expansion of types during translation + # (only for large unions, since caching adds overhead) + use_cache = len(t.items) > 3 + if use_cache and (cached := self.get_cached(t)): + return cached + + result = UnionType( self.translate_types(t.items), t.line, t.column, uses_pep604_syntax=t.uses_pep604_syntax, ) + if use_cache: + self.set_cached(t, result) + return result def translate_types(self, types: Iterable[Type]) -> list[Type]: return [t.accept(self) for t in types] diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 274b4b893a98a..6c94390c23dce 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -2271,6 +2271,7 @@ def __init__( lookup: Callable[[str, Context], SymbolTableNode | None], scope: TypeVarLikeScope, ) -> None: + super().__init__() self.seen_nodes = seen_nodes self.lookup = lookup self.scope = scope @@ -2660,6 +2661,7 @@ class TypeVarDefaultTranslator(TrivialSyntheticTypeTranslator): def __init__( self, api: SemanticAnalyzerInterface, tvar_expr_name: str, context: Context ) -> None: + super().__init__() self.api = api self.tvar_expr_name = tvar_expr_name self.context = context diff --git a/mypy/types.py b/mypy/types.py index 78244d0f9cf4c..b1e57b2f6a861 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -357,7 +357,7 @@ def _expand_once(self) -> Type: def _partial_expansion(self, nothing_args: bool = False) -> tuple[ProperType, bool]: # Private method mostly for debugging and testing. - unroller = UnrollAliasVisitor(set()) + unroller = UnrollAliasVisitor(set(), {}) if nothing_args: alias = self.copy_modified(args=[UninhabitedType()] * len(self.args)) else: @@ -2586,7 +2586,8 @@ def __hash__(self) -> int: def __eq__(self, other: object) -> bool: if not isinstance(other, TypedDictType): return NotImplemented - + if self is other: + return True return ( frozenset(self.items.keys()) == frozenset(other.items.keys()) and all( @@ -3507,7 +3508,11 @@ def visit_type_list(self, t: TypeList) -> Type: class UnrollAliasVisitor(TrivialSyntheticTypeTranslator): - def __init__(self, initial_aliases: set[TypeAliasType]) -> None: + def __init__( + self, initial_aliases: set[TypeAliasType], cache: dict[Type, Type] | None + ) -> None: + assert cache is not None + super().__init__(cache) self.recursed = False self.initial_aliases = initial_aliases @@ -3519,7 +3524,7 @@ def visit_type_alias_type(self, t: TypeAliasType) -> Type: # A = Tuple[B, B] # B = int # will not be detected as recursive on the second encounter of B. - subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}) + subvisitor = UnrollAliasVisitor(self.initial_aliases | {t}, self.cache) result = get_proper_type(t).accept(subvisitor) if subvisitor.recursed: self.recursed = True