Skip to content

Commit 538621e

Browse files
authored
unittest.case: tighter annotations for various assertions (#8077)
1 parent bae4fd2 commit 538621e

File tree

3 files changed

+108
-58
lines changed

3 files changed

+108
-58
lines changed

stdlib/_typeshed/__init__.pyi

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,21 +50,23 @@ class SupportsAnext(Protocol[_T_co]):
5050

5151
# Comparison protocols
5252

53-
class SupportsDunderLT(Protocol):
54-
def __lt__(self, __other: Any) -> bool: ...
53+
class SupportsDunderLT(Protocol[_T_contra]):
54+
def __lt__(self, __other: _T_contra) -> bool: ...
5555

56-
class SupportsDunderGT(Protocol):
57-
def __gt__(self, __other: Any) -> bool: ...
56+
class SupportsDunderGT(Protocol[_T_contra]):
57+
def __gt__(self, __other: _T_contra) -> bool: ...
5858

59-
class SupportsDunderLE(Protocol):
60-
def __le__(self, __other: Any) -> bool: ...
59+
class SupportsDunderLE(Protocol[_T_contra]):
60+
def __le__(self, __other: _T_contra) -> bool: ...
6161

62-
class SupportsDunderGE(Protocol):
63-
def __ge__(self, __other: Any) -> bool: ...
62+
class SupportsDunderGE(Protocol[_T_contra]):
63+
def __ge__(self, __other: _T_contra) -> bool: ...
6464

65-
class SupportsAllComparisons(SupportsDunderLT, SupportsDunderGT, SupportsDunderLE, SupportsDunderGE, Protocol): ...
65+
class SupportsAllComparisons(
66+
SupportsDunderLT[Any], SupportsDunderGT[Any], SupportsDunderLE[Any], SupportsDunderGE[Any], Protocol
67+
): ...
6668

67-
SupportsRichComparison: TypeAlias = SupportsDunderLT | SupportsDunderGT
69+
SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any]
6870
SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichComparison) # noqa: Y001
6971

7072
# Dunder protocols

stdlib/unittest/case.pyi

Lines changed: 45 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import sys
33
import unittest.result
4-
from _typeshed import Self, SupportsDunderGE, SupportsSub
4+
from _typeshed import Self, SupportsDunderGE, SupportsDunderGT, SupportsDunderLE, SupportsDunderLT, SupportsSub
55
from collections.abc import Callable, Container, Iterable, Mapping, Sequence, Set as AbstractSet
66
from contextlib import AbstractContextManager
77
from types import TracebackType
@@ -17,14 +17,18 @@ from typing import (
1717
SupportsAbs,
1818
SupportsRound,
1919
TypeVar,
20+
Union,
2021
overload,
2122
)
22-
from typing_extensions import ParamSpec
23+
from typing_extensions import ParamSpec, TypeAlias
2324
from warnings import WarningMessage
2425

2526
if sys.version_info >= (3, 9):
2627
from types import GenericAlias
2728

29+
if sys.version_info >= (3, 10):
30+
from types import UnionType
31+
2832
_T = TypeVar("_T")
2933
_S = TypeVar("_S", bound=SupportsSub[Any, Any])
3034
_E = TypeVar("_E", bound=BaseException)
@@ -75,7 +79,12 @@ def skipUnless(condition: object, reason: str) -> Callable[[_FT], _FT]: ...
7579
class SkipTest(Exception):
7680
def __init__(self, reason: str) -> None: ...
7781

78-
class _SupportsAbsAndDunderGE(SupportsDunderGE, SupportsAbs[Any], Protocol): ...
82+
class _SupportsAbsAndDunderGE(SupportsDunderGE[Any], SupportsAbs[Any], Protocol): ...
83+
84+
if sys.version_info >= (3, 10):
85+
_IsInstanceClassInfo: TypeAlias = Union[type, UnionType, tuple[type | UnionType | tuple[Any, ...], ...]]
86+
else:
87+
_IsInstanceClassInfo: TypeAlias = Union[type, tuple[type | tuple[Any, ...], ...]]
7988

