|
10 | 10 | from types import TracebackType
|
11 | 11 | from typing import Any
|
12 | 12 | from typing import Callable
|
| 13 | +from typing import cast |
| 14 | +from typing import Generic |
13 | 15 | from typing import Optional
|
14 | 16 | from typing import overload
|
15 | 17 | from typing import Pattern
|
16 | 18 | from typing import Tuple
|
| 19 | +from typing import TypeVar |
17 | 20 | from typing import Union
|
18 | 21 |
|
19 | 22 | from more_itertools.more import always_iterable
|
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
|
537 | 540 |
|
538 | 541 | # builtin pytest.raises helper
|
539 | 542 |
|
| 543 | +_E = TypeVar("_E", bound=BaseException) |
| 544 | + |
540 | 545 |
|
541 | 546 | @overload
|
542 | 547 | def raises(
|
543 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 548 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
544 | 549 | *,
|
545 | 550 | match: Optional[Union[str, Pattern]] = ... # noqa: W504 (SyntaxError in Python 3.5)
|
546 |
| -) -> "RaisesContext": |
| 551 | +) -> "RaisesContext[_E]": |
547 | 552 | ...
|
548 | 553 |
|
549 | 554 |
|
550 | 555 | @overload
|
551 | 556 | def raises(
|
552 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 557 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
553 | 558 | func: Callable,
|
554 | 559 | *args: Any,
|
555 | 560 | match: Optional[str] = ...,
|
556 | 561 | **kwargs: Any
|
557 |
| -) -> Optional[_pytest._code.ExceptionInfo]: |
| 562 | +) -> Optional[_pytest._code.ExceptionInfo[_E]]: |
558 | 563 | ...
|
559 | 564 |
|
560 | 565 |
|
561 | 566 | def raises(
|
562 |
| - expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]], |
| 567 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
563 | 568 | *args: Any,
|
564 | 569 | match: Optional[Union[str, Pattern]] = None,
|
565 | 570 | **kwargs: Any
|
566 |
| -) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]: |
| 571 | +) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]: |
567 | 572 | r"""
|
568 | 573 | Assert that a code block/function call raises ``expected_exception``
|
569 | 574 | or raise a failure exception otherwise.
|
@@ -702,28 +707,30 @@ def raises(
|
702 | 707 | try:
|
703 | 708 | func(*args[1:], **kwargs)
|
704 | 709 | except expected_exception:
|
705 |
| - return _pytest._code.ExceptionInfo.from_current() |
| 710 | + # Cast to narrow the type to expected_exception (_E). |
| 711 | + return cast( |
| 712 | + _pytest._code.ExceptionInfo[_E], |
| 713 | + _pytest._code.ExceptionInfo.from_current(), |
| 714 | + ) |
706 | 715 | fail(message)
|
707 | 716 |
|
708 | 717 |
|
709 | 718 | raises.Exception = fail.Exception # type: ignore
|
710 | 719 |
|
711 | 720 |
|
712 |
| -class RaisesContext: |
| 721 | +class RaisesContext(Generic[_E]): |
713 | 722 | def __init__(
|
714 | 723 | self,
|
715 |
| - expected_exception: Union[ |
716 |
| - "Type[BaseException]", Tuple["Type[BaseException]", ...] |
717 |
| - ], |
| 724 | + expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]], |
718 | 725 | message: str,
|
719 | 726 | match_expr: Optional[Union[str, Pattern]] = None,
|
720 | 727 | ) -> None:
|
721 | 728 | self.expected_exception = expected_exception
|
722 | 729 | self.message = message
|
723 | 730 | self.match_expr = match_expr
|
724 |
| - self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo] |
| 731 | + self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]] |
725 | 732 |
|
726 |
| - def __enter__(self) -> _pytest._code.ExceptionInfo: |
| 733 | + def __enter__(self) -> _pytest._code.ExceptionInfo[_E]: |
727 | 734 | self.excinfo = _pytest._code.ExceptionInfo.for_later()
|
728 | 735 | return self.excinfo
|
729 | 736 |
|
|
0 commit comments