Skip to content

Commit

Permalink
Adds support for basic union math with overloads
Browse files Browse the repository at this point in the history
This commit adds support for very basic and simple union math when
calling overloaded functions, resolving python#4576.

As a side effect, this change also fixes a bug where calling overloaded
functions can sometimes silently infer a return type of 'Any' and
slightly modifies the semantics of how mypy handles overlaps in
overloaded functions.

Details on specific changes made:

1.  The new algorithm works by modifying checkexpr.overload_call_targets
    to return all possible matches, rather then just one.

    We start by trying the first matching signature. If there was some
    error, we (conservatively) attempt to union all of the matching
    signatures together and repeat the typechecking process.

    If it doesn't seem like it's possible to combine the matching
    signatures in a sound way, we end and just output the errors we
    obtained from typechecking the first match.

    The "signature-unioning" code is currently deliberately very
    conservative. I figured it was better to start small and attempt to
    handle only basic cases like python#1943 and relax the restrictions later
    as needed. For more details on this algorithm, see the comments in
    checkexpr.union_overload_matches.

2.  This change incidentally resolves any bugs related to how calling
    an overloaded function can sometimes silently infer a return type
    of Any. Previously, if a function call caused an overload to be
    less precise then a previous one, we gave up and returned a silent
    Any.

    This change removes this case altogether and only infers Any if
    either (a) the caller arguments explicitly contains Any or (b) if
    there was some error.

    For example, see python#3295 and python#1322 -- I believe this pull request touches
    on and maybe resolves (??) those two issues.

3.  As a result, this caused a few errors in mypy where code was
    relying on this "silently infer Any" behavior -- see the changes in
    checker.py and semanal.py. Both files were using expressions of the
    form `zip(*iterable)`, which ended up having a type of `Any` under
    the old algorithm. The new algorithm will instead infer
    `Iterable[Tuple[Any, ...]]` which actually matches the stubs in
    typeshed.

4.  Many of the attrs tests were also relying on the same behavior.
    Specifically, these changes cause the attr stubs in
    `test-data/unit/lib-stub` to no longer work. It seemed that expressions
    of the form `a = attr.ib()` were evaluated to 'Any' not because of a
    stub, but because of the 'silent Any' bug.

    I couldn't find a clean way of fixing the stubs to infer the correct
    thing under this new behavior, so just gave up and removed the
    overloads altogether. I think this is fine though -- it seems like
    the attrs plugin infers the correct type for us anyways, regardless
    of what the stubs say.

    If this pull request is accepted, I plan on submitting a similar
    pull request to the stubs in typeshed.

4.  This pull request also probably touches on
    python/typing#253. We still require the
    overloads to be written from the most narrow to general and disallow
    overlapping signatures.

    However, if a *call* now causes overlaps, we try the "union"
    algorithm described above and default to selecting the first
    matching overload instead of giving up.
  • Loading branch information
