diff --git a/stdlib/dataclasses.pyi b/stdlib/dataclasses.pyi index 560147f9e96b..c4a3d7aa5632 100644 --- a/stdlib/dataclasses.pyi +++ b/stdlib/dataclasses.pyi @@ -3,8 +3,8 @@ import sys import types from builtins import type as Type # alias to avoid name clashes with fields named "type" from collections.abc import Callable, Iterable, Mapping -from typing import Any, Generic, Protocol, TypeVar, overload -from typing_extensions import Literal, TypeAlias +from typing import Any, ClassVar, Generic, Protocol, TypeVar, overload +from typing_extensions import Literal, TypeAlias, TypeGuard if sys.version_info >= (3, 9): from types import GenericAlias @@ -30,6 +30,11 @@ __all__ = [ if sys.version_info >= (3, 10): __all__ += ["KW_ONLY"] +class _DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] + +_DataclassT = TypeVar("_DataclassT", bound=_DataclassInstance) + # define _MISSING_TYPE as an enum within the type stubs, # even though that is not really its type at runtime # this allows us to use Literal[_MISSING_TYPE.MISSING] @@ -44,13 +49,13 @@ if sys.version_info >= (3, 10): class KW_ONLY: ... @overload -def asdict(obj: Any) -> dict[str, Any]: ... +def asdict(obj: _DataclassInstance) -> dict[str, Any]: ... @overload -def asdict(obj: Any, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ... +def asdict(obj: _DataclassInstance, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ... @overload -def astuple(obj: Any) -> tuple[Any, ...]: ... +def astuple(obj: _DataclassInstance) -> tuple[Any, ...]: ... @overload -def astuple(obj: Any, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ... +def astuple(obj: _DataclassInstance, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ... if sys.version_info >= (3, 8): # cls argument is now positional-only @@ -212,8 +217,13 @@ else: metadata: Mapping[Any, Any] | None = ..., ) -> Any: ... -def fields(class_or_instance: Any) -> tuple[Field[Any], ...]: ... -def is_dataclass(obj: Any) -> bool: ... +def fields(class_or_instance: _DataclassInstance | type[_DataclassInstance]) -> tuple[Field[Any], ...]: ... +@overload +def is_dataclass(obj: _DataclassInstance | type[_DataclassInstance]) -> Literal[True]: ... +@overload +def is_dataclass(obj: type) -> TypeGuard[type[_DataclassInstance]]: ... +@overload +def is_dataclass(obj: object) -> TypeGuard[_DataclassInstance | type[_DataclassInstance]]: ... class FrozenInstanceError(AttributeError): ... @@ -285,4 +295,4 @@ else: frozen: bool = ..., ) -> type: ... -def replace(__obj: _T, **changes: Any) -> _T: ... +def replace(__obj: _DataclassT, **changes: Any) -> _DataclassT: ... diff --git a/test_cases/stdlib/check_dataclasses.py b/test_cases/stdlib/check_dataclasses.py new file mode 100644 index 000000000000..211d63d65422 --- /dev/null +++ b/test_cases/stdlib/check_dataclasses.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import dataclasses as dc +from typing import Any, Dict, Tuple, Type +from typing_extensions import assert_type + + +@dc.dataclass +class Foo: + attr: str + + +assert_type(dc.fields(Foo), Tuple[dc.Field[Any], ...]) + +# Mypy correctly emits errors on these +# due to the fact it's a dataclass class, not an instance. +# Pyright, however, handles ClassVar members in protocols differently. +# See https://github.com/microsoft/pyright/issues/4339 +# +# dc.asdict(Foo) +# dc.astuple(Foo) +# dc.replace(Foo) + +if dc.is_dataclass(Foo): + # The inferred type doesn't change + # if it's already known to be a subtype of type[_DataclassInstance] + assert_type(Foo, Type[Foo]) + +f = Foo(attr="attr") + +assert_type(dc.fields(f), Tuple[dc.Field[Any], ...]) +assert_type(dc.asdict(f), Dict[str, Any]) +assert_type(dc.astuple(f), Tuple[Any, ...]) +assert_type(dc.replace(f, attr="new"), Foo) + +if dc.is_dataclass(f): + # The inferred type doesn't change + # if it's already known to be a subtype of _DataclassInstance + assert_type(f, Foo) + + +def test_other_isdataclass_overloads(x: type, y: object) -> None: + # TODO: pyright correctly emits an error on this, but mypy does not -- why? + # dc.fields(x) + + dc.fields(y) # type: ignore + + dc.asdict(x) # type: ignore + dc.asdict(y) # type: ignore + + dc.astuple(x) # type: ignore + dc.astuple(y) # type: ignore + + dc.replace(x) # type: ignore + dc.replace(y) # type: ignore + + if dc.is_dataclass(x): + assert_type(dc.fields(x), Tuple[dc.Field[Any], ...]) + # These should cause type checkers to emit errors + # due to the fact it's a dataclass class, not an instance + dc.asdict(x) # type: ignore + dc.astuple(x) # type: ignore + dc.replace(x) # type: ignore + + if dc.is_dataclass(y): + assert_type(dc.fields(y), Tuple[dc.Field[Any], ...]) + + # Mypy corrextly emits an error on these due to the fact we don't know + # whether it's a dataclass class or a dataclass instance. + # Pyright, however, handles ClassVar members in protocols differently. + # See https://github.com/microsoft/pyright/issues/4339 + # + # dc.asdict(y) + # dc.astuple(y) + # dc.replace(y) + + if dc.is_dataclass(y) and not isinstance(y, type): + assert_type(dc.fields(y), Tuple[dc.Field[Any], ...]) + assert_type(dc.asdict(y), Dict[str, Any]) + assert_type(dc.astuple(y), Tuple[Any, ...]) + dc.replace(y)