7
7
from decimal import Decimal
8
8
from itertools import filterfalse
9
9
from numbers import Number
10
+ from types import TracebackType
11
+ from typing import Any
12
+ from typing import Callable
13
+ from typing import Optional
14
+ from typing import overload
15
+ from typing import Pattern
16
+ from typing import Tuple
10
17
from typing import Union
11
18
12
19
from more_itertools .more import always_iterable
15
22
from _pytest .compat import STRING_TYPES
16
23
from _pytest .outcomes import fail
17
24
25
+ if False : # TYPE_CHECKING
26
+ from typing import Type # noqa: F401 (used in type string)
27
+
18
28
BASE_TYPE = (type , STRING_TYPES )
19
29
20
30
@@ -528,7 +538,32 @@ def _is_numpy_array(obj):
528
538
# builtin pytest.raises helper
529
539
530
540
531
- def raises (expected_exception , * args , match = None , ** kwargs ):
541
+ @overload
542
+ def raises (
543
+ expected_exception : Union ["Type[BaseException]" , Tuple ["Type[BaseException]" , ...]],
544
+ * ,
545
+ match : Optional [Union [str , Pattern ]] = ... # noqa: W504 (SyntaxError in Python 3.5)
546
+ ) -> "RaisesContext" :
547
+ ...
548
+
549
+
550
+ @overload
551
+ def raises (
552
+ expected_exception : Union ["Type[BaseException]" , Tuple ["Type[BaseException]" , ...]],
553
+ func : Callable ,
554
+ * args : Any ,
555
+ match : Optional [str ] = ...,
556
+ ** kwargs : Any
557
+ ) -> Optional [_pytest ._code .ExceptionInfo ]:
558
+ ...
559
+
560
+
561
+ def raises (
562
+ expected_exception : Union ["Type[BaseException]" , Tuple ["Type[BaseException]" , ...]],
563
+ * args : Any ,
564
+ match : Optional [Union [str , Pattern ]] = None ,
565
+ ** kwargs : Any
566
+ ) -> Union ["RaisesContext" , Optional [_pytest ._code .ExceptionInfo ]]:
532
567
r"""
533
568
Assert that a code block/function call raises ``expected_exception``
534
569
or raise a failure exception otherwise.
@@ -653,7 +688,9 @@ def raises(expected_exception, *args, match=None, **kwargs):
653
688
message = "DID NOT RAISE {}" .format (expected_exception )
654
689
655
690
if not args :
656
- return RaisesContext (
691
+ # Type ignore because **kwargs is passed intentionally to trigger
692
+ # a TypeError if provided.
693
+ return RaisesContext ( # type: ignore
657
694
expected_exception , message = message , match_expr = match , ** kwargs
658
695
)
659
696
else :
@@ -673,21 +710,35 @@ def raises(expected_exception, *args, match=None, **kwargs):
673
710
674
711
675
712
class RaisesContext :
676
- def __init__ (self , expected_exception , message , match_expr ):
713
+ def __init__ (
714
+ self ,
715
+ expected_exception : Union [
716
+ "Type[BaseException]" , Tuple ["Type[BaseException]" , ...]
717
+ ],
718
+ message : str ,
719
+ match_expr : Optional [Union [str , Pattern ]] = None ,
720
+ ) -> None :
677
721
self .expected_exception = expected_exception
678
722
self .message = message
679
723
self .match_expr = match_expr
680
- self .excinfo = None
724
+ self .excinfo = None # type: Optional[_pytest._code.ExceptionInfo]
681
725
682
- def __enter__ (self ):
726
+ def __enter__ (self ) -> _pytest . _code . ExceptionInfo :
683
727
self .excinfo = _pytest ._code .ExceptionInfo .for_later ()
684
728
return self .excinfo
685
729
686
- def __exit__ (self , * tp ):
730
+ def __exit__ (
731
+ self ,
732
+ exc_type : Optional ["Type[BaseException]" ],
733
+ exc_val : Optional [BaseException ],
734
+ exc_tb : Optional [TracebackType ],
735
+ ) -> bool :
687
736
__tracebackhide__ = True
688
- if tp [ 0 ] is None :
737
+ if exc_type is None :
689
738
fail (self .message )
690
- self .excinfo .__init__ (tp )
739
+ assert self .excinfo is not None
740
+ # Type ignored because mypy doesn't like calling __init__ directly like this.
741
+ self .excinfo .__init__ ((exc_type , exc_val , exc_tb )) # type: ignore
691
742
suppress_exception = issubclass (self .excinfo .type , self .expected_exception )
692
743
if self .match_expr is not None and suppress_exception :
693
744
self .excinfo .match (self .match_expr )
0 commit comments