From 5891413accb0b14c8209c3fe17ddaf492d945e3f Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Thu, 28 Nov 2024 16:23:14 +0100 Subject: [PATCH 1/8] Move the join-based narrowing logic towards a union-based narrowing logic. More concretely: use `make_simplified_union` instead of `join_simple` in `ConditionalTypeBinder.update_from_options` --- mypy/binder.py | 17 ++++++++++++++++- test-data/unit/check-narrowing.test | 21 +++++++++++++++++++++ test-data/unit/check-redefine.test | 2 +- 3 files changed, 38 insertions(+), 2 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 52ae9774e6d4..10aa11cc6206 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -10,6 +10,7 @@ from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var from mypy.subtypes import is_same_type, is_subtype +from mypy.typeops import make_simplified_union from mypy.types import ( AnyType, Instance, @@ -237,8 +238,22 @@ def update_from_options(self, frames: list[Frame]) -> bool: type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type) else: for other in resulting_values[1:]: + assert other is not None - type = join_simple(self.declarations[key], type, other.type) + + if ( + isinstance(t1 := get_proper_type(type), TupleType) + and isinstance(t2 := get_proper_type(other.type), TupleType) + and (len(l1 := t1.items) == len(l2 := t2.items)) + and (find_unpack_in_list(l1) is None) + and (find_unpack_in_list(l2) is None) + ): + type = t1.copy_modified( + items=[make_simplified_union([i1, i2]) for i1, i2 in zip(l1, l2)] + ) + else: + type = make_simplified_union([type, other.type]) + # Try simplifying resulting type for unions involving variadic tuples. # Technically, everything is still valid without this step, but if we do # not do this, this may create long unions after exiting an if check like: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 285d56ff7e50..10f5d30e6c23 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2352,3 +2352,24 @@ def fn_while(arg: T) -> None: return None return None [builtins fixtures/primitives.pyi] + + +[case testNarrowingIsinstanceCreatesUnion] + +class A: ... +class B(A): y: int +class C(A): y: int +class D(C): ... +class E(C): ... +class F(C): ... + +def f(x: A): + if isinstance(x, B): ... + elif isinstance(x, D): ... + elif isinstance(x, E): ... + elif isinstance(x, F): ... + else: return + reveal_type(x) # N: Revealed type is "Union[__main__.B, __main__.D, __main__.E, __main__.F]" + reveal_type(x.y) # N: Revealed type is "builtins.int" + +[builtins fixtures/isinstance.pyi] diff --git a/test-data/unit/check-redefine.test b/test-data/unit/check-redefine.test index b7642d30efc8..e162bb73a206 100644 --- a/test-data/unit/check-redefine.test +++ b/test-data/unit/check-redefine.test @@ -321,7 +321,7 @@ def f() -> None: x = 1 if int(): x = '' - reveal_type(x) # N: Revealed type is "builtins.object" + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.int]" x = '' reveal_type(x) # N: Revealed type is "builtins.str" if int(): From 0e9f99fdd0236f0001a8a4b45c9c0aceb4cd7954 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Nov 2024 16:44:07 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/binder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mypy/binder.py b/mypy/binder.py index 10aa11cc6206..3c8116bb135f 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -6,7 +6,6 @@ from typing_extensions import TypeAlias as _TypeAlias from mypy.erasetype import remove_instance_last_known_values -from mypy.join import join_simple from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var from mypy.subtypes import is_same_type, is_subtype From d70a727203d7e509f5e194a362cbaf867d7df21f Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 1 Dec 2024 11:24:22 +0100 Subject: [PATCH 3/8] much softer approach --- mypy/binder.py | 17 +---------------- mypy/join.py | 11 +++++++++-- test-data/unit/check-narrowing.test | 3 +-- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 10aa11cc6206..52ae9774e6d4 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -10,7 +10,6 @@ from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var from mypy.subtypes import is_same_type, is_subtype -from mypy.typeops import make_simplified_union from mypy.types import ( AnyType, Instance, @@ -238,22 +237,8 @@ def update_from_options(self, frames: list[Frame]) -> bool: type = AnyType(TypeOfAny.from_another_any, source_any=declaration_type) else: for other in resulting_values[1:]: - assert other is not None - - if ( - isinstance(t1 := get_proper_type(type), TupleType) - and isinstance(t2 := get_proper_type(other.type), TupleType) - and (len(l1 := t1.items) == len(l2 := t2.items)) - and (find_unpack_in_list(l1) is None) - and (find_unpack_in_list(l2) is None) - ): - type = t1.copy_modified( - items=[make_simplified_union([i1, i2]) for i1, i2 in zip(l1, l2)] - ) - else: - type = make_simplified_union([type, other.type]) - + type = join_simple(self.declarations[key], type, other.type) # Try simplifying resulting type for unions involving variadic tuples. # Technically, everything is still valid without this step, but if we do # not do this, this may create long unions after exiting an if check like: diff --git a/mypy/join.py b/mypy/join.py index 865dd073d081..c25acfc44df7 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -17,6 +17,7 @@ is_protocol_implementation, is_subtype, ) +from mypy.typeops import make_simplified_union from mypy.types import ( AnyType, CallableType, @@ -54,7 +55,8 @@ class InstanceJoiner: - def __init__(self) -> None: + def __init__(self, prefer_union_over_supertype: bool = False) -> None: + self.prefer_union_over_supertype: bool = prefer_union_over_supertype self.seen_instances: list[tuple[Instance, Instance]] = [] def join_instances(self, t: Instance, s: Instance) -> ProperType: @@ -164,6 +166,9 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: if is_subtype(p, t): return join_types(t, p, self) + if self.prefer_union_over_supertype: + return make_simplified_union([t, s]) + # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with # the longest MRO. Ties are broken by using the earlier base. @@ -224,7 +229,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType: if isinstance(s, UnionType) and not isinstance(t, UnionType): s, t = t, s - value = t.accept(TypeJoinVisitor(s)) + value = t.accept( + TypeJoinVisitor(s, instance_joiner=InstanceJoiner(prefer_union_over_supertype=True)) + ) if declaration is None or is_subtype(value, declaration): return value diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 10f5d30e6c23..4ad6093b5aad 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2353,8 +2353,7 @@ def fn_while(arg: T) -> None: return None [builtins fixtures/primitives.pyi] - -[case testNarrowingIsinstanceCreatesUnion] +[case testNarrowingInstancesCreatesUnion] class A: ... class B(A): y: int From bf8e8b1056e2b2b3f2cbf1309ad88ffd9624a35a Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Sun, 1 Dec 2024 12:53:42 +0100 Subject: [PATCH 4/8] fix --- mypy/binder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy/binder.py b/mypy/binder.py index 2517d1bacc11..52ae9774e6d4 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -6,6 +6,7 @@ from typing_extensions import TypeAlias as _TypeAlias from mypy.erasetype import remove_instance_last_known_values +from mypy.join import join_simple from mypy.literals import Key, literal, literal_hash, subkeys from mypy.nodes import Expression, IndexExpr, MemberExpr, NameExpr, RefExpr, TypeInfo, Var from mypy.subtypes import is_same_type, is_subtype From 31a04014fac592ad0056a74d9d04c74c9642ab9c Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 2 Dec 2024 06:17:44 +0100 Subject: [PATCH 5/8] module import --- mypy/join.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mypy/join.py b/mypy/join.py index c25acfc44df7..8f88fead503b 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -17,7 +17,6 @@ is_protocol_implementation, is_subtype, ) -from mypy.typeops import make_simplified_union from mypy.types import ( AnyType, CallableType, @@ -167,7 +166,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType: return join_types(t, p, self) if self.prefer_union_over_supertype: - return make_simplified_union([t, s]) + return mypy.typeops.make_simplified_union([t, s]) # Compute the "best" supertype of t when joined with s. # The definition of "best" may evolve; for now it is the one with From 121d8e1e64676f1606e3f0f94273221d205dd61d Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 2 Dec 2024 06:27:01 +0100 Subject: [PATCH 6/8] add error message to assert --- mypy/join.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/join.py b/mypy/join.py index 8f88fead503b..ca67d815538b 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -612,7 +612,7 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: fallback = self.instance_joiner.join_instances( mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t) ) - assert isinstance(fallback, Instance) + assert isinstance(fallback, Instance), f"s = {self.s}, t = {t}, f = {fallback}" items = self.join_tuples(self.s, t) if items is not None: return TupleType(items, fallback) From 2c1b1a10b0107b4ab2f566b0cfd0e91e148932d2 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 2 Dec 2024 22:30:19 +0100 Subject: [PATCH 7/8] TupleType fallbacks as Instances, not Unions --- mypy/join.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mypy/join.py b/mypy/join.py index ca67d815538b..867ee636997b 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -607,12 +607,17 @@ def visit_tuple_type(self, t: TupleType) -> ProperType: # * Joining with any Sequence also returns a Sequence: # Tuple[int, bool] + List[bool] becomes Sequence[int] if isinstance(self.s, TupleType): + if self.instance_joiner is None: self.instance_joiner = InstanceJoiner() + prefer_union = self.instance_joiner.prefer_union_over_supertype + self.instance_joiner.prefer_union_over_supertype = False fallback = self.instance_joiner.join_instances( mypy.typeops.tuple_fallback(self.s), mypy.typeops.tuple_fallback(t) ) - assert isinstance(fallback, Instance), f"s = {self.s}, t = {t}, f = {fallback}" + assert isinstance(fallback, Instance) + self.instance_joiner.prefer_union_over_supertype = prefer_union + items = self.join_tuples(self.s, t) if items is not None: return TupleType(items, fallback) From 43ed94e692eb56bec169f0cec75e13b5fce466c8 Mon Sep 17 00:00:00 2001 From: Christoph Tyralla Date: Mon, 2 Dec 2024 22:43:52 +0100 Subject: [PATCH 8/8] Add test case that would crash without the previous commit --- test-data/unit/check-narrowing.test | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 4ad6093b5aad..6a8aa5282a49 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2372,3 +2372,35 @@ def f(x: A): reveal_type(x.y) # N: Revealed type is "builtins.int" [builtins fixtures/isinstance.pyi] + +[case testNarrowingDoNotNarrowNamedTupleFallbacksToUnions] + +from typing import List, NamedTuple, Union + +class A(NamedTuple): + x: int +class B(NamedTuple): + x: int + y: int +class C(NamedTuple): + x: int + y: int + z: int + +def f() -> bool: ... + +def g() -> None: + l: List[Union[A, B, C]] + if f(): + assert isinstance(l[0], A) + reveal_type(l[0]) # N: Revealed type is "Tuple[builtins.int, fallback=__main__.A]" + elif f(): + assert isinstance(l[0], B) + reveal_type(l[0]) # N: Revealed type is "Tuple[builtins.int, builtins.int, fallback=__main__.B]" + else: + assert False + reveal_type(l[0]) # N: Revealed type is "Union[Tuple[builtins.int, fallback=__main__.A], \ + Tuple[builtins.int, builtins.int, fallback=__main__.B], \ + Tuple[builtins.int, builtins.int, builtins.int, fallback=__main__.C]]" + +[builtins fixtures/tuple.pyi]