From 35bc1a25f84953b47d53d2c95cfc63f6ad31ba6e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 23 Aug 2022 00:44:50 +0100 Subject: [PATCH] Allow using TypedDict for more precise typing of **kwds (#13471) Fixes #4441 This uses a different approach than the initial attempt, but I re-used some of the test cases from the older PR. The initial idea was to eagerly expand the signature of the function during semantic analysis, but it didn't work well with fine-grained mode and also mypy in general relies on function definition and its type being consistent (and rewriting `FuncDef` sounds too sketchy). So instead I add a boolean flag to `CallableType` to indicate whether type of `**kwargs` is each item type or the "packed" type. I also add few helpers and safety net in form of a `NewType()`, but in general I am surprised how few places needed normalizing the signatures (because most relevant code paths go through `check_callable_call()` and/or `is_callable_compatible()`). Currently `Unpack[...]` is hidden behind `--enable-incomplete-features`, so this will be too, but IMO this part is 99% complete (you can see even some more exotic use cases like generic TypedDicts and callback protocols in test cases). --- mypy/checker.py | 16 +- mypy/checkexpr.py | 4 + mypy/constraints.py | 6 +- mypy/join.py | 18 +- mypy/meet.py | 4 + mypy/messages.py | 5 +- mypy/semanal.py | 26 +++ mypy/subtypes.py | 25 ++- mypy/typeanal.py | 2 +- mypy/types.py | 52 ++++- mypyc/test-data/run-functions.test | 15 ++ mypyc/test/test_run.py | 2 + test-data/unit/check-incremental.test | 23 +++ test-data/unit/check-python38.test | 10 + test-data/unit/check-varargs.test | 286 ++++++++++++++++++++++++++ test-data/unit/fine-grained.test | 32 +++ 16 files changed, 505 insertions(+), 21 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 08f053d321330..670a08a8e2be9 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -728,9 +728,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None: # This is to match the direction the implementation's return # needs to be compatible in. if impl_type.variables: - impl = unify_generic_callable( - impl_type, - sig1, + impl: CallableType | None = unify_generic_callable( + # Normalize both before unifying + impl_type.with_unpacked_kwargs(), + sig1.with_unpacked_kwargs(), ignore_return=False, return_constraint_direction=SUPERTYPE_OF, ) @@ -1165,7 +1166,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) -> # builtins.tuple[T] is typing.Tuple[T, ...] arg_type = self.named_generic_type("builtins.tuple", [arg_type]) elif typ.arg_kinds[i] == nodes.ARG_STAR2: - if not isinstance(arg_type, ParamSpecType): + if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs: arg_type = self.named_generic_type( "builtins.dict", [self.str_type(), arg_type] ) @@ -1887,6 +1888,13 @@ def check_override( if fail: emitted_msg = False + + # Normalize signatures, so we get better diagnostics. + if isinstance(override, (CallableType, Overloaded)): + override = override.with_unpacked_kwargs() + if isinstance(original, (CallableType, Overloaded)): + original = original.with_unpacked_kwargs() + if ( isinstance(override, CallableType) and isinstance(original, CallableType) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index cb542ee5300ba..b1bb5f2b3cc2c 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1322,6 +1322,8 @@ def check_callable_call( See the docstring of check_call for more information. """ + # Always unpack **kwargs before checking a call. + callee = callee.with_unpacked_kwargs() if callable_name is None and callee.name: callable_name = callee.name ret_type = get_proper_type(callee.ret_type) @@ -2057,6 +2059,8 @@ def check_overload_call( context: Context, ) -> tuple[Type, Type]: """Checks a call to an overloaded function.""" + # Normalize unpacked kwargs before checking the call. + callee = callee.with_unpacked_kwargs() arg_types = self.infer_arg_types_in_empty_context(args) # Step 1: Filter call targets to remove ones where the argument counts don't match plausible_targets = self.plausible_overload_call_targets( diff --git a/mypy/constraints.py b/mypy/constraints.py index f9cc68a0a7eb8..e0cb3245fdf66 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members( return res def visit_callable_type(self, template: CallableType) -> list[Constraint]: + # Normalize callables before matching against each other. + # Note that non-normalized callables can be created in annotations + # using e.g. callback protocols. + template = template.with_unpacked_kwargs() if isinstance(self.actual, CallableType): res: list[Constraint] = [] - cactual = self.actual + cactual = self.actual.with_unpacked_kwargs() param_spec = template.param_spec() if param_spec is None: # FIX verify argument counts diff --git a/mypy/join.py b/mypy/join.py index 123488c54ef60..68cd02e40d17a 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Tuple + import mypy.typeops from mypy.maptype import map_instance_to_supertype from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType: """Return a simple least upper bound given the declared type.""" - # TODO: check infinite recursion for aliases here. + # TODO: check infinite recursion for aliases here? declaration = get_proper_type(declaration) s = get_proper_type(s) t = get_proper_type(t) @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType: if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType): s, t = t, s + # Meets/joins require callable type normalization. + s, t = normalize_callables(s, t) + value = t.accept(TypeJoinVisitor(s)) if declaration is None or is_subtype(value, declaration): return value @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) elif isinstance(t, PlaceholderType): return AnyType(TypeOfAny.from_error) + # Meets/joins require callable type normalization. + s, t = normalize_callables(s, t) + # Use a visitor to handle non-trivial cases. return t.accept(TypeJoinVisitor(s, instance_joiner)) @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool: return False +def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]: + if isinstance(s, (CallableType, Overloaded)): + s = s.with_unpacked_kwargs() + if isinstance(t, (CallableType, Overloaded)): + t = t.with_unpacked_kwargs() + return s, t + + def is_similar_callables(t: CallableType, s: CallableType) -> bool: """Return True if t and s have identical numbers of arguments, default arguments and varargs. diff --git a/mypy/meet.py b/mypy/meet.py index 21637f57f2334..ab47ae2894940 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType: return t if isinstance(s, UnionType) and not isinstance(t, UnionType): s, t = t, s + + # Meets/joins require callable type normalization. + s, t = join.normalize_callables(s, t) + return t.accept(TypeMeetVisitor(s)) diff --git a/mypy/messages.py b/mypy/messages.py index d93541e94c9ce..b4c203058ddc1 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None name = tp.arg_names[i] if name: s += name + ": " - s += format_type_bare(tp.arg_types[i]) + type_str = format_type_bare(tp.arg_types[i]) + if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs: + type_str = f"Unpack[{type_str}]" + s += type_str if tp.arg_kinds[i].is_optional(): s += " = ..." diff --git a/mypy/semanal.py b/mypy/semanal.py index fccd7ffae4bfd..baffbec5dc057 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -263,6 +263,7 @@ TypeVarLikeType, TypeVarType, UnboundType, + UnpackType, get_proper_type, get_proper_types, invalid_recursive_alias, @@ -832,6 +833,8 @@ def analyze_func_def(self, defn: FuncDef) -> None: self.defer(defn) return assert isinstance(result, ProperType) + if isinstance(result, CallableType): + result = self.remove_unpack_kwargs(defn, result) defn.type = result self.add_type_alias_deps(analyzer.aliases_used) self.check_function_signature(defn) @@ -874,6 +877,29 @@ def analyze_func_def(self, defn: FuncDef) -> None: defn.type = defn.type.copy_modified(ret_type=ret_type) self.wrapped_coro_return_types[defn] = defn.type + def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType: + if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2: + return typ + last_type = get_proper_type(typ.arg_types[-1]) + if not isinstance(last_type, UnpackType): + return typ + last_type = get_proper_type(last_type.type) + if not isinstance(last_type, TypedDictType): + self.fail("Unpack item in ** argument must be a TypedDict", defn) + new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)] + return typ.copy_modified(arg_types=new_arg_types) + overlap = set(typ.arg_names) & set(last_type.items) + # It is OK for TypedDict to have a key named 'kwargs'. + overlap.discard(typ.arg_names[-1]) + if overlap: + overlapped = ", ".join([f'"{name}"' for name in overlap]) + self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn) + new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)] + return typ.copy_modified(arg_types=new_arg_types) + # OK, everything looks right now, mark the callable type as using unpack. + new_arg_types = typ.arg_types[:-1] + [last_type] + return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True) + def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None: """Check basic signature validity and tweak annotation of self/cls argument.""" # Only non-static methods are special. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 9e84e25695ddc..d9a16ea049701 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -38,6 +38,7 @@ Instance, LiteralType, NoneType, + NormalizedCallableType, Overloaded, Parameters, ParamSpecType, @@ -591,8 +592,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool: return False def visit_parameters(self, left: Parameters) -> bool: - right = self.right - if isinstance(right, Parameters) or isinstance(right, CallableType): + if isinstance(self.right, Parameters) or isinstance(self.right, CallableType): + right = self.right + if isinstance(right, CallableType): + right = right.with_unpacked_kwargs() return are_parameters_compatible( left, right, @@ -636,7 +639,7 @@ def visit_callable_type(self, left: CallableType) -> bool: elif isinstance(right, Parameters): # this doesn't check return types.... but is needed for is_equivalent return are_parameters_compatible( - left, + left.with_unpacked_kwargs(), right, is_compat=self._is_subtype, ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names, @@ -1213,6 +1216,10 @@ def g(x: int) -> int: ... If the 'some_check' function is also symmetric, the two calls would be equivalent whether or not we check the args covariantly. """ + # Normalize both types before comparing them. + left = left.with_unpacked_kwargs() + right = right.with_unpacked_kwargs() + if is_compat_return is None: is_compat_return = is_compat @@ -1277,8 +1284,8 @@ def g(x: int) -> int: ... def are_parameters_compatible( - left: Parameters | CallableType, - right: Parameters | CallableType, + left: Parameters | NormalizedCallableType, + right: Parameters | NormalizedCallableType, *, is_compat: Callable[[Type, Type], bool], ignore_pos_arg_names: bool = False, @@ -1499,11 +1506,11 @@ def new_is_compat(left: Type, right: Type) -> bool: def unify_generic_callable( - type: CallableType, - target: CallableType, + type: NormalizedCallableType, + target: NormalizedCallableType, ignore_return: bool, return_constraint_direction: int | None = None, -) -> CallableType | None: +) -> NormalizedCallableType | None: """Try to unify a generic callable type with another callable type. Return unified CallableType if successful; otherwise, return None. @@ -1540,7 +1547,7 @@ def report(*args: Any) -> None: ) if had_errors: return None - return applied + return cast(NormalizedCallableType, applied) def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None: diff --git a/mypy/typeanal.py b/mypy/typeanal.py index ae1920e234bb4..44e8e7f6ee9d9 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ elif fullname in ("typing.Unpack", "typing_extensions.Unpack"): # We don't want people to try to use this yet. if not self.options.enable_incomplete_features: - self.fail('"Unpack" is not supported by mypy yet', t) + self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t) return AnyType(TypeOfAny.from_error) return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column) return None diff --git a/mypy/types.py b/mypy/types.py index cfb6c62de1470..82e09c2d40b34 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -11,6 +11,7 @@ Dict, Iterable, NamedTuple, + NewType, Sequence, TypeVar, Union, @@ -1561,6 +1562,9 @@ def __eq__(self, other: object) -> bool: return NotImplemented +CT = TypeVar("CT", bound="CallableType") + + class CallableType(FunctionLike): """Type of a non-overloaded callable object (such as function).""" @@ -1590,6 +1594,7 @@ class CallableType(FunctionLike): "type_guard", # T, if -> TypeGuard[T] (ret_type is bool in this case). "from_concatenate", # whether this callable is from a concatenate object # (this is used for error messages) + "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? ) def __init__( @@ -1613,6 +1618,7 @@ def __init__( def_extras: dict[str, Any] | None = None, type_guard: Type | None = None, from_concatenate: bool = False, + unpack_kwargs: bool = False, ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1653,9 +1659,10 @@ def __init__( else: self.def_extras = {} self.type_guard = type_guard + self.unpack_kwargs = unpack_kwargs def copy_modified( - self, + self: CT, arg_types: Bogus[Sequence[Type]] = _dummy, arg_kinds: Bogus[list[ArgKind]] = _dummy, arg_names: Bogus[list[str | None]] = _dummy, @@ -1674,8 +1681,9 @@ def copy_modified( def_extras: Bogus[dict[str, Any]] = _dummy, type_guard: Bogus[Type | None] = _dummy, from_concatenate: Bogus[bool] = _dummy, - ) -> CallableType: - return CallableType( + unpack_kwargs: Bogus[bool] = _dummy, + ) -> CT: + return type(self)( arg_types=arg_types if arg_types is not _dummy else self.arg_types, arg_kinds=arg_kinds if arg_kinds is not _dummy else self.arg_kinds, arg_names=arg_names if arg_names is not _dummy else self.arg_names, @@ -1698,6 +1706,7 @@ def copy_modified( from_concatenate=( from_concatenate if from_concatenate is not _dummy else self.from_concatenate ), + unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, ) def var_arg(self) -> FormalArgument | None: @@ -1889,6 +1898,27 @@ def expand_param_spec( variables=[*variables, *self.variables], ) + def with_unpacked_kwargs(self) -> NormalizedCallableType: + if not self.unpack_kwargs: + return NormalizedCallableType(self.copy_modified()) + last_type = get_proper_type(self.arg_types[-1]) + assert isinstance(last_type, ProperType) and isinstance(last_type, TypedDictType) + extra_kinds = [ + ArgKind.ARG_NAMED if name in last_type.required_keys else ArgKind.ARG_NAMED_OPT + for name in last_type.items + ] + new_arg_kinds = self.arg_kinds[:-1] + extra_kinds + new_arg_names = self.arg_names[:-1] + list(last_type.items) + new_arg_types = self.arg_types[:-1] + list(last_type.items.values()) + return NormalizedCallableType( + self.copy_modified( + arg_kinds=new_arg_kinds, + arg_names=new_arg_names, + arg_types=new_arg_types, + unpack_kwargs=False, + ) + ) + def __hash__(self) -> int: # self.is_type_obj() will fail if self.fallback.type is a FakeInfo if isinstance(self.fallback.type, FakeInfo): @@ -1940,6 +1970,7 @@ def serialize(self) -> JsonDict: "def_extras": dict(self.def_extras), "type_guard": self.type_guard.serialize() if self.type_guard is not None else None, "from_concatenate": self.from_concatenate, + "unpack_kwargs": self.unpack_kwargs, } @classmethod @@ -1962,9 +1993,16 @@ def deserialize(cls, data: JsonDict) -> CallableType: deserialize_type(data["type_guard"]) if data["type_guard"] is not None else None ), from_concatenate=data["from_concatenate"], + unpack_kwargs=data["unpack_kwargs"], ) +# This is a little safety net to prevent reckless special-casing of callables +# that can potentially break Unpack[...] with **kwargs. +# TODO: use this in more places in checkexpr.py etc? +NormalizedCallableType = NewType("NormalizedCallableType", CallableType) + + class Overloaded(FunctionLike): """Overloaded function type T1, ... Tn, where each Ti is CallableType. @@ -2009,6 +2047,9 @@ def with_name(self, name: str) -> Overloaded: def get_name(self) -> str | None: return self._items[0].name + def with_unpacked_kwargs(self) -> Overloaded: + return Overloaded([i.with_unpacked_kwargs() for i in self.items]) + def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_overloaded(self) @@ -2917,7 +2958,10 @@ def visit_callable_type(self, t: CallableType) -> str: name = t.arg_names[i] if name: s += name + ": " - s += t.arg_types[i].accept(self) + type_str = t.arg_types[i].accept(self) + if t.arg_kinds[i] == ARG_STAR2 and t.unpack_kwargs: + type_str = f"Unpack[{type_str}]" + s += type_str if t.arg_kinds[i].is_optional(): s += " =" diff --git a/mypyc/test-data/run-functions.test b/mypyc/test-data/run-functions.test index b6277c9e8ec41..a32af4c16dcc8 100644 --- a/mypyc/test-data/run-functions.test +++ b/mypyc/test-data/run-functions.test @@ -1235,3 +1235,18 @@ def g() -> None: a.pop() g() + +[case testIncompleteFeatureUnpackKwargsCompiled] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +def foo(**kwargs: Unpack[Person]) -> None: + print(kwargs["name"]) + +# This is not really supported yet, just test that we behave reasonably. +foo(name='Jennifer', age=38) +[out] +Jennifer diff --git a/mypyc/test/test_run.py b/mypyc/test/test_run.py index 62168ff4bb009..28892f8c39209 100644 --- a/mypyc/test/test_run.py +++ b/mypyc/test/test_run.py @@ -184,6 +184,8 @@ def run_case_step(self, testcase: DataDrivenTestCase, incremental_step: int) -> options.export_types = True options.preserve_asts = True options.incremental = self.separate + if "IncompleteFeature" in testcase.name: + options.enable_incomplete_features = True # Avoid checking modules/packages named 'unchecked', to provide a way # to test interacting with code we don't have types for. diff --git a/test-data/unit/check-incremental.test b/test-data/unit/check-incremental.test index 44452e2072b38..28497cb12c7bb 100644 --- a/test-data/unit/check-incremental.test +++ b/test-data/unit/check-incremental.test @@ -5957,3 +5957,26 @@ s: str = td["value"] [out] [out2] tmp/b.py:3: error: Incompatible types in assignment (expression has type "int", variable has type "str") + +[case testUnpackKwargsSerialize] +import m +[file lib.py] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +def foo(**kwargs: Unpack[Person]): + ... + +[file m.py] +from lib import foo +foo(name='Jennifer', age=38) +[file m.py.2] +from lib import foo +foo(name='Jennifer', age="38") +[builtins fixtures/dict.pyi] +[out] +[out2] +tmp/m.py:2: error: Argument "age" to "foo" has incompatible type "str"; expected "int" diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index deded7a52f724..0579d29843dc7 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -576,3 +576,13 @@ class Bar: def f(self, a: Optional[str] = None, /, *, b: bool = False) -> None: ... [builtins fixtures/bool.pyi] + +[case testUnpackWithDuplicateNamePositionalOnly] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int +def foo(name: str, /, **kwargs: Unpack[Person]) -> None: # Allowed + ... +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index 4dc10c9f7489c..ac68e20028a74 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -760,3 +760,289 @@ bar(*good3) bar(*bad1) # E: Argument 1 to "bar" has incompatible type "*I[str]"; expected "float" bar(*bad2) # E: List or tuple expected as variadic arguments [builtins fixtures/dict.pyi] + +-- Keyword arguments unpacking + +[case testUnpackKwargsReveal] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int +def foo(arg: bool, **kwargs: Unpack[Person]) -> None: ... + +reveal_type(foo) # N: Revealed type is "def (arg: builtins.bool, **kwargs: Unpack[TypedDict('__main__.Person', {'name': builtins.str, 'age': builtins.int})])" +[builtins fixtures/dict.pyi] + +[case testUnpackOutsideOfKwargs] +from typing_extensions import Unpack, TypedDict +class Person(TypedDict): + name: str + age: int + +def foo(x: Unpack[Person]) -> None: # E: TypedDict('__main__.Person', {'name': builtins.str, 'age': builtins.int}) cannot be unpacked (must be tuple or TypeVarTuple) + ... +def bar(x: int, *args: Unpack[Person]) -> None: # E: TypedDict('__main__.Person', {'name': builtins.str, 'age': builtins.int}) cannot be unpacked (must be tuple or TypeVarTuple) + ... +def baz(**kwargs: Unpack[Person]) -> None: # OK + ... +[builtins fixtures/dict.pyi] + +[case testUnpackWithoutTypedDict] +from typing_extensions import Unpack + +def foo(**kwargs: Unpack[dict]) -> None: # E: Unpack item in ** argument must be a TypedDict + ... +[builtins fixtures/dict.pyi] + +[case testUnpackWithDuplicateKeywords] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int +def foo(name: str, **kwargs: Unpack[Person]) -> None: # E: Overlap between argument names and ** TypedDict items: "name" + ... +[builtins fixtures/dict.pyi] + +[case testUnpackWithDuplicateKeywordKwargs] +from typing_extensions import Unpack, TypedDict +from typing import Dict, List + +class Spec(TypedDict): + args: List[int] + kwargs: Dict[int, int] +def foo(**kwargs: Unpack[Spec]) -> None: # Allowed + ... +foo(args=[1], kwargs={"2": 3}) # E: Dict entry 0 has incompatible type "str": "int"; expected "int": "int" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsNonIdentifier] +from typing_extensions import Unpack, TypedDict + +Weird = TypedDict("Weird", {"@": int}) + +def foo(**kwargs: Unpack[Weird]) -> None: + reveal_type(kwargs["@"]) # N: Revealed type is "builtins.int" +foo(**{"@": 42}) +foo(**{"no": "way"}) # E: Argument 1 to "foo" has incompatible type "**Dict[str, str]"; expected "int" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsEmpty] +from typing_extensions import Unpack, TypedDict + +Empty = TypedDict("Empty", {}) + +def foo(**kwargs: Unpack[Empty]) -> None: # N: "foo" defined here + reveal_type(kwargs) # N: Revealed type is "TypedDict('__main__.Empty', {})" +foo() +foo(x=1) # E: Unexpected keyword argument "x" for "foo" +[builtins fixtures/dict.pyi] + +[case testUnpackTypedDictTotality] +from typing_extensions import Unpack, TypedDict + +class Circle(TypedDict, total=True): + radius: int + color: str + x: int + y: int + +def foo(**kwargs: Unpack[Circle]): + ... +foo(x=0, y=0, color='orange') # E: Missing named argument "radius" for "foo" + +class Square(TypedDict, total=False): + side: int + color: str + +def bar(**kwargs: Unpack[Square]): + ... +bar(side=12) +[builtins fixtures/dict.pyi] + +[case testUnpackUnexpectedKeyword] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict, total=False): + name: str + age: int + +def foo(**kwargs: Unpack[Person]) -> None: # N: "foo" defined here + ... +foo(name='John', age=42, department='Sales') # E: Unexpected keyword argument "department" for "foo" +foo(name='Jennifer', age=38) +[builtins fixtures/dict.pyi] + +[case testUnpackKeywordTypes] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +def foo(**kwargs: Unpack[Person]): + ... +foo(name='John', age='42') # E: Argument "age" to "foo" has incompatible type "str"; expected "int" +foo(name='Jennifer', age=38) +[builtins fixtures/dict.pyi] + +[case testUnpackKeywordTypesTypedDict] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +class LegacyPerson(TypedDict): + name: str + age: str + +def foo(**kwargs: Unpack[Person]) -> None: + ... +lp = LegacyPerson(name="test", age="42") +foo(**lp) # E: Argument "age" to "foo" has incompatible type "str"; expected "int" +[builtins fixtures/dict.pyi] + +[case testFunctionBodyWithUnpackedKwargs] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +def foo(**kwargs: Unpack[Person]) -> int: + name: str = kwargs['name'] + age: str = kwargs['age'] # E: Incompatible types in assignment (expression has type "int", variable has type "str") + department: str = kwargs['department'] # E: TypedDict "Person" has no key "department" + return kwargs['age'] +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsOverrides] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +class Base: + def foo(self, **kwargs: Unpack[Person]) -> None: ... +class SubGood(Base): + def foo(self, *, name: str, age: int, extra: bool = False) -> None: ... +class SubBad(Base): + def foo(self, *, name: str, age: str) -> None: ... # E: Argument 2 of "foo" is incompatible with supertype "Base"; supertype defines the argument type as "int" \ + # N: This violates the Liskov substitution principle \ + # N: See https://mypy.readthedocs.io/en/stable/common_issues.html#incompatible-overrides +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsOverridesTypedDict] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +class PersonExtra(Person, total=False): + extra: bool + +class Unrelated(TypedDict): + baz: int + +class Base: + def foo(self, **kwargs: Unpack[Person]) -> None: ... +class SubGood(Base): + def foo(self, **kwargs: Unpack[PersonExtra]) -> None: ... +class SubBad(Base): + def foo(self, **kwargs: Unpack[Unrelated]) -> None: ... # E: Signature of "foo" incompatible with supertype "Base" \ + # N: Superclass: \ + # N: def foo(*, name: str, age: int) -> None \ + # N: Subclass: \ + # N: def foo(self, *, baz: int) -> None +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsGeneric] +from typing import Generic, TypeVar +from typing_extensions import Unpack, TypedDict + +T = TypeVar("T") +class Person(TypedDict, Generic[T]): + name: str + value: T + +def foo(**kwargs: Unpack[Person[T]]) -> T: ... +reveal_type(foo(name="test", value=42)) # N: Revealed type is "builtins.int" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsInference] +from typing import Generic, TypeVar, Protocol +from typing_extensions import Unpack, TypedDict + +T_contra = TypeVar("T_contra", contravariant=True) +class CBPerson(Protocol[T_contra]): + def __call__(self, **kwargs: Unpack[Person[T_contra]]) -> None: ... + +T = TypeVar("T") +class Person(TypedDict, Generic[T]): + name: str + value: T + +def test(cb: CBPerson[T]) -> T: ... + +def foo(*, name: str, value: int) -> None: ... +reveal_type(test(foo)) # N: Revealed type is "builtins.int" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsOverload] +from typing import Any, overload +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +class Fruit(TypedDict): + sort: str + taste: int + +@overload +def foo(**kwargs: Unpack[Person]) -> int: ... +@overload +def foo(**kwargs: Unpack[Fruit]) -> str: ... +def foo(**kwargs: Any) -> Any: + ... + +reveal_type(foo(sort="test", taste=999)) # N: Revealed type is "builtins.str" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsJoin] +from typing_extensions import Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +def foo(*, name: str, age: int) -> None: ... +def bar(**kwargs: Unpack[Person]) -> None: ... + +reveal_type([foo, bar]) # N: Revealed type is "builtins.list[def (*, name: builtins.str, age: builtins.int)]" +reveal_type([bar, foo]) # N: Revealed type is "builtins.list[def (*, name: builtins.str, age: builtins.int)]" +[builtins fixtures/dict.pyi] + +[case testUnpackKwargsParamSpec] +from typing import Callable, Any, TypeVar, List +from typing_extensions import ParamSpec, Unpack, TypedDict + +class Person(TypedDict): + name: str + age: int + +P = ParamSpec('P') +T = TypeVar('T') + +def dec(f: Callable[P, T]) -> Callable[P, List[T]]: ... + +@dec +def g(**kwargs: Unpack[Person]) -> int: ... + +reveal_type(g) # N: Revealed type is "def (*, name: builtins.str, age: builtins.int) -> builtins.list[builtins.int]" +[builtins fixtures/dict.pyi] diff --git a/test-data/unit/fine-grained.test b/test-data/unit/fine-grained.test index aa53c6482449e..8ef04562abbf6 100644 --- a/test-data/unit/fine-grained.test +++ b/test-data/unit/fine-grained.test @@ -9818,3 +9818,35 @@ x: str [builtins fixtures/dataclasses.pyi] [out] == + +[case testUnpackKwargsUpdateFine] +# flags: --enable-incomplete-features +import m +[file shared.py] +from typing_extensions import TypedDict + +class Person(TypedDict): + name: str + age: int + +[file shared.py.2] +from typing_extensions import TypedDict + +class Person(TypedDict): + name: str + age: str + +[file lib.py] +from typing_extensions import Unpack +from shared import Person + +def foo(**kwargs: Unpack[Person]): + ... +[file m.py] +from lib import foo +foo(name='Jennifer', age=38) + +[builtins fixtures/dict.pyi] +[out] +== +m.py:2: error: Argument "age" to "foo" has incompatible type "int"; expected "str"