Skip to content

Commit

Permalink
Make overloads respect keyword-only args
Browse files Browse the repository at this point in the history
This commit resolves python#1907.

Specifically, it turned out that support for non-positional args in
overload was never implemented to begin with. Thankfully, it also turned
out the bulk of the logic we wanted was already implemented within
`mypy.subtypes.is_callable_subtype`. Rather then re-implementing that
code, this commit refactors that method to support any kind of check,
instead of specifically subtype checks.

This, as a side-effect, ended up making some partial progress towards
python#4159 -- this is because unlike
the existing checks, `mypy.subtypes.is_callable_subtype` *doesn't* erase
types and has better support for typevars in general.

The reason this commit does not fully remove type erasure from overload
checks is because the new implementation still calls
`mypy.meet.is_overlapping_types` which *does* perform erasure. But
fixing that seemed out-of-scope for this commit, so I stopped here.
  • Loading branch information
Michael0x2a committed Apr 23, 2018
1 parent 18a77cf commit 90615f1
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 90 deletions.
93 changes: 37 additions & 56 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
Import, ImportFrom, ImportAll, ImportBase,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
ARG_POS, ARG_STAR, ARG_NAMED, ARG_NAMED_OPT, LITERAL_TYPE, MDEF, GDEF,
CONTRAVARIANT, COVARIANT, INVARIANT,
)
from mypy import nodes
Expand All @@ -39,7 +39,7 @@
from mypy import messages
from mypy.subtypes import (
is_subtype, is_equivalent, is_proper_subtype, is_more_precise,
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_subtype,
restrict_subtype_away, is_subtype_ignoring_tvars, is_callable_compatible,
unify_generic_callable, find_member
)
from mypy.maptype import map_instance_to_supertype
Expand Down Expand Up @@ -437,7 +437,8 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:

assert isinstance(impl_type, CallableType)
assert isinstance(sig1, CallableType)
if not is_callable_subtype(impl_type, sig1, ignore_return=True):
if not is_callable_compatible(impl_type, sig1,
is_compat=is_subtype, ignore_return=True):
self.msg.overloaded_signatures_arg_specific(i + 1, defn.impl)
impl_type_subst = impl_type
if impl_type.variables:
Expand Down Expand Up @@ -3533,37 +3534,45 @@ def is_unsafe_overlapping_signatures(signature: Type, other: Type) -> bool:
"""
if isinstance(signature, CallableType):
if isinstance(other, CallableType):
# TODO varargs
# TODO keyword args
# TODO erasure
# TODO allow to vary covariantly

# Check if the argument counts are overlapping.
min_args = max(signature.min_args, other.min_args)
max_args = min(len(signature.arg_types), len(other.arg_types))
max_args = min(signature.max_positional_args(), other.max_positional_args())
if min_args > max_args:
# Argument counts are not overlapping.
return False
# Signatures are overlapping iff if they are overlapping for the
# smallest common argument count.
for i in range(min_args):
t1 = signature.arg_types[i]
t2 = other.arg_types[i]
if not is_overlapping_types(t1, t2):
return False

# If one of the corresponding argument do NOT overlap,
# then the signatures are not overlapping.
if not is_callable_compatible(signature, other,
is_compat=is_overlapping_types,
ignore_return=True,
check_args_covariantly=True):
# TODO: this check (unlike the others) will erase types due to
# how is_overlapping_type is implemented. This should be
# fixed to make this check consistent with the others.
return False

# All arguments types for the smallest common argument count are
# overlapping => the signature is overlapping. The overlapping is
# safe if the return types are identical.
if is_same_type(signature.ret_type, other.ret_type):
return False

# If the first signature has more general argument types, the
# latter will never be called
if is_more_general_arg_prefix(signature, other):
return False

# Special case: all args are subtypes, and returns are subtypes
if (all(is_proper_subtype(s, o)
for (s, o) in zip(signature.arg_types, other.arg_types)) and
is_proper_subtype(signature.ret_type, other.ret_type)):
if is_callable_compatible(signature, other,
is_compat=is_proper_subtype,
check_args_covariantly=True):
return False

# If the first signature is NOT more precise then the second,
# then the overlap is unsafe.
return not is_more_precise_signature(signature, other)
return True

Expand All @@ -3572,12 +3581,11 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
"""Does t have wider arguments than s?"""
# TODO should an overload with additional items be allowed to be more
# general than one with fewer items (or just one item)?
# TODO check argument kinds and otherwise make more general
if isinstance(t, CallableType):
if isinstance(s, CallableType):
t, s = unify_generic_callables(t, s)
return all(is_proper_subtype(args, argt)
for argt, args in zip(t.arg_types, s.arg_types))
return is_callable_compatible(t, s,
is_compat=is_proper_subtype,
ignore_return=True)
elif isinstance(t, FunctionLike):
if isinstance(s, FunctionLike):
if len(t.items()) == len(s.items()):
Expand All @@ -3586,29 +3594,6 @@ def is_more_general_arg_prefix(t: FunctionLike, s: FunctionLike) -> bool:
return False


def unify_generic_callables(t: CallableType,
s: CallableType) -> Tuple[CallableType,
CallableType]:
"""Make type variables in generic callables the same if possible.
Return updated callables. If we can't unify the type variables,
return the unmodified arguments.
"""
# TODO: Use this elsewhere when comparing generic callables.
if t.is_generic() and s.is_generic():
t_substitutions = {}
s_substitutions = {}
for tv1, tv2 in zip(t.variables, s.variables):
# Are these something we can unify?
if tv1.id != tv2.id and is_equivalent_type_var_def(tv1, tv2):
newdef = TypeVarDef.new_unification_variable(tv2)
t_substitutions[tv1.id] = TypeVarType(newdef)
s_substitutions[tv2.id] = TypeVarType(newdef)
return (cast(CallableType, expand_type(t, t_substitutions)),
cast(CallableType, expand_type(s, s_substitutions)))
return t, s


def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:
"""Are type variable definitions equivalent?
Expand All @@ -3624,26 +3609,22 @@ def is_equivalent_type_var_def(tv1: TypeVarDef, tv2: TypeVarDef) -> bool:


