@@ -3,8 +3,8 @@ import sys
3
3
import types
4
4
from builtins import type as Type # alias to avoid name clashes with fields named "type"
5
5
from collections .abc import Callable , Iterable , Mapping
6
- from typing import Any , Generic , Protocol , TypeVar , overload
7
- from typing_extensions import Literal , TypeAlias
6
+ from typing import Any , ClassVar , Generic , Protocol , TypeVar , overload
7
+ from typing_extensions import Literal , TypeAlias , TypeGuard
8
8
9
9
if sys .version_info >= (3 , 9 ):
10
10
from types import GenericAlias
@@ -30,6 +30,11 @@ __all__ = [
30
30
if sys .version_info >= (3 , 10 ):
31
31
__all__ += ["KW_ONLY" ]
32
32
33
+ class _DataclassInstance (Protocol ):
34
+ __dataclass_fields__ : ClassVar [dict [str , Field [Any ]]]
35
+
36
+ _DataclassT = TypeVar ("_DataclassT" , bound = _DataclassInstance )
37
+
33
38
# define _MISSING_TYPE as an enum within the type stubs,
34
39
# even though that is not really its type at runtime
35
40
# this allows us to use Literal[_MISSING_TYPE.MISSING]
@@ -44,13 +49,13 @@ if sys.version_info >= (3, 10):
44
49
class KW_ONLY : ...
45
50
46
51
@overload
47
- def asdict (obj : Any ) -> dict [str , Any ]: ...
52
+ def asdict (obj : _DataclassInstance ) -> dict [str , Any ]: ...
48
53
@overload
49
- def asdict (obj : Any , * , dict_factory : Callable [[list [tuple [str , Any ]]], _T ]) -> _T : ...
54
+ def asdict (obj : _DataclassInstance , * , dict_factory : Callable [[list [tuple [str , Any ]]], _T ]) -> _T : ...
50
55
@overload
51
- def astuple (obj : Any ) -> tuple [Any , ...]: ...
56
+ def astuple (obj : _DataclassInstance ) -> tuple [Any , ...]: ...
52
57
@overload
53
- def astuple (obj : Any , * , tuple_factory : Callable [[list [Any ]], _T ]) -> _T : ...
58
+ def astuple (obj : _DataclassInstance , * , tuple_factory : Callable [[list [Any ]], _T ]) -> _T : ...
54
59
55
60
if sys .version_info >= (3 , 8 ):
56
61
# cls argument is now positional-only
@@ -212,8 +217,13 @@ else:
212
217
metadata : Mapping [Any , Any ] | None = ...,
213
218
) -> Any : ...
214
219
215
- def fields (class_or_instance : Any ) -> tuple [Field [Any ], ...]: ...
216
- def is_dataclass (obj : Any ) -> bool : ...
220
+ def fields (class_or_instance : _DataclassInstance | type [_DataclassInstance ]) -> tuple [Field [Any ], ...]: ...
221
+ @overload
222
+ def is_dataclass (obj : _DataclassInstance | type [_DataclassInstance ]) -> Literal [True ]: ...
223
+ @overload
224
+ def is_dataclass (obj : type ) -> TypeGuard [type [_DataclassInstance ]]: ...
225
+ @overload
226
+ def is_dataclass (obj : object ) -> TypeGuard [_DataclassInstance | type [_DataclassInstance ]]: ...
217
227
218
228
class FrozenInstanceError (AttributeError ): ...
219
229
@@ -285,4 +295,4 @@ else:
285
295
frozen : bool = ...,
286
296
) -> type : ...
287
297
288
- def replace (__obj : _T , ** changes : Any ) -> _T : ...
298
+ def replace (__obj : _DataclassT , ** changes : Any ) -> _DataclassT : ...
0 commit comments