Skip to content

Commit 5670ca2

Browse files
authored
Add SupportsRichComparison type to _typeshed (#6583)
Use it to improve types of `max()` and other functions. Also make some other tweaks to types related to comparison dunders. Fixes #6575
1 parent 968fd6d commit 5670ca2

File tree

9 files changed

+73
-69
lines changed

9 files changed

+73
-69
lines changed

stdlib/_bisect.pyi

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
11
import sys
2-
from _typeshed import SupportsLessThan
2+
from _typeshed import SupportsRichComparison
33
from typing import Callable, MutableSequence, Sequence, TypeVar
44

55
_T = TypeVar("_T")
66

77
if sys.version_info >= (3, 10):
88
def bisect_left(
9-
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
9+
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsRichComparison] | None = ...
1010
) -> int: ...
1111
def bisect_right(
12-
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
12+
a: Sequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsRichComparison] | None = ...
1313
) -> int: ...
1414
def insort_left(
15-
a: MutableSequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
15+
a: MutableSequence[_T],
16+
x: _T,
17+
lo: int = ...,
18+
hi: int | None = ...,
19+
*,
20+
key: Callable[[_T], SupportsRichComparison] | None = ...,
1621
) -> None: ...
1722
def insort_right(
18-
a: MutableSequence[_T], x: _T, lo: int = ..., hi: int | None = ..., *, key: Callable[[_T], SupportsLessThan] | None = ...
23+
a: MutableSequence[_T],
24+
x: _T,
25+
lo: int = ...,
26+
hi: int | None = ...,
27+
*,
28+
key: Callable[[_T], SupportsRichComparison] | None = ...,
1929
) -> None: ...
2030

2131
else:

