Skip to content

Commit 32ebe32

Browse files
authored
Use a TypeGuard for dataclasses.is_dataclass(); refine asdict(), astuple(), fields(), replace() (#9362)
1 parent c216b74 commit 32ebe32

File tree

2 files changed

+100
-9
lines changed

2 files changed

+100
-9
lines changed

stdlib/dataclasses.pyi

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import sys
33
import types
44
from builtins import type as Type # alias to avoid name clashes with fields named "type"
55
from collections.abc import Callable, Iterable, Mapping
6-
from typing import Any, Generic, Protocol, TypeVar, overload
7-
from typing_extensions import Literal, TypeAlias
6+
from typing import Any, ClassVar, Generic, Protocol, TypeVar, overload
7+
from typing_extensions import Literal, TypeAlias, TypeGuard
88

99
if sys.version_info >= (3, 9):
1010
from types import GenericAlias
@@ -30,6 +30,11 @@ __all__ = [
3030
if sys.version_info >= (3, 10):
3131
__all__ += ["KW_ONLY"]
3232

33+
class _DataclassInstance(Protocol):
34+
__dataclass_fields__: ClassVar[dict[str, Field[Any]]]
35+
36+
_DataclassT = TypeVar("_DataclassT", bound=_DataclassInstance)
37+
3338
# define _MISSING_TYPE as an enum within the type stubs,
3439
# even though that is not really its type at runtime
3540
# this allows us to use Literal[_MISSING_TYPE.MISSING]
@@ -44,13 +49,13 @@ if sys.version_info >= (3, 10):
4449
class KW_ONLY: ...
4550

4651
@overload
47-
def asdict(obj: Any) -> dict[str, Any]: ...
52+
def asdict(obj: _DataclassInstance) -> dict[str, Any]: ...
4853
@overload
49-
def asdict(obj: Any, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
54+
def asdict(obj: _DataclassInstance, *, dict_factory: Callable[[list[tuple[str, Any]]], _T]) -> _T: ...
5055
@overload
51-
def astuple(obj: Any) -> tuple[Any, ...]: ...
56+
def astuple(obj: _DataclassInstance) -> tuple[Any, ...]: ...
5257
@overload
53-
def astuple(obj: Any, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...
58+
def astuple(obj: _DataclassInstance, *, tuple_factory: Callable[[list[Any]], _T]) -> _T: ...
5459

5560
if sys.version_info >= (3, 8):
5661
# cls argument is now positional-only
@@ -212,8 +217,13 @@ else:
212217
metadata: Mapping[Any, Any] | None = ...,
213218
) -> Any: ...
214219

215-
def fields(class_or_instance: Any) -> tuple[Field[Any], ...]: ...
216-
def is_dataclass(obj: Any) -> bool: ...
220+
def fields(class_or_instance: _DataclassInstance | type[_DataclassInstance]) -> tuple[Field[Any], ...]: ...
221+
@overload
222+
def is_dataclass(obj: _DataclassInstance | type[_DataclassInstance]) -> Literal[True]: ...
223+
@overload
224+
def is_dataclass(obj: type) -> TypeGuard[type[_DataclassInstance]]: ...
225+
@overload
226+
def is_dataclass(obj: object) -> TypeGuard[_DataclassInstance | type[_DataclassInstance]]: ...
217227

218228
class FrozenInstanceError(AttributeError): ...
219229

@@ -285,4 +295,4 @@ else:
285295
frozen: bool = ...,
286296
) -> type: ...
287297

288-
def replace(__obj: _T, **changes: Any) -> _T: ...
298+
def replace(__obj: _DataclassT, **changes: Any) -> _DataclassT: ...
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
from __future__ import annotations
2+
3+
import dataclasses as dc
4+
from typing import Any, Dict, Tuple, Type
5+
from typing_extensions import assert_type
6+
7+
8+
@dc.dataclass
9+
class Foo:
10+
attr: str
11+
12+
13+
assert_type(dc.fields(Foo), Tuple[dc.Field[Any], ...])
14+
15+
# Mypy correctly emits errors on these
16+
# due to the fact it's a dataclass class, not an instance.
17+
# Pyright, however, handles ClassVar members in protocols differently.
18+
# See https://github.com/microsoft/pyright/issues/4339
19+
#
20+
# dc.asdict(Foo)
21+
# dc.astuple(Foo)
22+
# dc.replace(Foo)
23+
24+
if dc.is_dataclass(Foo):
25+
# The inferred type doesn't change
26+
# if it's already known to be a subtype of type[_DataclassInstance]
27+
assert_type(Foo, Type[Foo])
28+
29+
f = Foo(attr="attr")
30+
31+
assert_type(dc.fields(f), Tuple[dc.Field[Any], ...])
32+
assert_type(dc.asdict(f), Dict[str, Any])
33+
assert_type(dc.astuple(f), Tuple[Any, ...])
34+
assert_type(dc.replace(f, attr="new"), Foo)
35+
36+
if dc.is_dataclass(f):
37+
# The inferred type doesn't change
38+
# if it's already known to be a subtype of _DataclassInstance
39+
assert_type(f, Foo)
40+
41+
42+
def test_other_isdataclass_overloads(x: type, y: object) -> None:
43+
# TODO: pyright correctly emits an error on this, but mypy does not -- why?
44+
# dc.fields(x)
45+
46+
dc.fields(y) # type: ignore
47+
48+
dc.asdict(x) # type: ignore
49+
dc.asdict(y) # type: ignore
50+
51+
dc.astuple(x) # type: ignore
52+
dc.astuple(y) # type: ignore
53+
54+
dc.replace(x) # type: ignore
55+
dc.replace(y) # type: ignore
56+
57+
if dc.is_dataclass(x):
58+
assert_type(dc.fields(x), Tuple[dc.Field[Any], ...])
59+
# These should cause type checkers to emit errors
60+
# due to the fact it's a dataclass class, not an instance
61+
dc.asdict(x) # type: ignore
62+
dc.astuple(x) # type: ignore
63+
dc.replace(x) # type: ignore
64+
65+
if dc.is_dataclass(y):
66+
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
67+
68+
# Mypy corrextly emits an error on these due to the fact we don't know
69+
# whether it's a dataclass class or a dataclass instance.
70+
# Pyright, however, handles ClassVar members in protocols differently.
71+
# See https://github.com/microsoft/pyright/issues/4339
72+
#
73+
# dc.asdict(y)
74+
# dc.astuple(y)
75+
# dc.replace(y)
76+
77+
if dc.is_dataclass(y) and not isinstance(y, type):
78+
assert_type(dc.fields(y), Tuple[dc.Field[Any], ...])
79+
assert_type(dc.asdict(y), Dict[str, Any])
80+
assert_type(dc.astuple(y), Tuple[Any, ...])
81+
dc.replace(y)

0 commit comments

Comments
 (0)