diff --git a/stdlib/_weakrefset.pyi b/stdlib/_weakrefset.pyi index b55318528208..78ea42a91837 100644 --- a/stdlib/_weakrefset.pyi +++ b/stdlib/_weakrefset.pyi @@ -1,5 +1,5 @@ import sys -from collections.abc import Iterable, Iterator, MutableSet +from collections.abc import Iterable, Iterator, MutableSet, Set as AbstractSet from typing import Any, ClassVar, TypeVar, overload from typing_extensions import Self @@ -17,7 +17,7 @@ class WeakSet(MutableSet[_T]): @overload def __init__(self, data: Iterable[_T]) -> None: ... def add(self, item: _T) -> None: ... - def discard(self, item: _T) -> None: ... + def discard(self, item: _T | None) -> None: ... def copy(self) -> Self: ... def remove(self, item: _T) -> None: ... def update(self, other: Iterable[_T]) -> None: ... @@ -26,10 +26,13 @@ class WeakSet(MutableSet[_T]): def __len__(self) -> int: ... def __iter__(self) -> Iterator[_T]: ... def __ior__(self, other: Iterable[_T]) -> Self: ... # type: ignore[override,misc] - def difference(self, other: Iterable[_T]) -> Self: ... - def __sub__(self, other: Iterable[Any]) -> Self: ... - def difference_update(self, other: Iterable[Any]) -> None: ... - def __isub__(self, other: Iterable[Any]) -> Self: ... + def difference(self, other: Iterable[_T | None]) -> Self: ... + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], other: Iterable[None]) -> WeakSet[_S]: ... # type: ignore[overload-overlap] + @overload + def __sub__(self, other: Iterable[_T | None]) -> Self: ... + def difference_update(self, other: Iterable[_T | None]) -> None: ... + def __isub__(self, other: Iterable[_T | None]) -> Self: ... # type: ignore[misc] def intersection(self, other: Iterable[_T]) -> Self: ... def __and__(self, other: Iterable[Any]) -> Self: ... def intersection_update(self, other: Iterable[Any]) -> None: ... diff --git a/stdlib/builtins.pyi b/stdlib/builtins.pyi index 0a6dc57b05b8..bc509b011c2f 100644 --- a/stdlib/builtins.pyi +++ b/stdlib/builtins.pyi @@ -1166,9 +1166,9 @@ class set(MutableSet[_T]): def __init__(self, iterable: Iterable[_T], /) -> None: ... def add(self, element: _T, /) -> None: ... def copy(self) -> set[_T]: ... - def difference(self, *s: Iterable[Any]) -> set[_T]: ... - def difference_update(self, *s: Iterable[Any]) -> None: ... - def discard(self, element: _T, /) -> None: ... + def difference(self, *s: Iterable[_T | None]) -> set[_T]: ... + def difference_update(self, *s: Iterable[_T | None]) -> None: ... + def discard(self, element: _T | None, /) -> None: ... def intersection(self, *s: Iterable[Any]) -> set[_T]: ... def intersection_update(self, *s: Iterable[Any]) -> None: ... def isdisjoint(self, s: Iterable[Any], /) -> bool: ... @@ -1186,8 +1186,11 @@ class set(MutableSet[_T]): def __iand__(self, value: AbstractSet[object], /) -> Self: ... def __or__(self, value: AbstractSet[_S], /) -> set[_T | _S]: ... def __ior__(self, value: AbstractSet[_T], /) -> Self: ... # type: ignore[override,misc] + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], value: AbstractSet[None], /) -> set[_S]: ... + @overload def __sub__(self, value: AbstractSet[_T | None], /) -> set[_T]: ... - def __isub__(self, value: AbstractSet[object], /) -> Self: ... + def __isub__(self, value: AbstractSet[_T | None], /) -> Self: ... # type: ignore[misc] def __xor__(self, value: AbstractSet[_S], /) -> set[_T | _S]: ... def __ixor__(self, value: AbstractSet[_T], /) -> Self: ... # type: ignore[override,misc] def __le__(self, value: AbstractSet[object], /) -> bool: ... @@ -1217,7 +1220,10 @@ class frozenset(AbstractSet[_T_co]): def __iter__(self) -> Iterator[_T_co]: ... def __and__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ... def __or__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ... - def __sub__(self, value: AbstractSet[_T_co], /) -> frozenset[_T_co]: ... + @overload # type: ignore[override] + def __sub__(self: AbstractSet[_S | None], value: AbstractSet[None], /) -> frozenset[_S]: ... + @overload + def __sub__(self, value: AbstractSet[_T_co | None], /) -> frozenset[_T_co]: ... def __xor__(self, value: AbstractSet[_S], /) -> frozenset[_T_co | _S]: ... def __le__(self, value: AbstractSet[object], /) -> bool: ... def __lt__(self, value: AbstractSet[object], /) -> bool: ... diff --git a/stdlib/typing.pyi b/stdlib/typing.pyi index 7c1b171a730b..0c5014883cd7 100644 --- a/stdlib/typing.pyi +++ b/stdlib/typing.pyi @@ -635,7 +635,10 @@ class AbstractSet(Collection[_T_co]): def __ge__(self, other: AbstractSet[Any]) -> bool: ... def __and__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ... def __or__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ... - def __sub__(self, other: AbstractSet[Any]) -> AbstractSet[_T_co]: ... + @overload + def __sub__(self: AbstractSet[_S | None], other: AbstractSet[None]) -> AbstractSet[_S]: ... + @overload + def __sub__(self, other: AbstractSet[_T_co | None]) -> AbstractSet[_T_co]: ... def __xor__(self, other: AbstractSet[_T]) -> AbstractSet[_T_co | _T]: ... def __eq__(self, other: object) -> bool: ... def isdisjoint(self, other: Iterable[Any]) -> bool: ... @@ -644,7 +647,7 @@ class MutableSet(AbstractSet[_T]): @abstractmethod def add(self, value: _T) -> None: ... @abstractmethod - def discard(self, value: _T) -> None: ... + def discard(self, value: _T | None) -> None: ... # Mixin methods def clear(self) -> None: ... def pop(self) -> _T: ... @@ -652,7 +655,7 @@ class MutableSet(AbstractSet[_T]): def __ior__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] def __iand__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... def __ixor__(self, it: AbstractSet[_T]) -> typing_extensions.Self: ... # type: ignore[override,misc] - def __isub__(self, it: AbstractSet[Any]) -> typing_extensions.Self: ... + def __isub__(self, it: AbstractSet[_T | None]) -> typing_extensions.Self: ... # type: ignore[misc] class MappingView(Sized): def __init__(self, mapping: Mapping[Any, Any]) -> None: ... # undocumented diff --git a/test_cases/stdlib/builtins/check_set.py b/test_cases/stdlib/builtins/check_set.py new file mode 100644 index 000000000000..4dd003bf30f5 --- /dev/null +++ b/test_cases/stdlib/builtins/check_set.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import FrozenSet, Set +from typing_extensions import assert_type + +# We special case `AbstractSet[None] in set.__sub__ and frozenset.__sub__ +# so that it can be used for narrowing `set[T|None]` to `set[T]` +x = {"foo", "bar", None} +y = frozenset(x) +assert_type(x - {None}, Set[str]) +assert_type(y - {None}, FrozenSet[str]) + +# For most other cases of set subtraction, we're pretty restrictive about what's allowed. +# `set[T] - set[S]` is an error, even though it won't cause an exception at runtime, +# as it will always be a useless no-op +{"foo", "bar"} - {1, 2} # type: ignore + +# But subtracting set[T|None] from set[T] is allowed, as a convenience; +# this comes up a lot in real-life code: +assert_type({"foo", "bar"} - {"foo", None}, Set[str]) +x = {"foo", "bar"} +x.difference_update({"foo", "bar", None}) +name: str | None = "foo" +x.discard(name)