Skip to content

Commit 1828ba2

Browse files
authored
builtins.sum: Items in the iterable must support addition with int if no start value is given (#8000)
1 parent 7c47324 commit 1828ba2

File tree

3 files changed

+67
-7
lines changed

3 files changed

+67
-7
lines changed

stdlib/_typeshed/__init__.pyi

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ SupportsRichComparisonT = TypeVar("SupportsRichComparisonT", bound=SupportsRichC
6969

7070
# Dunder protocols
7171

72-
class SupportsAdd(Protocol):
73-
def __add__(self, __x: Any) -> Any: ...
72+
class SupportsAdd(Protocol[_T_contra, _T_co]):
73+
def __add__(self, __x: _T_contra) -> _T_co: ...
74+
75+
class SupportsRAdd(Protocol[_T_contra, _T_co]):
76+
def __radd__(self, __x: _T_contra) -> _T_co: ...
7477

7578
class SupportsDivMod(Protocol[_T_contra, _T_co]):
7679
def __divmod__(self, __other: _T_contra) -> _T_co: ...

stdlib/builtins.pyi

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ from _typeshed import (
2020
SupportsKeysAndGetItem,
2121
SupportsLenAndGetItem,
2222
SupportsNext,
23+
SupportsRAdd,
2324
SupportsRDivMod,
2425
SupportsRichComparison,
2526
SupportsRichComparisonT,
@@ -1637,8 +1638,12 @@ def sorted(
16371638
@overload
16381639
def sorted(__iterable: Iterable[_T], *, key: Callable[[_T], SupportsRichComparison], reverse: bool = ...) -> list[_T]: ...
16391640

1640-
_SumT = TypeVar("_SumT", bound=SupportsAdd)
1641-
_SumS = TypeVar("_SumS", bound=SupportsAdd)
1641+
_AddableT1 = TypeVar("_AddableT1", bound=SupportsAdd[Any, Any])
1642+
_AddableT2 = TypeVar("_AddableT2", bound=SupportsAdd[Any, Any])
1643+
1644+
class _SupportsSumWithNoDefaultGiven(SupportsAdd[Any, Any], SupportsRAdd[int, Any], Protocol): ...
1645+
1646+
_SupportsSumNoDefaultT = TypeVar("_SupportsSumNoDefaultT", bound=_SupportsSumWithNoDefaultGiven)
16421647

16431648
# In general, the return type of `x + x` is *not* guaranteed to be the same type as x.
16441649
# However, we can't express that in the stub for `sum()`
@@ -1653,15 +1658,15 @@ else:
16531658
def sum(__iterable: Iterable[bool], __start: int = ...) -> int: ... # type: ignore[misc]
16541659

16551660
@overload
1656-
def sum(__iterable: Iterable[_SumT]) -> _SumT | Literal[0]: ...
1661+
def sum(__iterable: Iterable[_SupportsSumNoDefaultT]) -> _SupportsSumNoDefaultT | Literal[0]: ...
16571662

16581663
if sys.version_info >= (3, 8):
16591664
@overload
1660-
def sum(__iterable: Iterable[_SumT], start: _SumS) -> _SumT | _SumS: ...
1665+
def sum(__iterable: Iterable[_AddableT1], start: _AddableT2) -> _AddableT1 | _AddableT2: ...
16611666

16621667
else:
16631668
@overload
1664-
def sum(__iterable: Iterable[_SumT], __start: _SumS) -> _SumT | _SumS: ...
1669+
def sum(__iterable: Iterable[_AddableT1], __start: _AddableT2) -> _AddableT1 | _AddableT2: ...
16651670

16661671
# The argument to `vars()` has to have a `__dict__` attribute, so can't be annotated with `object`
16671672
# (A "SupportsDunderDict" protocol doesn't work)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# pyright: reportUnnecessaryTypeIgnoreComment=true
2+
3+
from typing import Any, List, Union
4+
from typing_extensions import Literal, assert_type
5+
6+
7+
class Foo:
8+
def __add__(self, other: Any) -> "Foo":
9+
return Foo()
10+
11+
12+
class Bar:
13+
def __radd__(self, other: Any) -> "Bar":
14+
return Bar()
15+
16+
17+
class Baz:
18+
def __add__(self, other: Any) -> "Baz":
19+
return Baz()
20+
21+
def __radd__(self, other: Any) -> "Baz":
22+
return Baz()
23+
24+
25+
assert_type(sum([2, 4]), int)
26+
assert_type(sum([3, 5], 4), int)
27+
28+
assert_type(sum([True, False]), int)
29+
assert_type(sum([True, False], True), int)
30+
31+
assert_type(sum([["foo"], ["bar"]], ["baz"]), List[str])
32+
33+
assert_type(sum([Foo(), Foo()], Foo()), Foo)
34+
assert_type(sum([Baz(), Baz()]), Union[Baz, Literal[0]])
35+
36+
# mypy and pyright infer the types differently for these, so we can't use assert_type
37+
# Just test that no error is emitted for any of these
38+
sum([("foo",), ("bar", "baz")], ()) # mypy: `tuple[str, ...]`; pyright: `tuple[()] | tuple[str] | tuple[str, str]`
39+
sum([5.6, 3.2]) # mypy: `float`; pyright: `float | Literal[0]`
40+
sum([2.5, 5.8], 5) # mypy: `float`; pyright: `float | int`
41+
42+
# These all fail at runtime
43+
sum("abcde") # type: ignore[arg-type]
44+
sum([["foo"], ["bar"]]) # type: ignore[list-item]
45+
sum([("foo",), ("bar", "baz")]) # type: ignore[list-item]
46+
sum([Foo(), Foo()]) # type: ignore[list-item]
47+
sum([Bar(), Bar()], Bar()) # type: ignore[call-overload]
48+
sum([Bar(), Bar()]) # type: ignore[list-item]
49+
50+
# TODO: these pass pyright with the current stubs, but mypy erroneously emits an error:
51+
# sum([3, Fraction(7, 22), complex(8, 0), 9.83])
52+
# sum([3, Decimal('0.98')])

0 commit comments

Comments
 (0)