From 68656e8770402abf503d0c70085950449b7e4307 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Thu, 7 Nov 2019 14:58:05 -0800 Subject: [PATCH 1/3] Force enum literals to simplify when inferring unions While working on overhauling https://github.com/python/mypy/pull/7169, I discovered that simply just "deconstructing" enums into unions leads to some false positives in some real-world code. This is an existing problem, but became more prominent as I worked on improving type inference in the above PR. Here's a simplified example of one such problem I ran into: ``` from enum import Enum class Foo(Enum): A = 1 B = 2 class Wrapper: def __init__(self, x: bool, y: Foo) -> None: if x: if y is Foo.A: # 'y' is of type Literal[Foo.A] here pass else: # ...and of type Literal[Foo.B] here pass # We join these two types after the if/else to end up with # Literal[Foo.A, Foo.B] self.y = y else: # ...and so this fails! 'Foo' is not considered a subtype of # 'Literal[Foo.A, Foo.B]' self.y = y ``` I considered three different ways of fixing this: 1. Modify our various type comparison operations (`is_same`, `is_subtype`, `is_proper_subtype`, etc...) to consider `Foo` and `Literal[Foo.A, Foo.B]` equivalent. 2. Modify the 'join' logic so that when we join enum literals, we check and see if we can merge them back into the original class, undoing the "deconstruction". 3. Modify the `make_simplified_union` logic to do the reconstruction instead. I rejected the first two options: the first approach is the most sound one, but seemed complicated to implement. We have a lot of different type comparison operations and attempting to modify them all seems error-prone. I also didn't really like the idea of having two equally valid representations of the same type, and would rather push mypy to always standardize on one, just from a usability point of view. The second option seemed workable but limited to me. Modifying join would fix the specific example above, but I wasn't confident that was the only place we'd needed to patch. So I went with modifying `make_simplified_union` instead. The main disadvantage of this approach is that we still get false positives when working with Unions that come directly from the semantic analysis phase. For example, we still get an error with the following program: x: Literal[Foo.A, Foo.B] y: Foo # Error, we still think 'x' is of type 'Literal[Foo.A, Foo.B]' x = y But I think this is an acceptable tradeoff for now: I can't imagine too many people running into this. But if they do, we can always explore finding a way of simplifying unions after the semantic analysis phase or bite the bullet and implement approach 1. --- mypy/checker.py | 2 +- mypy/typeops.py | 54 ++++++++++++++++-- test-data/unit/check-enum.test | 100 ++++++++++++++++++++++++++++++++- 3 files changed, 148 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 557ceb8a71c0..651aeade9cda 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -48,7 +48,7 @@ from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, - true_only, false_only, function_type, + true_only, false_only, function_type, get_enum_values, ) from mypy import message_registry from mypy.subtypes import ( diff --git a/mypy/typeops.py b/mypy/typeops.py index b26aa8b3ea73..15b793f8c306 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,7 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set +from typing import cast, Optional, List, Sequence, Set, Dict import sys from mypy.types import ( @@ -300,6 +300,11 @@ def make_simplified_union(items: Sequence[Type], * [int, int] -> int * [int, Any] -> Union[int, Any] (Any types are not simplified away!) * [Any, Any] -> Any + * [Literal[Foo.A], Literal[Foo.B]] -> Foo (assuming Foo is a enum with two variants A and B) + + Note that we only collapse enum literals into the original enum when all literal variants + are present. Since enums are effectively final and there are a fixed number of possible + variants, it's safe to treat those two types as equivalent. Note: This must NOT be used during semantic analysis, since TypeInfos may not be fully initialized. @@ -316,6 +321,8 @@ def make_simplified_union(items: Sequence[Type], from mypy.subtypes import is_proper_subtype + enums_found = {} # type: Dict[str, int] + enum_max_members = {} # type: Dict[str, int] removed = set() # type: Set[int] for i, ti in enumerate(items): if i in removed: continue @@ -327,13 +334,52 @@ def make_simplified_union(items: Sequence[Type], removed.add(j) cbt = cbt or tj.can_be_true cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that if not ti.can_be_true and cbt: - items[i] = true_or_false(ti) + items[i] = ti = true_or_false(ti) elif not ti.can_be_false and cbf: - items[i] = true_or_false(ti) + items[i] = ti = true_or_false(ti) + + # Keep track of all enum Literal types we encounter, in case + # we can coalesce them together + if isinstance(ti, LiteralType) and ti.is_enum_literal(): + enum_name = ti.fallback.type.fullname() + if enum_name not in enum_max_members: + enum_max_members[enum_name] = len(get_enum_values(ti.fallback)) + enums_found[enum_name] = enums_found.get(enum_name, 0) + 1 + if isinstance(ti, Instance) and ti.type.is_enum: + enum_name = ti.type.fullname() + if enum_name not in enum_max_members: + enum_max_members[enum_name] = len(get_enum_values(ti)) + enums_found[enum_name] = enum_max_members[enum_name] + + enums_to_compress = {n for (n, c) in enums_found.items() if c >= enum_max_members[n]} + enums_encountered = set() # type: Set[str] + simplified_set = [] # type: List[ProperType] + for i, item in enumerate(items): + if i in removed: + continue + + # Try seeing if this is an enum or enum literal, and if it's + # one we should be collapsing away. + if isinstance(item, LiteralType): + instance = item.fallback # type: Optional[Instance] + elif isinstance(item, Instance): + instance = item + else: + instance = None + + if instance and instance.type.is_enum: + enum_name = instance.type.fullname() + if enum_name in enums_encountered: + continue + if enum_name in enums_to_compress: + simplified_set.append(instance) + enums_encountered.add(enum_name) + continue + simplified_set.append(item) - simplified_set = [items[i] for i in range(len(items)) if i not in removed] return UnionType.make_union(simplified_set, line, column) diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 43355392098c..3ae8f59bf72c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -629,6 +629,7 @@ elif x is Foo.C: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' if Foo.A is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -638,6 +639,7 @@ elif Foo.C is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' y: Foo if y is Foo.A: @@ -648,6 +650,7 @@ elif y is Foo.C: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -657,6 +660,7 @@ elif Foo.C is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' [builtins fixtures/bool.pyi] [case testEnumReachabilityChecksIndirect] @@ -686,6 +690,8 @@ if y is x: else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' if x is z: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -703,6 +709,8 @@ else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(z) # N: Revealed type is '__main__.Foo*' accepts_foo_a(z) +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(z) # N: Revealed type is '__main__.Foo*' if y is z: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -718,6 +726,8 @@ if z is y: else: reveal_type(y) # No output: this branch is unreachable reveal_type(z) # No output: this branch is unreachable +reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(z) # N: Revealed type is '__main__.Foo*' [builtins fixtures/bool.pyi] [case testEnumReachabilityNoNarrowingForUnionMessiness] @@ -740,6 +750,8 @@ if x is y: else: reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' if y is z: reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' @@ -747,6 +759,8 @@ if y is z: else: reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' [builtins fixtures/bool.pyi] [case testEnumReachabilityWithNone] @@ -764,16 +778,19 @@ if x: reveal_type(x) # N: Revealed type is '__main__.Foo' else: reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' if x is not None: reveal_type(x) # N: Revealed type is '__main__.Foo' else: reveal_type(x) # N: Revealed type is 'None' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' if x is Foo.A: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' [builtins fixtures/bool.pyi] [case testEnumReachabilityWithMultipleEnums] @@ -793,18 +810,21 @@ if x1 is Foo.A: reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]' +reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x2: Union[Foo, Bar] if x2 is Bar.A: reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]' else: reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]' +reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x3: Union[Foo, Bar] if x3 is Foo.A or x3 is Bar.A: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]' else: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]' +reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' [builtins fixtures/bool.pyi] @@ -823,7 +843,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -870,7 +890,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -899,7 +919,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -908,3 +928,77 @@ def func(x: Union[int, None, Empty] = _empty) -> int: reveal_type(x) # N: Revealed type is 'builtins.int' return x + 2 [builtins fixtures/primitives.pyi] + +[case testEnumUnionCompression] +from typing import Union +from typing_extensions import Literal +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +class Bar(Enum): + X = 1 + Y = 2 + +x1: Literal[Foo.A, Foo.B, Foo.B, Foo.B, 1, None] +assert x1 is not None +reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[1]]' + +x2: Literal[1, Foo.A, Foo.B, Foo.C, None] +assert x2 is not None +reveal_type(x2) # N: Revealed type is 'Union[Literal[1], __main__.Foo]' + +x3: Literal[Foo.A, Foo.B, 1, Foo.C, Foo.C, Foo.C, None] +assert x3 is not None +reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, Literal[1]]' + +x4: Literal[Foo.A, Foo.B, Foo.C, Foo.C, Foo.C, None] +assert x4 is not None +reveal_type(x4) # N: Revealed type is '__main__.Foo' + +x5: Union[Literal[Foo.A], Foo, None] +assert x5 is not None +reveal_type(x5) # N: Revealed type is '__main__.Foo' + +x6: Literal[Foo.A, Bar.X, Foo.B, Bar.Y, Foo.C, None] +assert x6 is not None +reveal_type(x6) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' + +# TODO: We should really simplify this down into just 'Bar' as well. +no_forcing: Literal[Bar.X, Bar.X, Bar.Y] +reveal_type(no_forcing) # N: Revealed type is 'Union[Literal[__main__.Bar.X], Literal[__main__.Bar.X], Literal[__main__.Bar.Y]]' + +[case testEnumUnionCompressionAssignment] +from typing_extensions import Literal +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +class Wrapper1: + def __init__(self, x: object, y: Foo) -> None: + if x: + if y is Foo.A: + pass + else: + pass + self.y = y + else: + self.y = y + reveal_type(self.y) # N: Revealed type is '__main__.Foo' + +class Wrapper2: + def __init__(self, x: object, y: Foo) -> None: + if x: + self.y = y + else: + if y is Foo.A: + pass + else: + pass + self.y = y + reveal_type(self.y) # N: Revealed type is '__main__.Foo' From d350cce37857f1e342ceb627b7844e44961ca742 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Fri, 8 Nov 2019 17:41:50 -0800 Subject: [PATCH 2/3] Fix bad rebase --- mypy/typeops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mypy/typeops.py b/mypy/typeops.py index 15b793f8c306..eae6b0c39af5 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -597,7 +597,7 @@ class Status(Enum): if isinstance(typ, UnionType): items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] - return make_simplified_union(items) + return UnionType.make_union(items) elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: new_items = [] for name, symbol in typ.type.names.items(): @@ -612,7 +612,7 @@ class Status(Enum): # only using CPython, but we might as well for the sake of full correctness. if sys.version_info < (3, 7): new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items) + return UnionType.make_union(new_items) else: return typ @@ -624,7 +624,7 @@ def coerce_to_literal(typ: Type) -> ProperType: typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] - return make_simplified_union(new_items) + return UnionType.make_union(new_items) elif isinstance(typ, Instance): if typ.last_known_value: return typ.last_known_value From 78cbcc6f17f58d018f3870060045ddb8ad3b3b29 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Sat, 9 Nov 2019 19:39:11 -0800 Subject: [PATCH 3/3] Fix flake8 --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 651aeade9cda..557ceb8a71c0 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -48,7 +48,7 @@ from mypy.typeops import ( map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union, erase_def_to_union_or_bound, erase_to_union_or_bound, - true_only, false_only, function_type, get_enum_values, + true_only, false_only, function_type, ) from mypy import message_registry from mypy.subtypes import (