Skip to content

Commit 9881f44

Browse files
committed
Make ExceptionInfo generic in the exception type
This way, in with pytest.raises(ValueError) as cm: ... cm.value is a ValueError and not a BaseException.
1 parent adbbbfc commit 9881f44

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

src/_pytest/_code/code.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from inspect import CO_VARKEYWORDS
77
from traceback import format_exception_only
88
from types import TracebackType
9+
from typing import Generic
910
from typing import Optional
1011
from typing import Pattern
1112
from typing import Tuple
13+
from typing import TypeVar
1214
from typing import Union
1315
from weakref import ref
1416

@@ -379,22 +381,25 @@ def recursionindex(self):
379381
)
380382

381383

384+
_E = TypeVar("_E", bound=BaseException)
385+
386+
382387
@attr.s(repr=False)
383-
class ExceptionInfo:
388+
class ExceptionInfo(Generic[_E]):
384389
""" wraps sys.exc_info() objects and offers
385390
help for navigating the traceback.
386391
"""
387392

388393
_assert_start_repr = "AssertionError('assert "
389394

390-
_excinfo = attr.ib(
391-
type=Optional[Tuple["Type[BaseException]", BaseException, TracebackType]]
392-
)
395+
_excinfo = attr.ib(type=Optional[Tuple["Type[_E]", "_E", TracebackType]])
393396
_striptext = attr.ib(type=str, default="")
394397
_traceback = attr.ib(type=Optional[Traceback], default=None)
395398

396399
@classmethod
397-
def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
400+
def from_current(
401+
cls, exprinfo: Optional[str] = None
402+
) -> "ExceptionInfo[BaseException]":
398403
"""returns an ExceptionInfo matching the current traceback
399404
400405
.. warning::
@@ -422,21 +427,21 @@ def from_current(cls, exprinfo: Optional[str] = None) -> "ExceptionInfo":
422427
return cls(tup, _striptext)
423428

424429
@classmethod
425-
def for_later(cls) -> "ExceptionInfo":
430+
def for_later(cls) -> "ExceptionInfo[_E]":
426431
"""return an unfilled ExceptionInfo
427432
"""
428433
return cls(None)
429434

430435
@property
431-
def type(self) -> "Type[BaseException]":
436+
def type(self) -> "Type[_E]":
432437
"""the exception class"""
433438
assert (
434439
self._excinfo is not None
435440
), ".type can only be used after the context manager exits"
436441
return self._excinfo[0]
437442

438443
@property
439-
def value(self) -> BaseException:
444+
def value(self) -> _E:
440445
"""the exception value"""
441446
assert (
442447
self._excinfo is not None

src/_pytest/python_api.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,13 @@
1010
from types import TracebackType
1111
from typing import Any
1212
from typing import Callable
13+
from typing import cast
14+
from typing import Generic
1315
from typing import Optional
1416
from typing import overload
1517
from typing import Pattern
1618
from typing import Tuple
19+
from typing import TypeVar
1720
from typing import Union
1821

1922
from more_itertools.more import always_iterable
@@ -537,33 +540,35 @@ def _is_numpy_array(obj):
537540

538541
# builtin pytest.raises helper
539542

543+
_E = TypeVar("_E", bound=BaseException)
544+
540545

541546
@overload
542547
def raises(
543-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
548+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
544549
*,
545550
match: Optional[Union[str, Pattern]] = ... # noqa: W504 (SyntaxError in Python 3.5)
546-
) -> "RaisesContext":
551+
) -> "RaisesContext[_E]":
547552
...
548553

549554

550555
@overload
551556
def raises(
552-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
557+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
553558
func: Callable,
554559
*args: Any,
555560
match: Optional[str] = ...,
556561
**kwargs: Any
557-
) -> Optional[_pytest._code.ExceptionInfo]:
562+
) -> Optional[_pytest._code.ExceptionInfo[_E]]:
558563
...
559564

560565

561566
def raises(
562-
expected_exception: Union["Type[BaseException]", Tuple["Type[BaseException]", ...]],
567+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
563568
*args: Any,
564569
match: Optional[Union[str, Pattern]] = None,
565570
**kwargs: Any
566-
) -> Union["RaisesContext", Optional[_pytest._code.ExceptionInfo]]:
571+
) -> Union["RaisesContext[_E]", Optional[_pytest._code.ExceptionInfo[_E]]]:
567572
r"""
568573
Assert that a code block/function call raises ``expected_exception``
569574
or raise a failure exception otherwise.
@@ -702,28 +707,30 @@ def raises(
702707
try:
703708
func(*args[1:], **kwargs)
704709
except expected_exception:
705-
return _pytest._code.ExceptionInfo.from_current()
710+
# Cast to narrow the type to expected_exception (_E).
711+
return cast(
712+
_pytest._code.ExceptionInfo[_E],
713+
_pytest._code.ExceptionInfo.from_current(),
714+
)
706715
fail(message)
707716

708717

709718
raises.Exception = fail.Exception # type: ignore
710719

711720

712-
class RaisesContext:
721+
class RaisesContext(Generic[_E]):
713722
def __init__(
714723
self,
715-
expected_exception: Union[
716-
"Type[BaseException]", Tuple["Type[BaseException]", ...]
717-
],
724+
expected_exception: Union["Type[_E]", Tuple["Type[_E]", ...]],
718725
message: str,
719726
match_expr: Optional[Union[str, Pattern]] = None,
720727
) -> None:
721728
self.expected_exception = expected_exception
722729
self.message = message
723730
self.match_expr = match_expr
724-
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
731+
self.excinfo = None # type: Optional[_pytest._code.ExceptionInfo[_E]]
725732

726-
def __enter__(self) -> _pytest._code.ExceptionInfo:
733+
def __enter__(self) -> _pytest._code.ExceptionInfo[_E]:
727734
self.excinfo = _pytest._code.ExceptionInfo.for_later()
728735
return self.excinfo
729736

0 commit comments

Comments
 (0)