Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up finding function type variables #16562

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,9 @@
from mypy.tvar_scope import TypeVarLikeScope
from mypy.typeanal import (
SELF_TYPE_NAMES,
FindTypeVarVisitor,
TypeAnalyser,
TypeVarLikeList,
TypeVarLikeQuery,
analyze_type_alias,
check_for_explicit_any,
detect_diverging_alias,
Expand Down Expand Up @@ -2034,6 +2034,11 @@ def analyze_unbound_tvar_impl(
assert isinstance(sym.node, TypeVarExpr)
return t.name, sym.node

def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
visitor = FindTypeVarVisitor(self, self.tvar_scope)
t.accept(visitor)
return visitor.type_var_likes

def get_all_bases_tvars(
self, base_type_exprs: list[Expression], removed: list[int]
) -> TypeVarLikeList:
Expand All @@ -2046,7 +2051,7 @@ def get_all_bases_tvars(
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
base_tvars = self.find_type_var_likes(base)
tvars.extend(base_tvars)
return remove_dups(tvars)

Expand All @@ -2064,7 +2069,7 @@ def get_and_bind_all_tvars(self, type_exprs: list[Expression]) -> list[TypeVarLi
except TypeTranslationError:
# This error will be caught later.
continue
base_tvars = base.accept(TypeVarLikeQuery(self, self.tvar_scope))
base_tvars = self.find_type_var_likes(base)
tvars.extend(base_tvars)
tvars = remove_dups(tvars) # Variables are defined in order of textual appearance.
tvar_defs = []
Expand Down Expand Up @@ -3489,7 +3494,7 @@ def analyze_alias(
)
return None, [], set(), [], False

found_type_vars = typ.accept(TypeVarLikeQuery(self, self.tvar_scope))
found_type_vars = self.find_type_var_likes(typ)
tvar_defs: list[TypeVarLikeType] = []
namespace = self.qualified_name(name)
with self.tvar_scope_frame(self.tvar_scope.class_frame(namespace)):
Expand Down
252 changes: 167 additions & 85 deletions mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,32 +1570,32 @@ def tvar_scope_frame(self) -> Iterator[None]:
yield
self.tvar_scope = old_scope

def find_type_var_likes(self, t: Type, include_callables: bool = True) -> TypeVarLikeList:
return t.accept(
TypeVarLikeQuery(self.api, self.tvar_scope, include_callables=include_callables)
)

def infer_type_variables(self, type: CallableType) -> list[tuple[str, TypeVarLikeExpr]]:
"""Return list of unique type variables referred to in a callable."""
names: list[str] = []
tvars: list[TypeVarLikeExpr] = []
def find_type_var_likes(self, t: Type) -> TypeVarLikeList:
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
t.accept(visitor)
return visitor.type_var_likes

def infer_type_variables(
self, type: CallableType
) -> tuple[list[tuple[str, TypeVarLikeExpr]], bool]:
"""Infer type variables from a callable.

Return tuple with these items:
- list of unique type variables referred to in a callable
- whether there is a reference to the Self type
"""
visitor = FindTypeVarVisitor(self.api, self.tvar_scope)
for arg in type.arg_types:
for name, tvar_expr in self.find_type_var_likes(arg):
if name not in names:
names.append(name)
tvars.append(tvar_expr)
arg.accept(visitor)

# When finding type variables in the return type of a function, don't
# look inside Callable types. Type variables only appearing in
# functions in the return type belong to those functions, not the
# function we're currently analyzing.
for name, tvar_expr in self.find_type_var_likes(type.ret_type, include_callables=False):
if name not in names:
names.append(name)
tvars.append(tvar_expr)
visitor.include_callables = False
type.ret_type.accept(visitor)

if not names:
return [] # Fast path
return list(zip(names, tvars))
return visitor.type_var_likes, visitor.has_self_type

def bind_function_type_variables(
self, fun_type: CallableType, defn: Context
Expand All @@ -1615,10 +1615,7 @@ def bind_function_type_variables(
binding = self.tvar_scope.bind_new(var.name, var_expr)
defs.append(binding)
return defs, has_self_type
typevars = self.infer_type_variables(fun_type)
has_self_type = find_self_type(
fun_type, lambda name: self.api.lookup_qualified(name, defn, suppress_errors=True)
)
typevars, has_self_type = self.infer_type_variables(fun_type)
# Do not define a new type variable if already defined in scope.
typevars = [
(name, tvar) for name, tvar in typevars if not self.is_defined_type_var(name, defn)
Expand Down Expand Up @@ -2062,67 +2059,6 @@ def flatten_tvars(lists: list[list[T]]) -> list[T]:
return result


class TypeVarLikeQuery(TypeQuery[TypeVarLikeList]):
"""Find TypeVar and ParamSpec references in an unbound type."""

def __init__(
self,
api: SemanticAnalyzerCoreInterface,
scope: TypeVarLikeScope,
*,
include_callables: bool = True,
) -> None:
super().__init__(flatten_tvars)
self.api = api
self.scope = scope
self.include_callables = include_callables
# Only include type variables in type aliases args. This would be anyway
# that case if we expand (as target variables would be overridden with args)
# and it may cause infinite recursion on invalid (diverging) recursive aliases.
self.skip_alias_target = True

def _seems_like_callable(self, type: UnboundType) -> bool:
if not type.args:
return False
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))

def visit_unbound_type(self, t: UnboundType) -> TypeVarLikeList:
name = t.name
node = None
# Special case P.args and P.kwargs for ParamSpecs only.
if name.endswith("args"):
if name.endswith(".args") or name.endswith(".kwargs"):
base = ".".join(name.split(".")[:-1])
n = self.api.lookup_qualified(base, t)
if n is not None and isinstance(n.node, ParamSpecExpr):
node = n
name = base
if node is None:
node = self.api.lookup_qualified(name, t)
if (
node
and isinstance(node.node, TypeVarLikeExpr)
and self.scope.get_binding(node) is None
):
assert isinstance(node.node, TypeVarLikeExpr)
return [(name, node.node)]
elif not self.include_callables and self._seems_like_callable(t):
return []
elif node and node.fullname in LITERAL_TYPE_NAMES:
return []
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
# Don't query the second argument to Annotated for TypeVars
return self.query_types([t.args[0]])
else:
return super().visit_unbound_type(t)

def visit_callable_type(self, t: CallableType) -> TypeVarLikeList:
if self.include_callables:
return super().visit_callable_type(t)
else:
return []


class DivergingAliasDetector(TrivialSyntheticTypeTranslator):
"""See docstring of detect_diverging_alias() for details."""

Expand Down Expand Up @@ -2359,3 +2295,149 @@ def unknown_unpack(t: Type) -> bool:
if isinstance(unpacked, AnyType) and unpacked.type_of_any == TypeOfAny.special_form:
return True
return False


class FindTypeVarVisitor(SyntheticTypeVisitor[None]):
"""Type visitor that looks for type variable types and self types."""

def __init__(self, api: SemanticAnalyzerCoreInterface, scope: TypeVarLikeScope) -> None:
self.api = api
self.scope = scope
self.type_var_likes: list[tuple[str, TypeVarLikeExpr]] = []
self.has_self_type = False
self.seen_aliases: set[TypeAliasType] | None = None
self.include_callables = True

def _seems_like_callable(self, type: UnboundType) -> bool:
if not type.args:
return False
return isinstance(type.args[0], (EllipsisType, TypeList, ParamSpecType))

def visit_unbound_type(self, t: UnboundType) -> None:
name = t.name
node = None

# Special case P.args and P.kwargs for ParamSpecs only.
if name.endswith("args"):
if name.endswith(".args") or name.endswith(".kwargs"):
base = ".".join(name.split(".")[:-1])
n = self.api.lookup_qualified(base, t)
if n is not None and isinstance(n.node, ParamSpecExpr):
node = n
name = base
if node is None:
node = self.api.lookup_qualified(name, t)
if node and node.fullname in SELF_TYPE_NAMES:
self.has_self_type = True
if (
node
and isinstance(node.node, TypeVarLikeExpr)
and self.scope.get_binding(node) is None
):
if (name, node.node) not in self.type_var_likes:
self.type_var_likes.append((name, node.node))
elif not self.include_callables and self._seems_like_callable(t):
if find_self_type(
t, lambda name: self.api.lookup_qualified(name, t, suppress_errors=True)
):
self.has_self_type = True
return
elif node and node.fullname in LITERAL_TYPE_NAMES:
return
elif node and node.fullname in ANNOTATED_TYPE_NAMES and t.args:
# Don't query the second argument to Annotated for TypeVars
self.process_types([t.args[0]])
elif t.args:
self.process_types(t.args)

def visit_type_list(self, t: TypeList) -> None:
self.process_types(t.items)

def visit_callable_argument(self, t: CallableArgument) -> None:
t.typ.accept(self)

def visit_any(self, t: AnyType) -> None:
pass

def visit_uninhabited_type(self, t: UninhabitedType) -> None:
pass

def visit_none_type(self, t: NoneType) -> None:
pass

def visit_erased_type(self, t: ErasedType) -> None:
pass

def visit_deleted_type(self, t: DeletedType) -> None:
pass

def visit_type_var(self, t: TypeVarType) -> None:
self.process_types([t.upper_bound, t.default] + t.values)

def visit_param_spec(self, t: ParamSpecType) -> None:
self.process_types([t.upper_bound, t.default])

def visit_type_var_tuple(self, t: TypeVarTupleType) -> None:
self.process_types([t.upper_bound, t.default])

def visit_unpack_type(self, t: UnpackType) -> None:
self.process_types([t.type])

def visit_parameters(self, t: Parameters) -> None:
self.process_types(t.arg_types)

def visit_partial_type(self, t: PartialType) -> None:
pass

def visit_instance(self, t: Instance) -> None:
self.process_types(t.args)

def visit_callable_type(self, t: CallableType) -> None:
# FIX generics
self.process_types(t.arg_types)
t.ret_type.accept(self)

def visit_tuple_type(self, t: TupleType) -> None:
self.process_types(t.items)

def visit_typeddict_type(self, t: TypedDictType) -> None:
self.process_types(list(t.items.values()))

def visit_raw_expression_type(self, t: RawExpressionType) -> None:
pass

def visit_literal_type(self, t: LiteralType) -> None:
pass

def visit_union_type(self, t: UnionType) -> None:
self.process_types(t.items)

def visit_overloaded(self, t: Overloaded) -> None:
self.process_types(t.items) # type: ignore[arg-type]

def visit_type_type(self, t: TypeType) -> None:
t.item.accept(self)

def visit_ellipsis_type(self, t: EllipsisType) -> None:
pass

def visit_placeholder_type(self, t: PlaceholderType) -> None:
return self.process_types(t.args)

def visit_type_alias_type(self, t: TypeAliasType) -> None:
# Skip type aliases in already visited types to avoid infinite recursion.
if self.seen_aliases is None:
self.seen_aliases = set()
elif t in self.seen_aliases:
return
self.seen_aliases.add(t)
self.process_types(t.args)

def process_types(self, types: list[Type] | tuple[Type, ...]) -> None:
# Redundant type check helps mypyc.
if isinstance(types, list):
for t in types:
t.accept(self)
else:
for t in types:
t.accept(self)