Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix checking multiple assignments based on tuple unpacking involving partially initialised variables (Fixes #12915). #14440

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 78 additions & 131 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@
NoneType,
Overloaded,
PartialType,
PlaceholderType,
ProperType,
StarType,
TupleType,
Expand Down Expand Up @@ -338,8 +339,6 @@ class TypeChecker(NodeVisitor[None], CheckerPluginInterface):
# Used for collecting inferred attribute types so that they can be checked
# for consistency.
inferred_attribute_types: dict[Var, Type] | None = None
# Don't infer partial None types if we are processing assignment from Union
no_partial_types: bool = False

# The set of all dependencies (suppressed or not) that this module accesses, either
# directly or indirectly.
Expand Down Expand Up @@ -3375,7 +3374,6 @@ def check_multi_assignment(
context: Context,
infer_lvalue_type: bool = True,
rv_type: Type | None = None,
undefined_rvalue: bool = False,
) -> None:
"""Check the assignment of one rvalue to a number of lvalues."""

Expand All @@ -3386,12 +3384,6 @@ def check_multi_assignment(
if isinstance(rvalue_type, TypeVarLikeType):
rvalue_type = get_proper_type(rvalue_type.upper_bound)

if isinstance(rvalue_type, UnionType):
# If this is an Optional type in non-strict Optional code, unwrap it.
relevant_items = rvalue_type.relevant_items()
if len(relevant_items) == 1:
rvalue_type = get_proper_type(relevant_items[0])

if isinstance(rvalue_type, AnyType):
for lv in lvalues:
if isinstance(lv, StarExpr):
Expand All @@ -3402,7 +3394,7 @@ def check_multi_assignment(
self.check_assignment(lv, temp_node, infer_lvalue_type)
elif isinstance(rvalue_type, TupleType):
self.check_multi_assignment_from_tuple(
lvalues, rvalue, rvalue_type, context, undefined_rvalue, infer_lvalue_type
lvalues, rvalue, rvalue_type, context, infer_lvalue_type
)
elif isinstance(rvalue_type, UnionType):
self.check_multi_assignment_from_union(
Expand All @@ -3424,64 +3416,89 @@ def check_multi_assignment_from_union(
infer_lvalue_type: bool,
) -> None:
"""Check assignment to multiple lvalue targets when rvalue type is a Union[...].

For example:

t: Union[Tuple[int, int], Tuple[str, str]]
t: Union[Tuple[int, int], Tuple[str, float]]
x, y = t
reveal_type(x) # Union[int, str]

The idea in this case is to process the assignment for every item of the union.
Important note: the types are collected in two places, 'union_types' contains
inferred types for first assignments, 'assignments' contains the narrowed types
for binder.
The idea is to check each single assignment by constructing a union of the
relevant rvalue types:

x = Union[int, str]
y = Union[int, float]
"""
self.no_partial_types = True
transposed: tuple[list[Type], ...] = tuple([] for _ in self.flatten_lvalues(lvalues))
# Notify binder that we want to defer bindings and instead collect types.
with self.binder.accumulate_type_assignments() as assignments:
for item in rvalue_type.items:
# Type check the assignment separately for each union item and collect
# the inferred lvalue types for each union item.
self.check_multi_assignment(
lvalues,
rvalue,
context,
infer_lvalue_type=infer_lvalue_type,
rv_type=item,
undefined_rvalue=True,
)
for t, lv in zip(transposed, self.flatten_lvalues(lvalues)):
# We can access _type_maps directly since temporary type maps are
# only created within expressions.
t.append(self._type_maps[0].pop(lv, AnyType(TypeOfAny.special_form)))
union_types = tuple(make_simplified_union(col) for col in transposed)
for expr, items in assignments.items():
# Bind a union of types collected in 'assignments' to every expression.
if isinstance(expr, StarExpr):
expr = expr.expr

# TODO: See todo in binder.py, ConditionalTypeBinder.assign_type
# It's unclear why the 'declared_type' param is sometimes 'None'
clean_items: list[tuple[Type, Type]] = []
for type, declared_type in items:
assert declared_type is not None
clean_items.append((type, declared_type))

types, declared_types = zip(*clean_items)
self.binder.assign_type(
expr,
make_simplified_union(list(types)),
make_simplified_union(list(declared_types)),
False,
# if `rvalue_type` is Optional type in non-strict Optional code, unwap it:
relevant_items = rvalue_type.relevant_items()
if len(relevant_items) == 1:
self.check_multi_assignment(
lvalues, rvalue, context, infer_lvalue_type, relevant_items[0]
)
for union, lv in zip(union_types, self.flatten_lvalues(lvalues)):
# Properly store the inferred types.
_1, _2, inferred = self.check_lvalue(lv)
if inferred:
self.set_inferred_type(inferred, lv, union)
return

# cases like a, *b, c require special care
star_idx = next((i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), None)

def handle_star_index(orig_types: list[Type]) -> list[Type]:
if star_idx is not None:
orig_types[star_idx] = self.named_generic_type(
"builtins.list", [orig_types[star_idx]]
)
return orig_types

# collect the relevant rvalue types
nmb_subitems = len(lvalues)
items: list[list[Type]] = []
for idx, item in enumerate(rvalue_type.items):
item = get_proper_type(item)
if isinstance(item, TupleType):
delta = len(item.items) - nmb_subitems
if star_idx is None:
if delta == 0: # a, b = x, y
items.append(item.items.copy())
else: # a, b = x, y, z or a, b, c = x, y
self.msg.wrong_number_values_to_unpack(
len(item.items), nmb_subitems, context
)
elif delta < -1: # a, b, c, *d = x, y
self.msg.wrong_number_values_to_unpack(
len(item.items), nmb_subitems - 1, context
)
elif delta == -1: # a, b, *c = x, y
items.append(item.items.copy())
# to be removed after transposing:
items[-1].insert(star_idx, PlaceholderType("temp", [], -1))
elif delta == 0: # a, b, *c = x, y, z
items.append(handle_star_index(item.items.copy()))
else: # a, *b = x, y, z
union = make_simplified_union(item.items[star_idx : star_idx + delta + 1])
subitems = item.items[:star_idx] + [union] + item.items[star_idx + delta + 1 :]
items.append(handle_star_index(subitems))
elif isinstance(item, AnyType):
items.append(handle_star_index(nmb_subitems * [cast(Type, item)]))
elif isinstance(item, Instance):
if item.type.fullname == "builtins.str":
self.msg.unpacking_strings_disallowed(context)
elif self.type_is_iterable(item):
items.append(handle_star_index(nmb_subitems * [self.iterable_item_type(item)]))
else:
self.msg.type_not_iterable(item, context)
else:
self.store_type(lv, union)
self.no_partial_types = False
self.msg.type_not_iterable(item, context)

# construct the unions and perform the single assignment checks
items_transposed = zip(*items)
for lvalue, subitems_ in zip(lvalues, items_transposed):
subitems = []
for item in subitems_:
item = get_proper_type(item)
if not isinstance(item, PlaceholderType):
subitems.append(item)
uniontype = make_simplified_union(subitems)
if isinstance(lvalue, StarExpr):
lvalue = lvalue.expr
self.check_assignment(lvalue, self.temp_node(uniontype, context), infer_lvalue_type)

def flatten_lvalues(self, lvalues: list[Expression]) -> list[Expression]:
res: list[Expression] = []
Expand All @@ -3500,7 +3517,6 @@ def check_multi_assignment_from_tuple(
rvalue: Expression,
rvalue_type: TupleType,
context: Context,
undefined_rvalue: bool,
infer_lvalue_type: bool = True,
) -> None:
if self.check_rvalue_count_in_assignment(lvalues, len(rvalue_type.items), context):
Expand All @@ -3512,34 +3528,6 @@ def check_multi_assignment_from_tuple(
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index + 1 :]

if not undefined_rvalue:
# Infer rvalue again, now in the correct type context.
lvalue_type = self.lvalue_type_for_inference(lvalues, rvalue_type)
reinferred_rvalue_type = get_proper_type(
self.expr_checker.accept(rvalue, lvalue_type)
)

if isinstance(reinferred_rvalue_type, UnionType):
# If this is an Optional type in non-strict Optional code, unwrap it.
relevant_items = reinferred_rvalue_type.relevant_items()
if len(relevant_items) == 1:
reinferred_rvalue_type = get_proper_type(relevant_items[0])
if isinstance(reinferred_rvalue_type, UnionType):
self.check_multi_assignment_from_union(
lvalues, rvalue, reinferred_rvalue_type, context, infer_lvalue_type
)
return
if isinstance(reinferred_rvalue_type, AnyType):
# We can get Any if the current node is
# deferred. Doing more inference in deferred nodes
# is hard, so give up for now. We can also get
# here if reinferring types above changes the
# inferred return type for an overloaded function
# to be ambiguous.
return
assert isinstance(reinferred_rvalue_type, TupleType)
rvalue_type = reinferred_rvalue_type

left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues)
)
Expand All @@ -3555,44 +3543,6 @@ def check_multi_assignment_from_tuple(
for lv, rv_type in zip(right_lvs, right_rv_types):
self.check_assignment(lv, self.temp_node(rv_type, context), infer_lvalue_type)

def lvalue_type_for_inference(self, lvalues: list[Lvalue], rvalue_type: TupleType) -> Type:
star_index = next(
(i for i, lv in enumerate(lvalues) if isinstance(lv, StarExpr)), len(lvalues)
)
left_lvs = lvalues[:star_index]
star_lv = cast(StarExpr, lvalues[star_index]) if star_index != len(lvalues) else None
right_lvs = lvalues[star_index + 1 :]
left_rv_types, star_rv_types, right_rv_types = self.split_around_star(
rvalue_type.items, star_index, len(lvalues)
)

type_parameters: list[Type] = []

def append_types_for_inference(lvs: list[Expression], rv_types: list[Type]) -> None:
for lv, rv_type in zip(lvs, rv_types):
sub_lvalue_type, index_expr, inferred = self.check_lvalue(lv)
if sub_lvalue_type and not isinstance(sub_lvalue_type, PartialType):
type_parameters.append(sub_lvalue_type)
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.append(rv_type)

append_types_for_inference(left_lvs, left_rv_types)

if star_lv:
sub_lvalue_type, index_expr, inferred = self.check_lvalue(star_lv.expr)
if sub_lvalue_type and not isinstance(sub_lvalue_type, PartialType):
type_parameters.extend([sub_lvalue_type] * len(star_rv_types))
else: # index lvalue
# TODO Figure out more precise type context, probably
# based on the type signature of the _set method.
type_parameters.extend(star_rv_types)

append_types_for_inference(right_lvs, right_rv_types)

return TupleType(type_parameters, self.named_type("builtins.tuple"))

def split_around_star(
self, items: list[T], star_index: int, length: int
) -> tuple[list[T], list[T], list[T]]:
Expand Down Expand Up @@ -3702,10 +3652,7 @@ def infer_variable_type(
"""Infer the type of initialized variables from initializer type."""
if isinstance(init_type, DeletedType):
self.msg.deleted_as_rvalue(init_type, context)
elif (
not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final)
and not self.no_partial_types
):
elif not is_valid_inferred_type(init_type, is_lvalue_final=name.is_final):
# We cannot use the type of the initialization expression for full type
# inference (it's not specific enough), but we might be able to give
# partial type which will be made more specific later. A partial type
Expand Down
8 changes: 3 additions & 5 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,8 @@ if int():
ab, ao = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
if int():
ao, ab = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")

if int():
ao, ao = f(b)
ao, ao = f(b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
if int():
ab, ab = f(b)
if int():
Expand Down Expand Up @@ -199,11 +198,10 @@ if int():
ao, ab, ab, ab = h(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
if int():
ab, ab, ao, ab = h(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")

if int():
ao, ab, ab = f(b, b)
ao, ab, ab = f(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
if int():
ab, ab, ao = g(b, b)
ab, ab, ao = g(b, b) # E: Incompatible types in assignment (expression has type "A[B]", variable has type "A[object]")
if int():
ab, ab, ab, ab = h(b, b)

Expand Down
48 changes: 48 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,54 @@ class C:
a = 42
[out]

[case testDefinePartiallyInitialisedVariableDuringTupleUnpacking]
# flags: --strict-optional
from typing import Tuple, Union

t1: Union[Tuple[None], Tuple[str]]
x1 = None
x1, = t1
reveal_type(x1) # N: Revealed type is "Union[None, builtins.str]"

t2: Union[Tuple[str], Tuple[None]]
x2 = None
x2, = t2
reveal_type(x2) # N: Revealed type is "Union[builtins.str, None]"

t3: Union[Tuple[int], Tuple[str]]
x3 = None
x3, = t3
reveal_type(x3) # N: Revealed type is "Union[builtins.int, builtins.str]"

def f() -> Union[
Tuple[None, None, None, int, int, int, int, int, int],
Tuple[None, None, None, int, int, int, str, str, str]
]: ...
a1 = None
b1 = None
c1 = None
a2: object
b2: object
c2: object
a1, a2, a3, b1, b2, b3, c1, c2, c3 = f()
reveal_type(a1) # N: Revealed type is "None"
reveal_type(a2) # N: Revealed type is "None"
reveal_type(a3) # N: Revealed type is "None"
reveal_type(b1) # N: Revealed type is "builtins.int"
reveal_type(b2) # N: Revealed type is "builtins.int"
reveal_type(b3) # N: Revealed type is "builtins.int"
reveal_type(c1) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(c2) # N: Revealed type is "Union[builtins.int, builtins.str]"
reveal_type(c3) # N: Revealed type is "Union[builtins.int, builtins.str]"

tt: Tuple[Union[Tuple[None], Tuple[str], Tuple[int]]]
z = None
z, = tt[0]
reveal_type(z) # N: Revealed type is "Union[None, builtins.str, builtins.int]"

[builtins fixtures/tuple.pyi]


-- More partial type errors
-- ------------------------

Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/check-tuples.test
Original file line number Diff line number Diff line change
Expand Up @@ -1054,7 +1054,8 @@ def g(x: T) -> Tuple[T, T]:
return (x, x)

z = 1
x, y = g(z) # E: Argument 1 to "g" has incompatible type "int"; expected "Tuple[B1, B2]"
x, y = g(z) # E: Incompatible types in assignment (expression has type "int", variable has type "Tuple[A, ...]") \
# E: Incompatible types in assignment (expression has type "int", variable has type "Tuple[Union[B1, C], Union[B2, C]]")
[builtins fixtures/tuple.pyi]
[out]

Expand Down
Loading