From 32802c7b1d2eab9bd4f09b60ba125f9a89bf287c Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Tue, 10 Apr 2018 08:49:35 -0700 Subject: [PATCH] Make overloads respect keyword-only args This commit resolves https://github.com/python/mypy/issues/1907. Specifically, it turned out that support for non-positional args in overload was never implemented to begin with. Thankfully, it also turned out the bulk of the logic we wanted was already implemented within `mypy.subtypes.is_callable_subtype`. Rather then re-implementing that code, this commit refactors that method to support any kind of check, instead of specifically subtype checks. This, as a side-effect, ended up making some partial progress towards https://github.com/python/mypy/issues/4159 -- this is because unlike the existing checks, `mypy.subtypes.is_callable_subtype` *doesn't* erase types and has better support for typevars in general. The reason this commit does not fully remove type erasure from overload checks is because the new implementation still calls `mypy.meet.is_overlapping_types` which *does* perform erasure. But fixing that seemed out-of-scope for this commit, so I stopped here. --- mypy/checker.py | 93 +++++++++++---------------- mypy/constraints.py | 4 +- mypy/subtypes.py | 67 +++++++++++-------- mypy/types.py | 7 ++ test-data/unit/check-overloading.test | 58 +++++++++++++++-- 5 files changed, 139 insertions(+), 90 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 208919ee59db0..653fddacd56d7 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -20,7 +20,7 @@ Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt, ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr, Import, ImportFrom, ImportAll, ImportBase, - ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, + ARG_POS, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, LITERAL_TYPE, MDEF, GDEF, CONTRAVARIANT, COVARIANT, INVARIANT, ) from mypy import nodes @@ -39,7 +39,7 @@ from mypy import messages from mypy.subtypes import ( is_subtype, is_equivalent, is_proper_subtype, is_more_precise, - restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype, + restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible, unify_generic_callable, find_member ) from mypy.maptype import map_instance_to_supertype @@ -437,7 +437,8 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: assert isinstance(impl_type, CallableType) assert isinstance(sig1, CallableType) - if not is_callable_subtype(impl_type, sig1, ignore_return=True): + if not is_callable_compatible(impl_type, sig1, + is_compat=is_subtype, ignore_return=True): self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl) impl_type_subst = impl_type if impl_type.variables: @@ -3530,37 +3531,45 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool: """ if isinstance(signature, CallableType): if isinstance(other, CallableType): - # TODO varargs - # TODO keyword args - # TODO erasure # TODO allow to vary covariantly + # Check if the argument counts are overlapping. min_args = max(signature.min_args, other.min_args) - max_args = min(len(signature.arg_types), len(other.arg_types)) + max_args = min(signature.max_positional_args(), other.max_positional_args()) if min_args > max_args: # Argument counts are not overlapping. return False - # Signatures are overlapping iff if they are overlapping for the - # smallest common argument count. - for i in range(min_args): - t1 = signature.arg_types[i] - t2 = other.arg_types[i] - if not is_overlapping_types(t1, t2): - return False + + # If one of the corresponding argument do NOT overlap, + # then the signatures are not overlapping. + if not is_callable_compatible(signature, other, + is_compat=is_overlapping_types, + ignore_return=True, + check_args_covariantly=True): + # TODO: this check (unlike the others) will erase types due to + # how is_overlapping_type is implemented. This should be + # fixed to make this check consistent with the others. + return False + # All arguments types for the smallest common argument count are # overlapping => the signature is overlapping. The overlapping is # safe if the return types are identical. if is_same_type(signature.ret_type, other.ret_type): return False + # If the first signature has more general argument types, the # latter will never be called if is_more_general_arg_prefix(signature, other): return False + # Special case: all args are subtypes, and returns are subtypes - if (all(is_proper_subtype(s, o) - for (s, o) in zip(signature.arg_types, other.arg_types)) and - is_proper_subtype(signature.ret_type, other.ret_type)): + if is_callable_compatible(signature, other, + is_compat=is_proper_subtype, + check_args_covariantly=True): return False + + # If the first signature is NOT more precise then the second, + # then the overlap is unsafe. return not is_more_precise_signature(signature, other) return True @@ -3569,12 +3578,11 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: """Does t have wider arguments than s?""" # TODO should an overload with additional items be allowed to be more # general than one with fewer items (or just one item)? - # TODO check argument kinds and otherwise make more general if isinstance(t, CallableType): if isinstance(s, CallableType): - t, s = unify_generic_callables(t, s) - return all(is_proper_subtype(args, argt) - for argt, args in zip(t.arg_types, s.arg_types)) + return is_callable_compatible(t, s, + is_compat=is_proper_subtype, + ignore_return=True) elif isinstance(t, FunctionLike): if isinstance(s, FunctionLike): if len(t.items()) == len(s.items()): @@ -3583,29 +3591,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool: return False -def unify_generic_callables(t: CallableType, - s: CallableType) -> Tuple[CallableType, - CallableType]: - """Make type variables in generic callables the same if possible. - - Return updated callables. If we can't unify the type variables, - return the unmodified arguments. - """ - # TODO: Use this elsewhere when comparing generic callables. - if t.is_generic() and s.is_generic(): - t_substitutions = {} - s_substitutions = {} - for tv1, tv2 in zip(t.variables, s.variables): - # Are these something we can unify? - if tv1.id != tv2.id and is_equivalent_type_var_def(tv1, tv2): - newdef = TypeVarDef.new_unification_variable(tv2) - t_substitutions[tv1.id] = TypeVarType(newdef) - s_substitutions[tv2.id] = TypeVarType(newdef) - return (cast(CallableType, expand_type(t, t_substitutions)), - cast(CallableType, expand_type(s, s_substitutions))) - return t, s - - def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool: """Are type variable definitions equivalent? @@ -3621,9 +3606,11 @@ def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool: def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool: - # TODO check argument kinds - return all(is_same_type(argt, args) - for argt, args in zip(t.arg_types, s.arg_types)) + return is_callable_compatible(t, s, + is_compat=is_same_type, + ignore_return=True, + check_args_covariantly=True, + ignore_pos_arg_names=True) def is_more_precise_signature(t: CallableType, s: CallableType) -> bool: @@ -3631,16 +3618,10 @@ def is_more_precise_signature(t: CallableType, s: CallableType) -> bool: A signature t is more precise than s if all argument types and the return type of t are more precise than the corresponding types in s. - - Assume that the argument kinds and names are compatible, and that the - argument counts are overlapping. """ - # TODO generic function types - # Only consider the common prefix of argument types. - for argt, args in zip(t.arg_types, s.arg_types): - if not is_more_precise(argt, args): - return False - return is_more_precise(t.ret_type, s.ret_type) + return is_callable_compatible(t, s, + is_compat=is_more_precise, + check_args_covariantly=True) def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]: diff --git a/mypy/constraints.py b/mypy/constraints.py index 92a1f35b999b8..67d42cf8c9237 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -525,7 +525,9 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType) for item in items: # Return type may be indeterminate in the template, so ignore it when performing a # subtype check. - if mypy.subtypes.is_callable_subtype(item, template, ignore_return=True): + if mypy.subtypes.is_callable_compatible(item, template, + is_compat=mypy.subtypes.is_subtype, + ignore_return=True): return item # Fall back to the first item if we can't find a match. This is totally arbitrary -- # maybe we should just bail out at this point. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index a514ecce127d5..767e3cb2b8c50 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -202,8 +202,9 @@ def visit_type_var(self, left: TypeVarType) -> bool: def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): - return is_callable_subtype( + return is_callable_compatible( left, right, + is_compat=is_subtype, ignore_pos_arg_names=self.ignore_pos_arg_names) elif isinstance(right, Overloaded): return all(is_subtype(left, item, self.check_type_parameter, @@ -309,10 +310,12 @@ def visit_overloaded(self, left: Overloaded) -> bool: else: # If this one overlaps with the supertype in any way, but it wasn't # an exact match, then it's a potential error. - if (is_callable_subtype(left_item, right_item, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names) or - is_callable_subtype(right_item, left_item, ignore_return=True, - ignore_pos_arg_names=self.ignore_pos_arg_names)): + if (is_callable_compatible(left_item, right_item, + is_compat=is_subtype, ignore_return=True, + ignore_pos_arg_names=self.ignore_pos_arg_names) or + is_callable_compatible(right_item, left_item, + is_compat=is_subtype, ignore_return=True, + ignore_pos_arg_names=self.ignore_pos_arg_names)): # If this is an overload that's already been matched, there's no # problem. if left_item not in matched_overloads: @@ -562,16 +565,22 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]: return result -def is_callable_subtype(left: CallableType, right: CallableType, - ignore_return: bool = False, - ignore_pos_arg_names: bool = False, - use_proper_subtype: bool = False) -> bool: - """Is left a subtype of right?""" +def is_callable_compatible(left: CallableType, right: CallableType, + *, + is_compat: Callable[[Type, Type], bool], + ignore_return: bool = False, + ignore_pos_arg_names: bool = False, + check_args_covariantly: bool = False) -> bool: + """Is the left compatible with the right, using the provided compatibility check? - if use_proper_subtype: - is_compat = is_proper_subtype - else: - is_compat = is_subtype + If 'check_args_covariantly' is set to True, check if the left's args is + compatible with the right's instead of the other way around (contravariantly). + + This function is mostly used to check if the left is a subtype of the right which + is why the default is to check the args covariantly. However, it's occasionally + useful to check the args using some other check, so we leave the variance + configurable. + """ # If either function is implicitly typed, ignore positional arg names too if left.implicit or right.implicit: @@ -604,6 +613,9 @@ def is_callable_subtype(left: CallableType, right: CallableType, if not ignore_return and not is_compat(left.ret_type, right.ret_type): return False + if check_args_covariantly: + is_compat = flip_compat_check(is_compat) + if right.is_ellipsis_args: return True @@ -658,7 +670,7 @@ def is_callable_subtype(left: CallableType, right: CallableType, right_by_position = right.argument_by_position(j) assert right_by_position is not None if not are_args_compatible(left_by_position, right_by_position, - ignore_pos_arg_names, use_proper_subtype): + ignore_pos_arg_names, is_compat): return False j += 1 continue @@ -681,7 +693,7 @@ def is_callable_subtype(left: CallableType, right: CallableType, right_by_name = right.argument_by_name(name) assert right_by_name is not None if not are_args_compatible(left_by_name, right_by_name, - ignore_pos_arg_names, use_proper_subtype): + ignore_pos_arg_names, is_compat): return False continue @@ -690,7 +702,7 @@ def is_callable_subtype(left: CallableType, right: CallableType, if left_arg is None: return False - if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, use_proper_subtype): + if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, is_compat): return False done_with_positional = False @@ -742,7 +754,7 @@ def are_args_compatible( left: FormalArgument, right: FormalArgument, ignore_pos_arg_names: bool, - use_proper_subtype: bool) -> bool: + is_compat: Callable[[Type, Type], bool]) -> bool: # If right has a specific name it wants this argument to be, left must # have the same. if right.name is not None and left.name != right.name: @@ -753,18 +765,20 @@ def are_args_compatible( if right.pos is not None and left.pos != right.pos: return False # Left must have a more general type - if use_proper_subtype: - if not is_proper_subtype(right.typ, left.typ): - return False - else: - if not is_subtype(right.typ, left.typ): - return False + if not is_compat(right.typ, left.typ): + return False # If right's argument is optional, left's must also be. if not right.required and left.required: return False return True +def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]: + def new_is_compat(left: Type, right: Type) -> bool: + return is_compat(right, left) + return new_is_compat + + def unify_generic_callable(type: CallableType, target: CallableType, ignore_return: bool) -> Optional[CallableType]: """Try to unify a generic callable type with another callable type. @@ -907,10 +921,7 @@ def visit_type_var(self, left: TypeVarType) -> bool: def visit_callable_type(self, left: CallableType) -> bool: right = self.right if isinstance(right, CallableType): - return is_callable_subtype( - left, right, - ignore_pos_arg_names=False, - use_proper_subtype=True) + return is_callable_compatible(left, right, is_compat=is_proper_subtype) elif isinstance(right, Overloaded): return all(is_proper_subtype(left, item) for item in right.items()) diff --git a/mypy/types.py b/mypy/types.py index 96b5849ffee2a..867b29117be53 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -773,6 +773,13 @@ def max_fixed_args(self) -> int: n -= 1 return n + def max_positional_args(self) -> int: + """Returns the number of positional args. + + This includes *arg and **kwargs but excludes keyword-only args.""" + blacklist = (ARG_NAMED, ARG_NAMED_OPT) + return len([kind not in blacklist for kind in self.arg_kinds]) + def corresponding_argument(self, model: FormalArgument) -> Optional[FormalArgument]: """Return the argument in this function that corresponds to `model`""" diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 88eabcaf4a619..ae86401b4769f 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -993,9 +993,16 @@ T = TypeVar('T', bound=str) def f(x: Sequence[T]) -> None: pass @overload def f(x: Sequence[int]) -> int: pass -# These are considered overlapping despite the bound on T due to runtime type erasure. -[out] -tmp/foo.pyi:4: error: Overloaded function signatures 1 and 2 overlap with incompatible return types + +@overload +def g(x: Sequence[T]) -> None: pass +@overload +def g(x: Sequence[str]) -> int: pass + +@overload +def h(x: Sequence[str]) -> int: pass +@overload +def h(x: Sequence[T]) -> None: pass [case testOverlapWithTypeVarsWithValues] from foo import * @@ -1026,16 +1033,21 @@ g(1, 'foo') g(1, 'foo', b'bar') # E: Value of type variable "AnyStr" of "g" cannot be "object" [builtins fixtures/primitives.pyi] -[case testBadOverlapWithTypeVarsWithValues] +[case testOverlapWithTypeVarsWithValuesOrdering] from foo import * [file foo.pyi] from typing import overload, TypeVar AnyStr = TypeVar('AnyStr', bytes, str) @overload -def f(x: AnyStr) -> None: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +def f(x: AnyStr) -> None: pass @overload def f(x: str) -> bool: pass + +@overload +def g(x: str) -> bool: pass # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def g(x: AnyStr) -> None: pass [builtins fixtures/primitives.pyi] [case testOverlappingOverloadCounting] @@ -1490,3 +1502,39 @@ class Child4(ParentWithDynamicImpl): [builtins fixtures/tuple.pyi] +[case testOverloadWithNonPositionalArgs] +from typing import overload + +class A: ... +class B: ... +class C: ... + +@overload +def foo(*, p1: A, p2: B = B()) -> A: ... +@overload +def foo(*, p2: B = B()) -> B: ... +def foo(p1, p2=None): ... + +reveal_type(foo()) # E: Revealed type is '__main__.B' +reveal_type(foo(p2=B())) # E: Revealed type is '__main__.B' +reveal_type(foo(p1=A())) # E: Revealed type is '__main__.A' + +[case testOverloadWithNonPositionalArgsIgnoresOrder] +from typing import overload + +class A: ... +class B(A): ... +class X: ... +class Y: ... + +@overload +def f(*, p1: X, p2: A) -> X: ... +@overload +def f(*, p2: B, p1: X) -> Y: ... +def f(*, p1, p2): ... + +@overload +def g(*, p1: X, p2: B) -> X: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types +@overload +def g(*, p2: A, p1: X) -> Y: ... +def g(*, p1, p2): ...