From 876327cc353f5f3f0ec0ed6e1eea7e0986b8e00e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 13 Nov 2023 19:31:05 +0000 Subject: [PATCH 1/4] Fix crash on strict-equality with recursive types --- mypy/checkexpr.py | 19 +++++++++++++++---- test-data/unit/check-expressions.test | 24 ++++++++++++++++++++++++ test-data/unit/fixtures/list.pyi | 1 + 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index c87d1f6cd31c..ea7ed0a6a669 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3618,6 +3618,7 @@ def dangerous_comparison( left: Type, right: Type, original_container: Type | None = None, + seen_types: set[tuple[Type, Type]] | None = None, *, prefer_literal: bool = True, ) -> bool: @@ -3639,6 +3640,12 @@ def dangerous_comparison( if not self.chk.options.strict_equality: return False + if seen_types is None: + seen_types = set() + if (left, right) in seen_types: + return False + seen_types.add((left, right)) + left, right = get_proper_types((left, right)) # We suppress the error if there is a custom __eq__() method on either @@ -3694,17 +3701,21 @@ def dangerous_comparison( abstract_set = self.chk.lookup_typeinfo("typing.AbstractSet") left = map_instance_to_supertype(left, abstract_set) right = map_instance_to_supertype(right, abstract_set) - return self.dangerous_comparison(left.args[0], right.args[0]) + return self.dangerous_comparison( + left.args[0], right.args[0], seen_types=seen_types + ) elif left.type.has_base("typing.Mapping") and right.type.has_base("typing.Mapping"): # Similar to above: Mapping ignores the classes, it just compares items. abstract_map = self.chk.lookup_typeinfo("typing.Mapping") left = map_instance_to_supertype(left, abstract_map) right = map_instance_to_supertype(right, abstract_map) return self.dangerous_comparison( - left.args[0], right.args[0] - ) or self.dangerous_comparison(left.args[1], right.args[1]) + left.args[0], right.args[0], seen_types=seen_types + ) or self.dangerous_comparison(left.args[1], right.args[1], seen_types=seen_types) elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name: - return self.dangerous_comparison(left.args[0], right.args[0]) + return self.dangerous_comparison( + left.args[0], right.args[0], seen_types=seen_types + ) elif left_name in OVERLAPPING_BYTES_ALLOWLIST and right_name in ( OVERLAPPING_BYTES_ALLOWLIST ): diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 4ac5512580d2..7db74827d4e0 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2378,6 +2378,30 @@ assert a == b [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityWithRecursiveMapTypes] +# flags: --strict-equality +from typing import Dict + +R = Dict[str, R] + +a: R +b: R +assert a == b +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + +[case testStrictEqualityWithRecursiveListTypes] +# flags: --strict-equality +from typing import List, Union + +R = List[Union[str, R]] + +a: R +b: R +assert a == b +[builtins fixtures/list.pyi] +[typing fixtures/typing-full.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any") diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 90fbabe8bc92..3dcdf18b2faa 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -6,6 +6,7 @@ T = TypeVar('T') class object: def __init__(self) -> None: pass + def __eq__(self, other: object) -> bool: pass class type: pass class ellipsis: pass From 33afea80d3de10b98c4fa6bdc02d8b3283ac264f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 14 Nov 2023 20:33:24 +0000 Subject: [PATCH 2/4] Update mypy/checkexpr.py Co-authored-by: Alex Waygood --- mypy/checkexpr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ea7ed0a6a669..da61833bbe5b 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3617,9 +3617,9 @@ def dangerous_comparison( self, left: Type, right: Type, + *, original_container: Type | None = None, seen_types: set[tuple[Type, Type]] | None = None, - *, prefer_literal: bool = True, ) -> bool: """Check for dangerous non-overlapping comparisons like 42 == 'no'. From 80c4d7b29dab1340cc8f41dfa688811ffec95a3f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 15 Nov 2023 00:25:12 +0000 Subject: [PATCH 3/4] Fix another similar crash --- mypy/meet.py | 8 ++++++++ test-data/unit/check-expressions.test | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/mypy/meet.py b/mypy/meet.py index 610185d6bbbf..1c5f45c4502d 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -262,6 +262,7 @@ def is_overlapping_types( ignore_promotions: bool = False, prohibit_none_typevar_overlap: bool = False, ignore_uninhabited: bool = False, + seen_types: set[tuple[Type, Type]] | None = None, ) -> bool: """Can a value of type 'left' also be of type 'right' or vice-versa? @@ -275,6 +276,12 @@ def is_overlapping_types( # A type guard forces the new type even if it doesn't overlap the old. return True + if seen_types is None: + seen_types = set() + if (left, right) in seen_types: + return True + seen_types.add((left, right)) + left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: @@ -287,6 +294,7 @@ def _is_overlapping_types(left: Type, right: Type) -> bool: ignore_promotions=ignore_promotions, prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, ignore_uninhabited=ignore_uninhabited, + seen_types=seen_types, ) # We should never encounter this type. diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 7db74827d4e0..8fe68365e5ac 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2387,6 +2387,10 @@ R = Dict[str, R] a: R b: R assert a == b + +R2 = Dict[int, R2] +c: R2 +assert a == c # E: Non-overlapping equality check (left operand type: "Dict[str, R]", right operand type: "Dict[int, R2]") [builtins fixtures/dict.pyi] [typing fixtures/typing-full.pyi] @@ -2399,6 +2403,10 @@ R = List[Union[str, R]] a: R b: R assert a == b + +R2 = List[Union[int, R2]] +c: R2 +assert a == c [builtins fixtures/list.pyi] [typing fixtures/typing-full.pyi] From 7c216e905347d062944708be885945892ea6dfcd Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 15 Nov 2023 00:51:33 +0000 Subject: [PATCH 4/4] Fix overlapping types --- mypy/meet.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mypy/meet.py b/mypy/meet.py index 1c5f45c4502d..df8b960cdf3f 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -280,21 +280,23 @@ def is_overlapping_types( seen_types = set() if (left, right) in seen_types: return True - seen_types.add((left, right)) + if isinstance(left, TypeAliasType) and isinstance(right, TypeAliasType): + seen_types.add((left, right)) left, right = get_proper_types((left, right)) def _is_overlapping_types(left: Type, right: Type) -> bool: """Encode the kind of overlapping check to perform. - This function mostly exists so we don't have to repeat keyword arguments everywhere.""" + This function mostly exists, so we don't have to repeat keyword arguments everywhere. + """ return is_overlapping_types( left, right, ignore_promotions=ignore_promotions, prohibit_none_typevar_overlap=prohibit_none_typevar_overlap, ignore_uninhabited=ignore_uninhabited, - seen_types=seen_types, + seen_types=seen_types.copy(), ) # We should never encounter this type.