|
1 | 1 | # Stubs for warnings
|
2 | 2 |
|
3 |
| -from typing import Any, Dict, List, NamedTuple, Optional, overload, TextIO, Tuple, Type, Union |
4 |
| -from types import ModuleType, TracebackType |
| 3 | +import sys |
| 4 | +from typing import Any, Dict, List, NamedTuple, Optional, overload, TextIO, Tuple, Type, Union, ContextManager |
| 5 | +from types import ModuleType |
| 6 | + |
| 7 | +if sys.version_info >= (3, 8): |
| 8 | + from typing import Literal |
| 9 | +else: |
| 10 | + from typing_extensions import Literal |
5 | 11 |
|
6 | 12 | @overload
|
7 | 13 | def warn(message: str, category: Optional[Type[Warning]] = ..., stacklevel: int = ...) -> None: ...
|
@@ -37,10 +43,12 @@ class _Record(NamedTuple):
|
37 | 43 | file: Optional[TextIO]
|
38 | 44 | line: Optional[str]
|
39 | 45 |
|
40 |
| -class catch_warnings: |
41 |
| - def __init__(self, *, record: bool = ..., |
42 |
| - module: Optional[ModuleType] = ...) -> None: ... |
43 |
| - def __enter__(self) -> Optional[List[_Record]]: ... |
44 |
| - def __exit__(self, exc_type: Optional[Type[BaseException]], |
45 |
| - exc_val: Optional[BaseException], |
46 |
| - exc_tb: Optional[TracebackType]) -> None: ... |
| 46 | + |
| 47 | +@overload |
| 48 | +def catch_warnings(*, record: Literal[False] = ..., module: Optional[ModuleType] = ...) -> ContextManager[None]: ... |
| 49 | + |
| 50 | +@overload |
| 51 | +def catch_warnings(*, record: Literal[True], module: Optional[ModuleType] = ...) -> ContextManager[List[_Record]]: ... |
| 52 | + |
| 53 | +@overload |
| 54 | +def catch_warnings(*, record: bool, module: Optional[ModuleType] = ...) -> ContextManager[Optional[List[_Record]]]: ... |
0 commit comments