diff --git a/stdlib/_typeshed/__init__.pyi b/stdlib/_typeshed/__init__.pyi index aef6b553ea92..162c40522224 100644 --- a/stdlib/_typeshed/__init__.pyi +++ b/stdlib/_typeshed/__init__.pyi @@ -75,6 +75,9 @@ class SupportsAdd(Protocol[_T_contra, _T_co]): class SupportsRAdd(Protocol[_T_contra, _T_co]): def __radd__(self, __x: _T_contra) -> _T_co: ... +class SupportsSub(Protocol[_T_contra, _T_co]): + def __sub__(self, __x: _T_contra) -> _T_co: ... + class SupportsDivMod(Protocol[_T_contra, _T_co]): def __divmod__(self, __other: _T_contra) -> _T_co: ... diff --git a/stdlib/unittest/case.pyi b/stdlib/unittest/case.pyi index 7f2928233c10..15b573edeebb 100644 --- a/stdlib/unittest/case.pyi +++ b/stdlib/unittest/case.pyi @@ -1,12 +1,24 @@ -import datetime import logging import sys import unittest.result -from _typeshed import Self +from _typeshed import Self, SupportsDunderGE, SupportsSub from collections.abc import Callable, Container, Iterable, Mapping, Sequence, Set as AbstractSet from contextlib import AbstractContextManager from types import TracebackType -from typing import Any, AnyStr, ClassVar, Generic, NamedTuple, NoReturn, Pattern, TypeVar, overload +from typing import ( + Any, + AnyStr, + ClassVar, + Generic, + NamedTuple, + NoReturn, + Pattern, + Protocol, + SupportsAbs, + SupportsRound, + TypeVar, + overload, +) from typing_extensions import ParamSpec from warnings import WarningMessage @@ -14,6 +26,7 @@ if sys.version_info >= (3, 9): from types import GenericAlias _T = TypeVar("_T") +_S = TypeVar("_S", bound=SupportsSub[Any, Any]) _E = TypeVar("_E", bound=BaseException) _FT = TypeVar("_FT", bound=Callable[..., Any]) _P = ParamSpec("_P") @@ -62,6 +75,8 @@ def skipUnless(condition: object, reason: str) -> Callable[[_FT], _FT]: ... class SkipTest(Exception): def __init__(self, reason: str) -> None: ... +class _SupportsAbsAndDunderGE(SupportsDunderGE, SupportsAbs[Any], Protocol): ... + class TestCase: failureException: type[BaseException] longMessage: bool @@ -165,33 +180,35 @@ class TestCase: self, logger: str | logging.Logger | None = ..., level: int | str | None = ... ) -> _AssertLogsContext[None]: ... + @overload + def assertAlmostEqual(self, first: _S, second: _S, places: None, msg: Any, delta: _SupportsAbsAndDunderGE) -> None: ... @overload def assertAlmostEqual( - self, first: float, second: float, places: int | None = ..., msg: Any = ..., delta: float | None = ... + self, first: _S, second: _S, places: None = ..., msg: Any = ..., *, delta: _SupportsAbsAndDunderGE ) -> None: ... @overload def assertAlmostEqual( self, - first: datetime.datetime, - second: datetime.datetime, + first: SupportsSub[_T, SupportsAbs[SupportsRound[object]]], + second: _T, places: int | None = ..., msg: Any = ..., - delta: datetime.timedelta | None = ..., + delta: None = ..., ) -> None: ... @overload - def assertNotAlmostEqual(self, first: float, second: float, *, msg: Any = ...) -> None: ... + def assertNotAlmostEqual(self, first: _S, second: _S, places: None, msg: Any, delta: _SupportsAbsAndDunderGE) -> None: ... @overload - def assertNotAlmostEqual(self, first: float, second: float, places: int | None = ..., msg: Any = ...) -> None: ... - @overload - def assertNotAlmostEqual(self, first: float, second: float, *, msg: Any = ..., delta: float | None = ...) -> None: ... + def assertNotAlmostEqual( + self, first: _S, second: _S, places: None = ..., msg: Any = ..., *, delta: _SupportsAbsAndDunderGE + ) -> None: ... @overload def assertNotAlmostEqual( self, - first: datetime.datetime, - second: datetime.datetime, + first: SupportsSub[_T, SupportsAbs[SupportsRound[object]]], + second: _T, places: int | None = ..., msg: Any = ..., - delta: datetime.timedelta | None = ..., + delta: None = ..., ) -> None: ... def assertRegex(self, text: AnyStr, expected_regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ... def assertNotRegex(self, text: AnyStr, unexpected_regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ... @@ -249,14 +266,10 @@ class TestCase: ) -> None: ... @overload def failUnlessRaises(self, exception: type[_E] | tuple[type[_E], ...], msg: Any = ...) -> _AssertRaisesContext[_E]: ... - def failUnlessAlmostEqual(self, first: float, second: float, places: int = ..., msg: Any = ...) -> None: ... - def assertAlmostEquals( - self, first: float, second: float, places: int = ..., msg: Any = ..., delta: float = ... - ) -> None: ... - def failIfAlmostEqual(self, first: float, second: float, places: int = ..., msg: Any = ...) -> None: ... - def assertNotAlmostEquals( - self, first: float, second: float, places: int = ..., msg: Any = ..., delta: float = ... - ) -> None: ... + failUnlessAlmostEqual = assertAlmostEqual + assertAlmostEquals = assertAlmostEqual + failIfAlmostEqual = assertNotAlmostEqual + assertNotAlmostEquals = assertNotAlmostEqual def assertRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ... def assertNotRegexpMatches(self, text: AnyStr, regex: AnyStr | Pattern[AnyStr], msg: Any = ...) -> None: ... @overload diff --git a/test_cases/stdlib/test_unittest.py b/test_cases/stdlib/test_unittest.py new file mode 100644 index 000000000000..086a1aef6211 --- /dev/null +++ b/test_cases/stdlib/test_unittest.py @@ -0,0 +1,29 @@ +# pyright: reportUnnecessaryTypeIgnoreComment=true + +import unittest +from datetime import datetime, timedelta +from decimal import Decimal +from fractions import Fraction + +case = unittest.TestCase() + +case.assertAlmostEqual(2.4, 2.41) +case.assertAlmostEqual(Fraction(49, 50), Fraction(48, 50)) +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1)) +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1)) +case.assertAlmostEqual(Decimal("1.1"), Decimal("1.11")) +case.assertAlmostEqual(2.4, 2.41, places=8) +case.assertAlmostEqual(2.4, 2.41, delta=0.02) +case.assertAlmostEqual(2.4, 2.41, None, "foo", 0.02) + +case.assertAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore[call-overload] +case.assertAlmostEqual("foo", "bar") # type: ignore[call-overload] +case.assertAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore[arg-type] + +case.assertNotAlmostEqual(Fraction(49, 50), Fraction(48, 50)) +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), delta=timedelta(hours=1)) +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1), None, "foo", timedelta(hours=1)) + +case.assertNotAlmostEqual(2.4, 2.41, places=9, delta=0.02) # type: ignore[call-overload] +case.assertNotAlmostEqual("foo", "bar") # type: ignore[call-overload] +case.assertNotAlmostEqual(datetime(1999, 1, 2), datetime(1999, 1, 2, microsecond=1)) # type: ignore[arg-type]