def is_same_arg_prefix(t: CallableType, s: CallableType) -> bool:
# TODO check argument kinds
return all(is_same_type(argt, args)
for argt, args in zip(t.arg_types, s.arg_types))
return is_callable_compatible(t, s,
is_compat=is_same_type,
ignore_return=True,
check_args_covariantly=True,
ignore_pos_arg_names=True)


def is_more_precise_signature(t: CallableType, s: CallableType) -> bool:
"""Is t more precise than s?
A signature t is more precise than s if all argument types and the return
type of t are more precise than the corresponding types in s.
Assume that the argument kinds and names are compatible, and that the
argument counts are overlapping.
"""
# TODO generic function types
# Only consider the common prefix of argument types.
for argt, args in zip(t.arg_types, s.arg_types):
if not is_more_precise(argt, args):
return False
return is_more_precise(t.ret_type, s.ret_type)
return is_callable_compatible(t, s,
is_compat=is_more_precise,
check_args_covariantly=True)


def infer_operator_assignment_method(typ: Type, operator: str) -> Tuple[bool, str]:
Expand Down
4 changes: 3 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def find_matching_overload_item(overloaded: Overloaded, template: CallableType)
for item in items:
# Return type may be indeterminate in the template, so ignore it when performing a
# subtype check.
if mypy.subtypes.is_callable_subtype(item, template, ignore_return=True):
if mypy.subtypes.is_callable_compatible(item, template,
is_compat=mypy.subtypes.is_subtype,
ignore_return=True):
return item
# Fall back to the first item if we can't find a match. This is totally arbitrary --
# maybe we should just bail out at this point.
Expand Down
67 changes: 39 additions & 28 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ def visit_type_var(self, left: TypeVarType) -> bool:
def visit_callable_type(self, left: CallableType) -> bool:
right = self.right
if isinstance(right, CallableType):
return is_callable_subtype(
return is_callable_compatible(
left, right,
is_compat=is_subtype,
ignore_pos_arg_names=self.ignore_pos_arg_names)
elif isinstance(right, Overloaded):
return all(is_subtype(left, item, self.check_type_parameter,
Expand Down Expand Up @@ -309,10 +310,12 @@ def visit_overloaded(self, left: Overloaded) -> bool:
else:
# If this one overlaps with the supertype in any way, but it wasn't
# an exact match, then it's a potential error.
if (is_callable_subtype(left_item, right_item, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names) or
is_callable_subtype(right_item, left_item, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names)):
if (is_callable_compatible(left_item, right_item,
is_compat=is_subtype, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names) or
is_callable_compatible(right_item, left_item,
is_compat=is_subtype, ignore_return=True,
ignore_pos_arg_names=self.ignore_pos_arg_names)):
# If this is an overload that's already been matched, there's no
# problem.
if left_item not in matched_overloads:
Expand Down Expand Up @@ -562,16 +565,22 @@ def non_method_protocol_members(tp: TypeInfo) -> List[str]:
return result


def is_callable_subtype(left: CallableType, right: CallableType,
ignore_return: bool = False,
ignore_pos_arg_names: bool = False,
use_proper_subtype: bool = False) -> bool:
"""Is left a subtype of right?"""
def is_callable_compatible(left: CallableType, right: CallableType,
*,
is_compat: Callable[[Type, Type], bool],
ignore_return: bool = False,
ignore_pos_arg_names: bool = False,
check_args_covariantly: bool = False) -> bool:
"""Is the left compatible with the right, using the provided compatibility check?
if use_proper_subtype:
is_compat = is_proper_subtype
else:
is_compat = is_subtype
If 'check_args_covariantly' is set to True, check if the left's args is
compatible with the right's instead of the other way around (contravariantly).
This function is mostly used to check if the left is a subtype of the right which
is why the default is to check the args covariantly. However, it's occasionally
useful to check the args using some other check, so we leave the variance
configurable.
"""

# If either function is implicitly typed, ignore positional arg names too
if left.implicit or right.implicit:
Expand Down Expand Up @@ -604,6 +613,9 @@ def is_callable_subtype(left: CallableType, right: CallableType,
if not ignore_return and not is_compat(left.ret_type, right.ret_type):
return False

if check_args_covariantly:
is_compat = flip_compat_check(is_compat)

if right.is_ellipsis_args:
return True

Expand Down Expand Up @@ -658,7 +670,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
right_by_position = right.argument_by_position(j)
assert right_by_position is not None
if not are_args_compatible(left_by_position, right_by_position,
ignore_pos_arg_names, use_proper_subtype):
ignore_pos_arg_names, is_compat):
return False
j += 1
continue
Expand All @@ -681,7 +693,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
right_by_name = right.argument_by_name(name)
assert right_by_name is not None
if not are_args_compatible(left_by_name, right_by_name,
ignore_pos_arg_names, use_proper_subtype):
ignore_pos_arg_names, is_compat):
return False
continue

Expand All @@ -690,7 +702,7 @@ def is_callable_subtype(left: CallableType, right: CallableType,
if left_arg is None:
return False

if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, use_proper_subtype):
if not are_args_compatible(left_arg, right_arg, ignore_pos_arg_names, is_compat):
return False

done_with_positional = False
Expand Down Expand Up @@ -742,7 +754,7 @@ def are_args_compatible(
left: FormalArgument,
right: FormalArgument,
ignore_pos_arg_names: bool,
use_proper_subtype: bool) -> bool:
is_compat: Callable[[Type, Type], bool]) -> bool:
# If right has a specific name it wants this argument to be, left must
# have the same.
if right.name is not None and left.name != right.name:
Expand All @@ -753,18 +765,20 @@ def are_args_compatible(
if right.pos is not None and left.pos != right.pos:
return False
# Left must have a more general type
if use_proper_subtype:
if not is_proper_subtype(right.typ, left.typ):
return False
else:
if not is_subtype(right.typ, left.typ):
return False
if not is_compat(right.typ, left.typ):
return False
# If right's argument is optional, left's must also be.
if not right.required and left.required:
return False
return True


def flip_compat_check(is_compat: Callable[[Type, Type], bool]) -> Callable[[Type, Type], bool]:
def new_is_compat(left: Type, right: Type) -> bool:
return is_compat(right, left)
return new_is_compat


def unify_generic_callable(type: CallableType, target: CallableType,
ignore_return: bool) -> Optional[CallableType]:
"""Try to unify a generic callable type with another callable type.
Expand Down Expand Up @@ -907,10 +921,7 @@ def visit_type_var(self, left: TypeVarType) -> bool:
def visit_callable_type(self, left: CallableType) -> bool:
right = self.right
if isinstance(right, CallableType):
return is_callable_subtype(
left, right,
ignore_pos_arg_names=False,
use_proper_subtype=True)
return is_callable_compatible(left, right, is_compat=is_proper_subtype)
elif isinstance(right, Overloaded):
return all(is_proper_subtype(left, item)
for item in right.items())
Expand Down
7 changes: 7 additions & 0 deletions mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,13 @@ def max_fixed_args(self) -> int:
n -= 1
return n

def max_positional_args(self) -> int:
"""Returns the number of positional args.
This includes *arg and **kwargs but excludes keyword-only args."""
blacklist = (ARG_NAMED, ARG_NAMED_OPT)
return len([kind not in blacklist for kind in self.arg_kinds])

def corresponding_argument(self, model: FormalArgument) -> Optional[FormalArgument]:
"""Return the argument in this function that corresponds to `model`"""

Expand Down
Loading

0 comments on commit 90615f1

Please sign in to comment.