Skip to content

Commit c2dc532

Browse files
committed
Fix --strict-equality crash for instances of a class generic over a ParamSpec
1 parent 1feabc8 commit c2dc532

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

mypy/meet.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from itertools import chain
34
from typing import Callable
45

56
from mypy import join
@@ -342,7 +343,15 @@ def _is_overlapping_types(left: Type, right: Type) -> bool:
342343
left_possible = get_possible_variants(left)
343344
right_possible = get_possible_variants(right)
344345

345-
# We start by checking multi-variant types like Unions first. We also perform
346+
# First handle a special case: comparing a `Parameters` to a `ParamSpecType`.
347+
# This should always be considered an overlapping equality check.
348+
# This needs to be done before we move on to other TypeVarLike comparisons.
349+
if (isinstance(left, Parameters) and isinstance(right, ParamSpecType)) or (
350+
isinstance(left, ParamSpecType) and isinstance(right, Parameters)
351+
):
352+
return True
353+
354+
# Now move on to checking multi-variant types like Unions. We also perform
346355
# the same logic if either type happens to be a TypeVar/ParamSpec/TypeVarTuple.
347356
#
348357
# Handling the TypeVarLikes now lets us simulate having them bind to the corresponding
@@ -451,6 +460,24 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
451460
elif isinstance(right, CallableType):
452461
right = right.fallback
453462

463+
if isinstance(left, Parameters):
464+
if not isinstance(right, Parameters):
465+
return False
466+
if len(left.arg_types) == len(right.arg_types):
467+
return all(
468+
_is_overlapping_types(left_arg, right_arg)
469+
for left_arg, right_arg in zip(left.arg_types, right.arg_types)
470+
)
471+
if not any(
472+
isinstance(arg, TypeVarLikeType) for arg in chain(left.arg_types, right.arg_types)
473+
):
474+
return False
475+
# TODO: Is this sound?
476+
return True
477+
if isinstance(right, Parameters):
478+
assert not isinstance(left, (Parameters, TypeVarLikeType))
479+
return False
480+
454481
if isinstance(left, LiteralType) and isinstance(right, LiteralType):
455482
if left.value == right.value:
456483
# If values are the same, we still need to check if fallbacks are overlapping,

test-data/unit/pythoneval.test

+36-2
Original file line numberDiff line numberDiff line change
@@ -1928,11 +1928,45 @@ _testStarUnpackNestedUnderscore.py:16: note: Revealed type is "builtins.list[bui
19281928
[case testStrictEqualitywithParamSpec]
19291929
# flags: --strict-equality
19301930
from typing import Generic
1931-
from typing_extensions import ParamSpec
1931+
from typing_extensions import Concatenate, ParamSpec
19321932

19331933
P = ParamSpec("P")
19341934

19351935
class Foo(Generic[P]): ...
1936+
class Bar(Generic[P]): ...
19361937

1937-
def check(foo1: Foo[[int]], foo2: Foo[[str]]) -> bool:
1938+
def bad1(foo1: Foo[[int]], foo2: Foo[[str]]) -> bool:
19381939
return foo1 == foo2
1940+
1941+
def bad2(foo1: Foo[[int, str]], foo2: Foo[[int, bytes]]) -> bool:
1942+
return foo1 == foo2
1943+
1944+
def bad3(foo1: Foo[[int]], foo2: Foo[[int, int]]) -> bool:
1945+
return foo1 == foo2
1946+
1947+
def bad4(foo: Foo[[int]], bar: Bar[[int]]) -> bool:
1948+
return foo == bar
1949+
1950+
def good1(foo1: Foo[[int]], foo2: Foo[[int]]) -> bool:
1951+
return foo1 == foo2
1952+
1953+
def good2(foo1: Foo[[int]], foo2: Foo[[bool]]) -> bool:
1954+
return foo1 == foo2
1955+
1956+
def good3(foo1: Foo[[int, int]], foo2: Foo[[bool, bool]]) -> bool:
1957+
return foo1 == foo2
1958+
1959+
def good4(foo1: Foo[[int]], foo2: Foo[P], *args: P.args, **kwargs: P.kwargs) -> bool:
1960+
return foo1 == foo2
1961+
1962+
def good5(foo1: Foo[P], foo2: Foo[[int, str, bytes]], *args: P.args, **kwargs: P.kwargs) -> bool:
1963+
return foo1 == foo2
1964+
1965+
def good6(foo1: Foo[Concatenate[int, P]], foo2: Foo[[int, str, bytes]], *args: P.args, **kwargs: P.kwargs) -> bool:
1966+
return foo1 == foo2
1967+
1968+
[out]
1969+
_testStrictEqualitywithParamSpec.py:11: error: Non-overlapping equality check (left operand type: "Foo[[int]]", right operand type: "Foo[[str]]")
1970+
_testStrictEqualitywithParamSpec.py:14: error: Non-overlapping equality check (left operand type: "Foo[[int, str]]", right operand type: "Foo[[int, bytes]]")
1971+
_testStrictEqualitywithParamSpec.py:17: error: Non-overlapping equality check (left operand type: "Foo[[int]]", right operand type: "Foo[[int, int]]")
1972+
_testStrictEqualitywithParamSpec.py:20: error: Non-overlapping equality check (left operand type: "Foo[[int]]", right operand type: "Bar[[int]]")

0 commit comments

Comments
 (0)