8089
class TestCase:
8190
failureException: type[BaseException]
@@ -105,18 +114,30 @@ class TestCase:
105114
def assertNotEqual(self, first: Any, second: Any, msg: Any = ...) -> None: ...
106115
def assertTrue(self, expr: Any, msg: Any = ...) -> None: ...
107116
def assertFalse(self, expr: Any, msg: Any = ...) -> None: ...
108-
def assertIs(self, expr1: Any, expr2: Any, msg: Any = ...) -> None: ...
109-
def assertIsNot(self, expr1: Any, expr2: Any, msg: Any = ...) -> None: ...
110-
def assertIsNone(self, obj: Any, msg: Any = ...) -> None: ...
111-
def assertIsNotNone(self, obj: Any, msg: Any = ...) -> None: ...
117+
def assertIs(self, expr1: object, expr2: object, msg: Any = ...) -> None: ...
118+
def assertIsNot(self, expr1: object, expr2: object, msg: Any = ...) -> None: ...
119+
def assertIsNone(self, obj: object, msg: Any = ...) -> None: ...
120+
def assertIsNotNone(self, obj: object, msg: Any = ...) -> None: ...
112121
def assertIn(self, member: Any, container: Iterable[Any] | Container[Any], msg: Any = ...) -> None: ...
113122
def assertNotIn(self, member: Any, container: Iterable[Any] | Container[Any], msg: Any = ...) -> None: ...
114-
def assertIsInstance(self, obj: Any, cls: type | tuple[type, ...], msg: Any = ...) -> None: ...
115-
def assertNotIsInstance(self, obj: Any, cls: type | tuple[type, ...], msg: Any = ...) -> None: ...
116-
def assertGreater(self, a: Any, b: Any, msg: Any = ...) -> None: ...
117-
def assertGreaterEqual(self, a: Any, b: Any, msg: Any = ...) -> None: ...
118-
def assertLess(self, a: Any, b: Any, msg: Any = ...) -> None: ...
119-
def assertLessEqual(self, a: Any, b: Any, msg: Any = ...) -> None: ...
123+
def assertIsInstance(self, obj: object, cls: _IsInstanceClassInfo, msg: Any = ...) -> None: ...
124+
def assertNotIsInstance(self, obj: object, cls: _IsInstanceClassInfo, msg: Any = ...) -> None: ...
125+
@overload
126+
def assertGreater(self, a: SupportsDunderGT[_T], b: _T, msg: Any = ...) -> None: ...
127+
@overload
128+
def assertGreater(self, a: _T, b: SupportsDunderLT[_T], msg: Any = ...) -> None: ...
129+
@overload
130+
def assertGreaterEqual(self, a: SupportsDunderGE[_T], b: _T, msg: Any = ...) -> None: ...
131+
@overload
132+
def assertGreaterEqual(self, a: _T, b: SupportsDunderLE[_T], msg: Any = ...) -> None: ...
133+
@overload
134+
def assertLess(self, a: SupportsDunderLT[_T], b: _T, msg: Any = ...) -> None: ...
135+
@overload
136+
def assertLess(self, a: _T, b: SupportsDunderGT[_T], msg: Any = ...) -> None: ...
137+
@overload
138+
def assertLessEqual(self, a: SupportsDunderLT[_T], b: _T, msg: Any = ...) -> None: ...
139+
@overload
140+
def assertLessEqual(self, a: _T, b: SupportsDunderGT[_T], msg: Any = ...) -> None: ...
120141
# `assertRaises`, `assertRaisesRegex`, and `assertRaisesRegexp`
121142
# are not using `ParamSpec` intentionally,
122143
# because they might be used with explicitly wrong arg types to raise some error in tests.
@@ -249,45 +270,21 @@ class TestCase:
249270
def _formatMessage(self, msg: str | None, standardMsg: str) -> str: ... # undocumented
250271
def _getAssertEqualityFunc(self, first: Any, second: Any) -> Callable[..., None]: ... # undocumented
251272
if sys.version_info < (3, 12):
252-
def failUnlessEqual(self, first: Any, second: Any, msg: Any = ...) -> None: ...
253-
def assertEquals(self, first: Any, second: Any, msg: Any = ...) -> None: ...
254-
def failIfEqual(self, first: Any, second: Any, msg: Any = ...) -> None: ...
255-
def assertNotEquals(self, first: Any, second: Any, msg: Any = ...) -> None: ...
256-
def failUnless(self, expr: bool, msg: Any = ...) -> None: ...
257-
def assert_(self, expr: bool, msg: Any = ...) -> None: ...
258-
def failIf(self, expr: bool, msg: Any = ...) -> None: ...
259-
@overload
260-
def failUnlessRaises( # type: ignore[misc]
261-
self,
262-
exception: type[BaseException] | tuple[type[BaseException], ...],
263-
callable: Callable[_P, object] = ...,
264-
*args: _P.args,
265-
**kwargs: _P.kwargs,
266-
) -> None: ...
267-
@overload
268-
def failUnlessRaises(self, exception: type[_E] | tuple[type[_E], ...], msg: Any = ...) -> _AssertRaisesContext[_E]: ...
273+
failUnlessEqual = assertEqual
274+
assertEquals = assertEqual
275+
failIfEqual = assertNotEqual
276+
assertNotEquals = assertNotEqual
277+
failUnless = assertTrue
278+
assert_ = assertTrue
279+
failIf = assertFalse
280+
failUnlessRaises = assertRaises
269281
failUnlessAlmostEqual = assertAlmostEqual
270282
assertAlmostEquals = assertAlmostEqual
271283
failIfAlmostEqual = assertNotAlmostEqual
272284
assertNotAlmostEquals = assertNotAlmostEqual
273-
def assertRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
274-
def assertNotRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ...
275-
@overload
276-
def assertRaisesRegexp( # type: ignore[misc]
277-
self,
278-
exception: type[BaseException] | tuple[type[BaseException], ...],
279-
expected_regex: str | bytes | Pattern[str] | Pattern[bytes],
280-
callable: Callable[..., object],
281-
*args: Any,
282-
**kwargs: Any,
283-
) -> None: ...
284-
@overload
285-
def assertRaisesRegexp(
286-
self,
287-
exception: type[_E] | tuple[type[_E], ...],
288-
expected_regex: str | bytes | Pattern[str] | Pattern[bytes],
289-
msg: Any = ...,
290-
) -> _AssertRaisesContext[_E]: ...
285+
assertRegexpMatches = assertRegex
286+
assertNotRegexpMatches = assertNotRegex
287+
assertRaisesRegexp = assertRaisesRegex
291288
def assertDictContainsSubset(
292289
self, subset: Mapping[Any, Any], dictionary: Mapping[Any, Any], msg: object = ...
293290
) -> None: ...