stdlib/_operator.pyi

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
from _typeshed import SupportsAnyComparison
23
from typing import (
34
Any,
45
AnyStr,
@@ -14,7 +15,6 @@ from typing import (
1415
SupportsAbs,
1516
Tuple,
1617
TypeVar,
17-
Union,
1818
overload,
1919
)
2020
from typing_extensions import ParamSpec, SupportsIndex, final
@@ -35,29 +35,13 @@ class _SupportsNeg(Protocol[_T_co]):
3535
class _SupportsPos(Protocol[_T_co]):
3636
def __pos__(self) -> _T_co: ...
3737

38-
# Different to _typeshed.SupportsLessThan
39-
class _SupportsLT(Protocol):
40-
def __lt__(self, __other: Any) -> Any: ...
41-
42-
class _SupportsGT(Protocol):
43-
def __gt__(self, __other: Any) -> Any: ...
44-
45-
class _SupportsLE(Protocol):
46-
def __le__(self, __other: Any) -> Any: ...
47-
48-
class _SupportsGE(Protocol):
49-
def __ge__(self, __other: Any) -> Any: ...
50-
51-
# We get false-positive errors if e.g. `lt` does not have the same signature as `le`,
52-
# so a broad union type is required for all four comparison methods
53-
_SupportsComparison = Union[_SupportsLE, _SupportsGE, _SupportsGT, _SupportsLT]
54-
55-
def lt(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
56-
def le(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
38+
# All four comparison functions must have the same signature, or we get false-positive errors
39+
def lt(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
40+
def le(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
5741
def eq(__a: object, __b: object) -> Any: ...
5842
def ne(__a: object, __b: object) -> Any: ...
59-
def ge(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
60-
def gt(__a: _SupportsComparison, __b: _SupportsComparison) -> Any: ...
43+
def ge(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
44+
def gt(__a: SupportsAnyComparison, __b: SupportsAnyComparison) -> Any: ...
6145
def not_(__a: object) -> bool: ...
6246
def truth(__a: object) -> bool: ...
6347
def is_(__a: object, __b: object) -> bool: ...

stdlib/_typeshed/__init__.pyi

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,25 @@ class SupportsNext(Protocol[_T_co]):
3535
class SupportsAnext(Protocol[_T_co]):
3636
def __anext__(self) -> Awaitable[_T_co]: ...
3737

38-
class SupportsLessThan(Protocol):
39-
def __lt__(self, __other: Any) -> bool: ...
38+
# Comparison protocols
4039

41-
SupportsLessThanT = TypeVar("SupportsLessThanT", bound=SupportsLessThan) # noqa: Y001
40+
class SupportsDunderLT(Protocol):
41+
def __lt__(self, __other: Any) -> Any: ...
4242

43-
class SupportsGreaterThan(Protocol):
44-
def __gt__(self, __other: Any) -> bool: ...
43+
class SupportsDunderGT(Protocol):
44+
def __gt__(self, __other: Any) -> Any: ...
4545

46-
SupportsGreaterThanT = TypeVar("SupportsGreaterThanT", bound=SupportsGreaterThan) # noqa: Y001
46+
class SupportsDunderLE(Protocol):
47+
def __le__(self, __other: Any) -> Any: ...
48+
49+
class SupportsDunderGE(Protocol):
50+
def __ge__(self, __other: Any) -> Any: ...
51+
52+
class SupportsAllComparisons(SupportsDunderLT, SupportsDunderGT, SupportsDunderLE, SupportsDunderGE, Protocol): ...
53+
54+
SupportsRichComparison = Union[SupportsDunderLT, SupportsDunderGT]
55+
SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichComparison) # noqa: Y001
56+
SupportsAnyComparison = Union[SupportsDunderLE, SupportsDunderGE, SupportsDunderGT, SupportsDunderLT]
4757

4858
class SupportsDivMod(Protocol[_T_contra, _T_co]):
4959
def __divmod__(self, __other: _T_contra) -> _T_co: ...

stdlib/builtins.pyi

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,12 @@ from _typeshed import (
1313
StrOrBytesPath,
1414
SupportsAnext,
1515
SupportsDivMod,
16-
SupportsGreaterThan,
17-
SupportsGreaterThanT,
1816
SupportsKeysAndGetItem,
1917
SupportsLenAndGetItem,
20-
SupportsLessThan,
21-
SupportsLessThanT,
2218
SupportsNext,
2319
SupportsRDivMod,
20+
SupportsRichComparison,
21+
SupportsRichComparisonT,
2422
SupportsTrunc,
2523
SupportsWrite,
2624
)
@@ -783,9 +781,9 @@ class list(MutableSequence[_T], Generic[_T]):
783781
def remove(self, __value: _T) -> None: ...
784782
# Signature of `list.sort` should be kept inline with `collections.UserList.sort()`
785783
@overload
786-
def sort(self: list[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> None: ...
784+
def sort(self: list[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...) -> None: ...
787785
@overload
788-
def sort(self, *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> None: ...
786+
def sort(self, *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> None: ...
789787
def __len__(self) -> int: ...
790788
def __iter__(self) -> Iterator[_T]: ...
791789
def __str__(self) -> str: ...
@@ -1143,32 +1141,32 @@ class map(Iterator[_S], Generic[_S]):
11431141

11441142
@overload
11451143
def max(
1146-
__arg1: SupportsGreaterThanT, __arg2: SupportsGreaterThanT, *_args: SupportsGreaterThanT, key: None = ...
1147-
) -> SupportsGreaterThanT: ...
1144+
__arg1: SupportsRichComparisonT, __arg2: SupportsRichComparisonT, *_args: SupportsRichComparisonT, key: None = ...
1145+
) -> SupportsRichComparisonT: ...
11481146
@overload
1149-
def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsGreaterThan]) -> _T: ...
1147+
def max(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsRichComparisonT]) -> _T: ...
11501148
@overload
1151-
def max(__iterable: Iterable[SupportsGreaterThanT], *, key: None = ...) -> SupportsGreaterThanT: ...
1149+
def max(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ...) -> SupportsRichComparisonT: ...
11521150
@overload
1153-
def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsGreaterThan]) -> _T: ...
1151+
def max(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
11541152
@overload
1155-
def max(__iterable: Iterable[SupportsGreaterThanT], *, key: None = ..., default: _T) -> SupportsGreaterThanT | _T: ...
1153+
def max(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., default: _T) -> SupportsRichComparisonT | _T: ...
11561154
@overload
1157-
def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsGreaterThan], default: _T2) -> _T1 | _T2: ...
1155+
def max(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsRichComparison], default: _T2) -> _T1 | _T2: ...
11581156
@overload
11591157
def min(
1160-
__arg1: SupportsLessThanT, __arg2: SupportsLessThanT, *_args: SupportsLessThanT, key: None = ...
1161-
) -> SupportsLessThanT: ...
1158+
__arg1: SupportsRichComparisonT, __arg2: SupportsRichComparisonT, *_args: SupportsRichComparisonT, key: None = ...
1159+
) -> SupportsRichComparisonT: ...
11621160
@overload
1163-
def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsLessThan]) -> _T: ...
1161+
def min(__arg1: _T, __arg2: _T, *_args: _T, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
11641162
@overload
1165-
def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ...) -> SupportsLessThanT: ...
1163+
def min(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ...) -> SupportsRichComparisonT: ...
11661164
@overload
1167-
def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan]) -> _T: ...
1165+
def min(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison]) -> _T: ...
11681166
@overload
1169-
def min(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., default: _T) -> SupportsLessThanT | _T: ...
1167+
def min(__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., default: _T) -> SupportsRichComparisonT | _T: ...
11701168
@overload
1171-
def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsLessThan], default: _T2) -> _T1 | _T2: ...
1169+
def min(__iterable: Iterable[_T1], *, key: Callable[[_T1], SupportsRichComparison], default: _T2) -> _T1 | _T2: ...
11721170
@overload
11731171
def next(__i: SupportsNext[_T]) -> _T: ...
11741172
@overload
@@ -1379,9 +1377,11 @@ def round(number: SupportsRound[_T], ndigits: SupportsIndex) -> _T: ...
13791377
# for why arg 3 of `setattr` should be annotated with `Any` and not `object`
13801378
def setattr(__obj: object, __name: str, __value: Any) -> None: ...
13811379
@overload
1382-
def sorted(__iterable: Iterable[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> list[SupportsLessThanT]: ...
1380+
def sorted(
1381+
__iterable: Iterable[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...
1382+
) -> list[SupportsRichComparisonT]: ...
13831383
@overload
1384-
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> list[_T]: ...
1384+
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> list[_T]: ...
13851385

13861386
if sys.version_info >= (3, 8):
13871387
@overload

stdlib/collections/__init__.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from _collections_abc import dict_items, dict_keys, dict_values
3-
from _typeshed import Self, SupportsKeysAndGetItem, SupportsLessThan, SupportsLessThanT
3+
from _typeshed import Self, SupportsKeysAndGetItem, SupportsRichComparison, SupportsRichComparisonT
44
from typing import Any, Dict, Generic, NoReturn, Tuple, Type, TypeVar, overload
55
from typing_extensions import SupportsIndex, final
66

@@ -99,9 +99,9 @@ class UserList(MutableSequence[_T]):
9999
def reverse(self) -> None: ...
100100
# All arguments are passed to `list.sort` at runtime, so the signature should be kept in line with `list.sort`.
101101
@overload
102-
def sort(self: UserList[SupportsLessThanT], *, key: None = ..., reverse: bool = ...) -> None: ...
102+
def sort(self: UserList[SupportsRichComparisonT], *, key: None = ..., reverse: bool = ...) -> None: ...
103103
@overload
104-
def sort(self, *, key: Callable[[_T], SupportsLessThan], reverse: bool = ...) -> None: ...
104+
def sort(self, *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> None: ...
105105
def extend(self, other: Iterable[_T]) -> None: ...
106106

107107
_UserStringT = TypeVar("_UserStringT", bound=UserString)

stdlib/functools.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
import types
3-
from _typeshed import SupportsItems, SupportsLessThan
3+
from _typeshed import SupportsAllComparisons, SupportsItems
44
from typing import Any, Callable, Generic, Hashable, Iterable, NamedTuple, Sequence, Sized, Tuple, Type, TypeVar, overload
55
from typing_extensions import final
66

@@ -45,7 +45,7 @@ WRAPPER_UPDATES: Sequence[str]
4545
def update_wrapper(wrapper: _T, wrapped: _AnyCallable, assigned: Sequence[str] = ..., updated: Sequence[str] = ...) -> _T: ...
4646
def wraps(wrapped: _AnyCallable, assigned: Sequence[str] = ..., updated: Sequence[str] = ...) -> Callable[[_T], _T]: ...
4747
def total_ordering(cls: Type[_T]) -> Type[_T]: ...
48-
def cmp_to_key(mycmp: Callable[[_T, _T], int]) -> Callable[[_T], SupportsLessThan]: ...
48+
def cmp_to_key(mycmp: Callable[[_T, _T], int]) -> Callable[[_T], SupportsAllComparisons]: ...
4949

5050
class partial(Generic[_T]):
5151
func: Callable[..., _T]

stdlib/genericpath.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from _typeshed import BytesPath, StrOrBytesPath, StrPath, SupportsLessThanT
2+
from _typeshed import BytesPath, StrOrBytesPath, StrPath, SupportsRichComparisonT
33
from typing import Sequence, Tuple, overload
44
from typing_extensions import Literal
55

@@ -11,9 +11,9 @@ def commonprefix(m: Sequence[StrPath]) -> str: ...
1111
@overload
1212
def commonprefix(m: Sequence[BytesPath]) -> bytes | Literal[""]: ...
1313
@overload
14-
def commonprefix(m: Sequence[list[SupportsLessThanT]]) -> Sequence[SupportsLessThanT]: ...
14+
def commonprefix(m: Sequence[list[SupportsRichComparisonT]]) -> Sequence[SupportsRichComparisonT]: ...
1515
@overload
16-
def commonprefix(m: Sequence[Tuple[SupportsLessThanT, ...]]) -> Sequence[SupportsLessThanT]: ...
16+
def commonprefix(m: Sequence[Tuple[SupportsRichComparisonT, ...]]) -> Sequence[SupportsRichComparisonT]: ...
1717
def exists(path: StrOrBytesPath | int) -> bool: ...
1818
def getsize(filename: StrOrBytesPath | int) -> int: ...
1919
def isfile(path: StrOrBytesPath | int) -> bool: ...

stdlib/heapq.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from _typeshed import SupportsLessThan
1+
from _typeshed import SupportsRichComparison
22
from typing import Any, Callable, Iterable, TypeVar
33

44
_T = TypeVar("_T")
@@ -9,6 +9,6 @@ def heappushpop(__heap: list[_T], __item: _T) -> _T: ...
99
def heapify(__heap: list[Any]) -> None: ...
1010
def heapreplace(__heap: list[_T], __item: _T) -> _T: ...
1111
def merge(*iterables: Iterable[_T], key: Callable[[_T], Any] | None = ..., reverse: bool = ...) -> Iterable[_T]: ...
12-
def nlargest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsLessThan] | None = ...) -> list[_T]: ...
13-
def nsmallest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsLessThan] | None = ...) -> list[_T]: ...
12+
def nlargest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsRichComparison] | None = ...) -> list[_T]: ...
13+
def nsmallest(n: int, iterable: Iterable[_T], key: Callable[[_T], SupportsRichComparison] | None = ...) -> list[_T]: ...
1414
def _heapify_max(__x: list[Any]) -> None: ... # undocumented

stdlib/statistics.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from _typeshed import SupportsLessThanT
2+
from _typeshed import SupportsRichComparisonT
33
from decimal import Decimal
44
from fractions import Fraction
55
from typing import Any, Hashable, Iterable, NamedTuple, Sequence, SupportsFloat, Type, TypeVar, Union
@@ -27,8 +27,8 @@ else:
2727
def harmonic_mean(data: Iterable[_NumberT]) -> _NumberT: ...
2828

2929
def median(data: Iterable[_NumberT]) -> _NumberT: ...
30-
def median_low(data: Iterable[SupportsLessThanT]) -> SupportsLessThanT: ...
31-
def median_high(data: Iterable[SupportsLessThanT]) -> SupportsLessThanT: ...
30+
def median_low(data: Iterable[SupportsRichComparisonT]) -> SupportsRichComparisonT: ...
31+
def median_high(data: Iterable[SupportsRichComparisonT]) -> SupportsRichComparisonT: ...
3232
def median_grouped(data: Iterable[_NumberT], interval: _NumberT = ...) -> _NumberT: ...
3333
def mode(data: Iterable[_HashableT]) -> _HashableT: ...
3434

0 commit comments

Comments
 (0)