Skip to content

Improve annotations for various set methods #11403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions stdlib/_weakrefset.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from collections.abc import Iterable, Iterator, MutableSet
from collections.abc import Iterable, Iterator, MutableSet, Set as AbstractSet
from typing import Any, ClassVar, TypeVar, overload
from typing_extensions import Self

Expand All @@ -17,7 +17,7 @@ class WeakSet(MutableSet[_T]):
@overload
def __init__(self, data: Iterable[_T]) -> None: ...
def add(self, item: _T) -> None: ...
def discard(self, item: _T) -> None: ...
def discard(self, item: _T | None) -> None: ...
def copy(self) -> Self: ...
def remove(self, item: _T) -> None: ...
def update(self, other: Iterable[_T]) -> None: ...
Expand All @@ -26,10 +26,13 @@ class WeakSet(MutableSet[_T]):
def __len__(self) -> int: ...
def __iter__(self) -> Iterator[_T]: ...
def __ior__(self, other: Iterable[_T]) -> Self: ... # type: ignore[override,misc]
def difference(self, other: Iterable[_T]) -> Self: ...
def __sub__(self, other: Iterable[Any]) -> Self: ...
def difference_update(self, other: Iterable[Any]) -> None: ...
def __isub__(self, other: Iterable[Any]) -> Self: ...
def difference(self, other: Iterable[_T | None]) -> Self: ...
@overload # type: ignore[override]
def __sub__(self: AbstractSet[_S | None], other: Iterable[None]) -> WeakSet[_S]: ... # type: ignore[overload-overlap]
@overload
def __sub__(self, other: Iterable[_T | None]) -> Self: ...
def difference_update(self, other: Iterable[_T | None]) -> None: ...
def __isub__(self, other: Iterable[_T | None]) -> Self: ... # type: ignore[misc]
def intersection(self, other: Iterable[_T]) -> Self: ...
def __and__(self, other: Iterable[Any]) -> Self: ...
def intersection_update(self, other: Iterable[Any]) -> None: ...
Expand Down
16 changes: 11 additions & 5 deletions stdlib/builtins.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1166,9 +1166,9 @@ class set(MutableSet[_T]):
def __init__(self, iterable: Iterable[_T], /) -> None: ...
def add(self, element: _T, /) -> None: ...
def copy(self) -> set[_T]: ...
def difference(self, *s: Iterable[Any]) -> set[_T]: ...
def difference_update(self, *s: Iterable[Any]) -> None: ...
def discard(self, element: _T, /) -> None: ...
def difference(self, *s: Iterable[_T | None]) -> set[_T]: ...
def difference_update(self, *s: Iterable[_T | None]) -> None: ...
def discard(self, element: _T | None, /) -> None: ...
Copy link

@Bruntaz Bruntaz Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the correct signature for this should be def discard(self, element: Hashable, /) -> None: ....

set.discard will only raise an exception if an unhashable type is passed into it. Any other input is valid (just if it doesn't match _T, it will do nothing instead of raising like set.remove():

>>> a = {1, 2, 3}
>>> a.discard(1)
>>> a
{2, 3}
>>> a.discard("Hello")
>>> a
{2, 3}
>>> a.discard([])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: unhashable type: 'list'

def intersection(self, *s: Iterable[Any]) -> set[_T]: ...
def intersection_update(self, *s: Iterable[Any]) -> None: ...
def isdisjoint(self, s: Iterable[Any], /) -> bool: ...
Expand All @@ -1186,8 +1186,11 @@ class set(MutableSet[_T]):
def __iand__(self, value: AbstractSet[object], /) -> Self: ...
def __or__(self, value: AbstractSet[_S], /) -> set[_T | _S]: ...
def __ior__(self, value: AbstractSet[_T], /) -> Self: ... # type: ignore[override,misc]
@overload # type: ignore[override]
def __sub__(self: AbstractSet[_S | None], value: AbstractSet[None], /) -> set[_S]: ...
@overload
def __sub__(self, value: AbstractSet[_T | None], /) -> set[_T]: ...
def __isub__(self, value: AbstractSet[object], /) -> Self: ...
def __isub__(self, value: AbstractSet[_T | None], /) -> Self: ... # type: ignore[misc]
def __xor__(self, value: AbstractSet[_S], /) -> set[_T | _S]: ...
def __ixor__(self, value: AbstractSet[_T], /) -> Self: ... # type: ignore[override,misc]
def __le__(self, value: AbstractSet[object], /) -> bool: ...
Expand Down Expand Up @@ -1217,7 +1220,10 @@ class frozenset(AbstractSet[_T_co]):
def __iter__(self) -> Iterator[_T_co]: ...
def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ...
def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...
def __sub__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ...
@overload # type: ignore[override]
def __sub__(self: AbstractSet[_S | None], value: AbstractSet[None], /) -> frozenset[_S]: ...
@overload
def __sub__(self, value: AbstractSet[_T_co | None], /) -> frozenset[_T_co]: ...
def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ...
def __le__(self, value: AbstractSet[object], /) -> bool: ...
def __lt__(self, value: AbstractSet[object], /) -> bool: ...
Expand Down
9 changes: 6 additions & 3 deletions stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,10 @@ class AbstractSet(Collection[_T_co]):
def __ge__(self, other: AbstractSet[Any]) -> bool: ...
def __and__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ...
def __or__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ...
def __sub__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ...
@overload
def __sub__(self: AbstractSet[_S | None], other: AbstractSet[None]) -> AbstractSet[_S]: ...
@overload
def __sub__(self, other: AbstractSet[_T_co | None]) -> AbstractSet[_T_co]: ...
def __xor__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ...
def __eq__(self, other: object) -> bool: ...
def isdisjoint(self, other: Iterable[Any]) -> bool: ...
Expand All @@ -644,15 +647,15 @@ class MutableSet(AbstractSet[_T]):
@abstractmethod
def add(self, value: _T) -> None: ...
@abstractmethod
def discard(self, value: _T) -> None: ...
def discard(self, value: _T | None) -> None: ...
# Mixin methods
def clear(self) -> None: ...
def pop(self) -> _T: ...
def remove(self, value: _T) -> None: ...
def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc]
def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ...
def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc]
def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ...
def __isub__(self, it: AbstractSet[_T | None]) -> typing_extensions.Self: ... # type: ignore[misc]

class MappingView(Sized):
def __init__(self, mapping: Mapping[Any, Any]) -> None: ... # undocumented
Expand Down
24 changes: 24 additions & 0 deletions test_cases/stdlib/builtins/check_set.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import FrozenSet, Set
from typing_extensions import assert_type

# We special case `AbstractSet[None] in set.__sub__ and frozenset.__sub__
# so that it can be used for narrowing `set[T|None]` to `set[T]`
x = {"foo", "bar", None}
y = frozenset(x)
assert_type(x - {None}, Set[str])
assert_type(y - {None}, FrozenSet[str])

# For most other cases of set subtraction, we're pretty restrictive about what's allowed.
# `set[T] - set[S]` is an error, even though it won't cause an exception at runtime,
# as it will always be a useless no-op
{"foo", "bar"} - {1, 2} # type: ignore

# But subtracting set[T|None] from set[T] is allowed, as a convenience;
# this comes up a lot in real-life code:
assert_type({"foo", "bar"} - {"foo", None}, Set[str])
x = {"foo", "bar"}
x.difference_update({"foo", "bar", None})
name: str | None = "foo"
x.discard(name)