test_cases/stdlib/test_unittest.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
case = unittest.TestCase()
99

10+
###
11+
# Tests for assertAlmostEqual
12+
###
13+
1014
case.assertAlmostEqual(2.4, 2.41)
1115
case.assertAlmostEqual(Fraction(49, 50), Fraction(48, 50))
1216
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
@@ -20,10 +24,57 @@
2024
case.assertAlmostEqual("foo", "bar") # type: ignore
2125
case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore
2226

27+
###
28+
# Tests for assertNotAlmostEqual
29+
###
30+
2331
case.assertNotAlmostEqual(Fraction(49, 50), Fraction(48, 50))
2432
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1))
2533
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1))
2634

2735
case.assertNotAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore
2836
case.assertNotAlmostEqual("foo", "bar") # type: ignore
2937
case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore
38+
39+
###
40+
# Tests for assertGreater
41+
###
42+
43+
44+
class Spam:
45+
def __lt__(self, other: object) -> bool:
46+
return True
47+
48+
49+
class Eggs:
50+
def __gt__(self, other: object) -> bool:
51+
return True
52+
53+
54+
class Ham:
55+
def __lt__(self, other: "Ham") -> bool:
56+
if not isinstance(other, Ham):
57+
return NotImplemented
58+
return True
59+
60+
61+
class Bacon:
62+
def __gt__(self, other: "Bacon") -> bool:
63+
if not isinstance(other, Bacon):
64+
return NotImplemented
65+
return True
66+
67+
68+
case.assertGreater(5.8, 3)
69+
case.assertGreater(Decimal("4.5"), Fraction(3, 2))
70+
case.assertGreater(Fraction(3, 2), 0.9)
71+
case.assertGreater(Eggs(), object())
72+
case.assertGreater(object(), Spam())
73+
case.assertGreater(Ham(), Ham())
74+
case.assertGreater(Bacon(), Bacon())
75+
76+
case.assertGreater(object(), object()) # type: ignore
77+
case.assertGreater(datetime(1999, 1, 2), 1) # type: ignore
78+
case.assertGreater(Spam(), Eggs()) # type: ignore
79+
case.assertGreater(Ham(), Bacon()) # type: ignore
80+
case.assertGreater(Bacon(), Ham()) # type: ignore

0 commit comments

Comments
 (0)