Michael0x2a committed Apr 2, 2018
1 parent 21cd8e2 commit 8f84256
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 85 deletions.
4 changes: 2 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1798,8 +1798,8 @@ def check_multi_assignment_from_union(self, lvalues: List[Expression], rvalue: E
expr = expr.expr
types, declared_types = zip(*items)
self.binder.assign_type(expr,
UnionType.make_simplified_union(types),
UnionType.make_simplified_union(declared_types),
UnionType.make_simplified_union(list(types)),
UnionType.make_simplified_union(list(declared_types)),
False)
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
# Properly store the inferred types.
Expand Down
209 changes: 151 additions & 58 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,63 @@ def check_call(self, callee: Type, args: List[Expression],
arg_types = self.infer_arg_types_in_context(None, args)
self.msg.enable_errors()

target = self.overload_call_target(arg_types, arg_kinds, arg_names,
callee, context,
messages=arg_messages)
return self.check_call(target, args, arg_kinds, context, arg_names,
overload_messages = arg_messages.copy()
targets = self.overload_call_targets(arg_types, arg_kinds, arg_names,
callee, context,
messages=overload_messages)

# If there are multiple targets, that means that there were
# either multiple possible matches or the types were overlapping in some
# way. In either case, we default to picking the first match and
# see what happens if we try using it.
#
# Note: if we pass in an argument that inherits from two overloaded
# types, we default to picking the first match. For example:
#
# class A: pass
# class B: pass
# class C(A, B): pass
#
# @overload
# def f(x: A) -> int: ...
# @overload
# def f(x: B) -> str: ...
# def f(x): ...
#
# reveal_type(f(C())) # Will be 'int', not 'Union[int, str]'
#
# It's unclear if this is really the best thing to do, but multiple
# inheritance is rare. See the docstring of mypy.meet.is_overlapping_types
# for more about this.

original_output = self.check_call(targets[0], args, arg_kinds, context, arg_names,
arg_messages=overload_messages,
callable_name=callable_name,
object_type=object_type)

if not overload_messages.is_errors() or len(targets) == 1:
# If there were no errors or if there was only one match, we can end now.
#
# Note that if we have only one target, there's nothing else we
# can try doing. In that case, we just give up and return early
# and skip the below steps.
arg_messages.add_errors(overload_messages)
return original_output

# Otherwise, we attempt to synthesize together a new callable by combining
# together the different matches by union-ing together their arguments
# and return type.

targets = cast(List[CallableType], targets)
unioned_callable = self.union_overload_matches(targets)
if unioned_callable is None:
# If it was not possible to actually combine together the
# callables in a sound way, we give up and return the original
# error message.
arg_messages.add_errors(overload_messages)
return original_output

return self.check_call(unioned_callable, args, arg_kinds, context, arg_names,
arg_messages=arg_messages,
callable_name=callable_name,
object_type=object_type)
Expand Down Expand Up @@ -1089,83 +1142,123 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
(callee_type.item.type.is_abstract or callee_type.item.type.is_protocol) and
# ...except for classmethod first argument
not caller_type.is_classmethod_class):
self.msg.concrete_only_call(callee_type, context)
messages.concrete_only_call(callee_type, context)
elif not is_subtype(caller_type, callee_type):
if self.chk.should_suppress_optional_error([caller_type, callee_type]):
return
messages.incompatible_argument(n, m, callee, original_caller_type,
caller_kind, context)
if (isinstance(original_caller_type, (Instance, TupleType, TypedDictType)) and
isinstance(callee_type, Instance) and callee_type.type.is_protocol):
self.msg.report_protocol_problems(original_caller_type, callee_type, context)
messages.report_protocol_problems(original_caller_type, callee_type, context)
if (isinstance(callee_type, CallableType) and
isinstance(original_caller_type, Instance)):
call = find_member('__call__', original_caller_type, original_caller_type)
if call:
self.msg.note_call(original_caller_type, call, context)

def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded, context: Context,
messages: Optional[MessageBuilder] = None) -> Type:
"""Infer the correct overload item to call with given argument types.
The return value may be CallableType or AnyType (if an unique item
could not be determined).
messages.note_call(original_caller_type, call, context)

def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
overload: Overloaded, context: Context,
messages: Optional[MessageBuilder] = None) -> Sequence[Type]:
"""Infer all possible overload targets to call with given argument types.
The list is guaranteed be one of the following:
1. A List[CallableType] of length 1 if we were able to find an
unambiguous best match.
2. A List[AnyType] of length 1 if we were unable to find any match
or discovered the match was ambiguous due to conflicting Any types.
3. A List[CallableType] of length 2 or more if there were multiple
plausible matches. The matches are returned in the order they
were defined.
"""
messages = messages or self.msg
# TODO: For overlapping signatures we should try to get a more precise
# result than 'Any'.
match = [] # type: List[CallableType]
best_match = 0
for typ in overload.items():
similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names,
typ, context=context)
if similarity > 0 and similarity >= best_match:
if (match and not is_same_type(match[-1].ret_type,
typ.ret_type) and
(not mypy.checker.is_more_precise_signature(match[-1], typ)
or (any(isinstance(arg, AnyType) for arg in arg_types)
and any_arg_causes_overload_ambiguity(
match + [typ], arg_types, arg_kinds, arg_names)))):
# Ambiguous return type. Either the function overload is
# overlapping (which we don't handle very well here) or the
# caller has provided some Any argument types; in either
# case we'll fall back to Any. It's okay to use Any types
# in calls.
#
# Overlapping overload items are generally fine if the
# overlapping is only possible when there is multiple
# inheritance, as this is rare. See docstring of
# mypy.meet.is_overlapping_types for more about this.
#
# Note that there is no ambiguity if the items are
# covariant in both argument types and return types with
# respect to type precision. We'll pick the best/closest
# match.
#
# TODO: Consider returning a union type instead if the
# overlapping is NOT due to Any types?
return AnyType(TypeOfAny.special_form)
else:
match.append(typ)
if (match and not is_same_type(match[-1].ret_type, typ.ret_type)
and any(isinstance(arg, AnyType) for arg in arg_types)
and any_arg_causes_overload_ambiguity(
match + [typ], arg_types, arg_kinds, arg_names)):
# Ambiguous return type. The caller has provided some
# Any argument types (which are okay to use in calls),
# so we fall back to returning 'Any'.
return [AnyType(TypeOfAny.special_form)]
match.append(typ)
best_match = max(best_match, similarity)
if not match:

if len(match) == 0:
if not self.chk.should_suppress_optional_error(arg_types):
messages.no_variant_matches_arguments(overload, arg_types, context)
return AnyType(TypeOfAny.from_error)
return [AnyType(TypeOfAny.from_error)]
elif len(match) == 1:
return match
else:
if len(match) == 1:
return match[0]
else:
# More than one signature matches. Pick the first *non-erased*
# matching signature, or default to the first one if none
# match.
for m in match:
if self.match_signature_types(arg_types, arg_kinds, arg_names, m,
context=context):
return m
return match[0]
# More than one signature matches or the signatures are
# overlapping. In either case, we return all of the matching
# signatures and let the caller decide what to do with them.
out = [m for m in match if self.match_signature_types(
arg_types, arg_kinds, arg_names, m, context=context)]
return out if len(out) >= 1 else match

def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
"""Accepts a list of overload signatures and attempts to combine them together into a
new CallableType consisting of the union of all of the given arguments and return types.
Returns None if it is not possible to combine the different callables together in a
sound manner."""

new_args: List[List[Type]] = [[] for _ in range(len(callables[0].arg_types))]

expected_names = callables[0].arg_names
expected_kinds = callables[0].arg_kinds

for target in callables:
if target.arg_names != expected_names or target.arg_kinds != expected_kinds:
# We conservatively end if the overloads do not have the exact same signature.
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
return None

for i, arg in enumerate(target.arg_types):
new_args[i].append(arg)

union_count = 0
final_args = []
for args in new_args:
new_type = UnionType.make_simplified_union(args)
union_count += 1 if isinstance(new_type, UnionType) else 0
final_args.append(new_type)

# TODO: Modify this check to be less conservative.
#
# Currently, we permit only one union union in the arguments because if we allow
# multiple, we can't always guarantee the synthesized callable will be correct.
#
# For example, suppose we had the following two overloads:
#
# @overload
# def f(x: A, y: B) -> None: ...
# @overload
# def f(x: B, y: A) -> None: ...
#
# If we continued and synthesize "def f(x: Union[A,B], y: Union[A,B]) -> None: ...",
# then we'd incorrectly accept calls like "f(A(), A())" when they really ought to
# be rejected.
#
# However, that means we'll also give up if the original overloads contained
# any unions. This is likely unnecessary -- we only really need to give up if
# there are more then one *synthesized* union arguments.
if union_count >= 2:
return None

return callables[0].copy_modified(
arg_types=final_args,
ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]),
implicit=True,
from_overloads=True)

def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
arg_names: Optional[Sequence[Optional[str]]],
Expand Down
14 changes: 12 additions & 2 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,8 +629,19 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
expected_type = callee.arg_types[m - 1]
except IndexError: # Varargs callees
expected_type = callee.arg_types[-1]

arg_type_str, expected_type_str = self.format_distinctly(
arg_type, expected_type, bare=True)
expected_type_str = self.quote_type_string(expected_type_str)

if callee.from_overloads and isinstance(expected_type, UnionType):
expected_formatted = []
for e in expected_type.items:
type_str = self.format_distinctly(arg_type, e, bare=True)[1]
expected_formatted.append(self.quote_type_string(type_str))
expected_type_str = 'one of {} based on available overloads'.format(
', '.join(expected_formatted))

if arg_kind == ARG_STAR:
arg_type_str = '*' + arg_type_str
elif arg_kind == ARG_STAR2:
Expand All @@ -645,8 +656,7 @@ def incompatible_argument(self, n: int, m: int, callee: CallableType, arg_type:
arg_label = '"{}"'.format(arg_name)

msg = 'Argument {} {}has incompatible type {}; expected {}'.format(
arg_label, target, self.quote_type_string(arg_type_str),
self.quote_type_string(expected_type_str))
arg_label, target, self.quote_type_string(arg_type_str), expected_type_str)
if isinstance(arg_type, Instance) and isinstance(expected_type, Instance):
notes = append_invariance_notes(notes, arg_type, expected_type)
self.fail(msg, context)
Expand Down
2 changes: 1 addition & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2870,7 +2870,7 @@ def process_module_assignment(self, lvals: List[Lvalue], rval: Expression,
# about the length mismatch in type-checking.
elementwise_assignments = zip(rval.items, *[v.items for v in seq_lvals])
for rv, *lvs in elementwise_assignments:
self.process_module_assignment(lvs, rv, ctx)
self.process_module_assignment(list(lvs), rv, ctx)
elif isinstance(rval, RefExpr):
rnode = self.lookup_type_node(rval)
if rnode and rnode.kind == MODULE_REF:
Expand Down
11 changes: 10 additions & 1 deletion mypy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ class CallableType(FunctionLike):
special_sig = None # type: Optional[str]
# Was this callable generated by analyzing Type[...] instantiation?
from_type_type = False # type: bool
# Was this callable generated by synthesizing multiple overloads?
from_overloads = False # type: bool

bound_args = None # type: List[Optional[Type]]

Expand All @@ -679,6 +681,7 @@ def __init__(self,
is_classmethod_class: bool = False,
special_sig: Optional[str] = None,
from_type_type: bool = False,
from_overloads: bool = False,
bound_args: Optional[List[Optional[Type]]] = None,
) -> None:
assert len(arg_types) == len(arg_kinds) == len(arg_names)
Expand All @@ -703,6 +706,7 @@ def __init__(self,
self.is_classmethod_class = is_classmethod_class
self.special_sig = special_sig
self.from_type_type = from_type_type
self.from_overloads = from_overloads
self.bound_args = bound_args or []
super().__init__(line, column)

Expand All @@ -718,8 +722,10 @@ def copy_modified(self,
line: int = _dummy,
column: int = _dummy,
is_ellipsis_args: bool = _dummy,
implicit: bool = _dummy,
special_sig: Optional[str] = _dummy,
from_type_type: bool = _dummy,
from_overloads: bool = _dummy,
bound_args: List[Optional[Type]] = _dummy) -> 'CallableType':
return CallableType(
arg_types=arg_types if arg_types is not _dummy else self.arg_types,
Expand All @@ -734,10 +740,11 @@ def copy_modified(self,
column=column if column is not _dummy else self.column,
is_ellipsis_args=(
is_ellipsis_args if is_ellipsis_args is not _dummy else self.is_ellipsis_args),
implicit=self.implicit,
implicit=implicit if implicit is not _dummy else self.implicit,
is_classmethod_class=self.is_classmethod_class,
special_sig=special_sig if special_sig is not _dummy else self.special_sig,
from_type_type=from_type_type if from_type_type is not _dummy else self.from_type_type,
from_overloads=from_overloads if from_overloads is not _dummy else self.from_overloads,
bound_args=bound_args if bound_args is not _dummy else self.bound_args,
)

Expand Down Expand Up @@ -889,6 +896,7 @@ def serialize(self) -> JsonDict:
'is_ellipsis_args': self.is_ellipsis_args,
'implicit': self.implicit,
'is_classmethod_class': self.is_classmethod_class,
'from_overloads': self.from_overloads,
'bound_args': [(None if t is None else t.serialize())
for t in self.bound_args],
}
Expand All @@ -907,6 +915,7 @@ def deserialize(cls, data: JsonDict) -> 'CallableType':
is_ellipsis_args=data['is_ellipsis_args'],
implicit=data['implicit'],
is_classmethod_class=data['is_classmethod_class'],
from_overloads=data['from_overloads'],
bound_args=[(None if t is None else deserialize_type(t))
for t in data['bound_args']],
)
Expand Down
Loading

0 comments on commit 8f84256

Please sign in to comment.