diff --git a/mypy/typeops.py b/mypy/typeops.py index b26aa8b3ea73..eae6b0c39af5 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -5,7 +5,7 @@ since these may assume that MROs are ready. """ -from typing import cast, Optional, List, Sequence, Set +from typing import cast, Optional, List, Sequence, Set, Dict import sys from mypy.types import ( @@ -300,6 +300,11 @@ def make_simplified_union(items: Sequence[Type], * [int, int] -> int * [int, Any] -> Union[int, Any] (Any types are not simplified away!) * [Any, Any] -> Any + * [Literal[Foo.A], Literal[Foo.B]] -> Foo (assuming Foo is a enum with two variants A and B) + + Note that we only collapse enum literals into the original enum when all literal variants + are present. Since enums are effectively final and there are a fixed number of possible + variants, it's safe to treat those two types as equivalent. Note: This must NOT be used during semantic analysis, since TypeInfos may not be fully initialized. @@ -316,6 +321,8 @@ def make_simplified_union(items: Sequence[Type], from mypy.subtypes import is_proper_subtype + enums_found = {} # type: Dict[str, int] + enum_max_members = {} # type: Dict[str, int] removed = set() # type: Set[int] for i, ti in enumerate(items): if i in removed: continue @@ -327,13 +334,52 @@ def make_simplified_union(items: Sequence[Type], removed.add(j) cbt = cbt or tj.can_be_true cbf = cbf or tj.can_be_false + # if deleted subtypes had more general truthiness, use that if not ti.can_be_true and cbt: - items[i] = true_or_false(ti) + items[i] = ti = true_or_false(ti) elif not ti.can_be_false and cbf: - items[i] = true_or_false(ti) + items[i] = ti = true_or_false(ti) + + # Keep track of all enum Literal types we encounter, in case + # we can coalesce them together + if isinstance(ti, LiteralType) and ti.is_enum_literal(): + enum_name = ti.fallback.type.fullname() + if enum_name not in enum_max_members: + enum_max_members[enum_name] = len(get_enum_values(ti.fallback)) + enums_found[enum_name] = enums_found.get(enum_name, 0) + 1 + if isinstance(ti, Instance) and ti.type.is_enum: + enum_name = ti.type.fullname() + if enum_name not in enum_max_members: + enum_max_members[enum_name] = len(get_enum_values(ti)) + enums_found[enum_name] = enum_max_members[enum_name] + + enums_to_compress = {n for (n, c) in enums_found.items() if c >= enum_max_members[n]} + enums_encountered = set() # type: Set[str] + simplified_set = [] # type: List[ProperType] + for i, item in enumerate(items): + if i in removed: + continue + + # Try seeing if this is an enum or enum literal, and if it's + # one we should be collapsing away. + if isinstance(item, LiteralType): + instance = item.fallback # type: Optional[Instance] + elif isinstance(item, Instance): + instance = item + else: + instance = None + + if instance and instance.type.is_enum: + enum_name = instance.type.fullname() + if enum_name in enums_encountered: + continue + if enum_name in enums_to_compress: + simplified_set.append(instance) + enums_encountered.add(enum_name) + continue + simplified_set.append(item) - simplified_set = [items[i] for i in range(len(items)) if i not in removed] return UnionType.make_union(simplified_set, line, column) @@ -551,7 +597,7 @@ class Status(Enum): if isinstance(typ, UnionType): items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items] - return make_simplified_union(items) + return UnionType.make_union(items) elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname() == target_fullname: new_items = [] for name, symbol in typ.type.names.items(): @@ -566,7 +612,7 @@ class Status(Enum): # only using CPython, but we might as well for the sake of full correctness. if sys.version_info < (3, 7): new_items.sort(key=lambda lit: lit.value) - return make_simplified_union(new_items) + return UnionType.make_union(new_items) else: return typ @@ -578,7 +624,7 @@ def coerce_to_literal(typ: Type) -> ProperType: typ = get_proper_type(typ) if isinstance(typ, UnionType): new_items = [coerce_to_literal(item) for item in typ.items] - return make_simplified_union(new_items) + return UnionType.make_union(new_items) elif isinstance(typ, Instance): if typ.last_known_value: return typ.last_known_value diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 43355392098c..3ae8f59bf72c 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -629,6 +629,7 @@ elif x is Foo.C: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' if Foo.A is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -638,6 +639,7 @@ elif Foo.C is x: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(x) # No output here: this branch is unreachable +reveal_type(x) # N: Revealed type is '__main__.Foo' y: Foo if y is Foo.A: @@ -648,6 +650,7 @@ elif y is Foo.C: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' if Foo.A is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -657,6 +660,7 @@ elif Foo.C is y: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]' else: reveal_type(y) # No output here: this branch is unreachable +reveal_type(y) # N: Revealed type is '__main__.Foo' [builtins fixtures/bool.pyi] [case testEnumReachabilityChecksIndirect] @@ -686,6 +690,8 @@ if y is x: else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' if x is z: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -703,6 +709,8 @@ else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' reveal_type(z) # N: Revealed type is '__main__.Foo*' accepts_foo_a(z) +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(z) # N: Revealed type is '__main__.Foo*' if y is z: reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' @@ -718,6 +726,8 @@ if z is y: else: reveal_type(y) # No output: this branch is unreachable reveal_type(z) # No output: this branch is unreachable +reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]' +reveal_type(z) # N: Revealed type is '__main__.Foo*' [builtins fixtures/bool.pyi] [case testEnumReachabilityNoNarrowingForUnionMessiness] @@ -740,6 +750,8 @@ if x is y: else: reveal_type(x) # N: Revealed type is '__main__.Foo' reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +reveal_type(x) # N: Revealed type is '__main__.Foo' +reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' if y is z: reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' @@ -747,6 +759,8 @@ if y is z: else: reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' +reveal_type(y) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B]]' +reveal_type(z) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]' [builtins fixtures/bool.pyi] [case testEnumReachabilityWithNone] @@ -764,16 +778,19 @@ if x: reveal_type(x) # N: Revealed type is '__main__.Foo' else: reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' if x is not None: reveal_type(x) # N: Revealed type is '__main__.Foo' else: reveal_type(x) # N: Revealed type is 'None' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' if x is Foo.A: reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]' +reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]' [builtins fixtures/bool.pyi] [case testEnumReachabilityWithMultipleEnums] @@ -793,18 +810,21 @@ if x1 is Foo.A: reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]' else: reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]' +reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x2: Union[Foo, Bar] if x2 is Bar.A: reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]' else: reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]' +reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' x3: Union[Foo, Bar] if x3 is Foo.A or x3 is Bar.A: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]' else: reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]' +reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' [builtins fixtures/bool.pyi] @@ -823,7 +843,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -870,7 +890,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -899,7 +919,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int: # E: Unsupported left operand type for + ("Empty") \ # N: Left operand is of type "Union[int, None, Empty]" if x is _empty: - reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]' + reveal_type(x) # N: Revealed type is '__main__.Empty' return 0 elif x is None: reveal_type(x) # N: Revealed type is 'None' @@ -908,3 +928,77 @@ def func(x: Union[int, None, Empty] = _empty) -> int: reveal_type(x) # N: Revealed type is 'builtins.int' return x + 2 [builtins fixtures/primitives.pyi] + +[case testEnumUnionCompression] +from typing import Union +from typing_extensions import Literal +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + C = 3 + +class Bar(Enum): + X = 1 + Y = 2 + +x1: Literal[Foo.A, Foo.B, Foo.B, Foo.B, 1, None] +assert x1 is not None +reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Foo.B], Literal[1]]' + +x2: Literal[1, Foo.A, Foo.B, Foo.C, None] +assert x2 is not None +reveal_type(x2) # N: Revealed type is 'Union[Literal[1], __main__.Foo]' + +x3: Literal[Foo.A, Foo.B, 1, Foo.C, Foo.C, Foo.C, None] +assert x3 is not None +reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, Literal[1]]' + +x4: Literal[Foo.A, Foo.B, Foo.C, Foo.C, Foo.C, None] +assert x4 is not None +reveal_type(x4) # N: Revealed type is '__main__.Foo' + +x5: Union[Literal[Foo.A], Foo, None] +assert x5 is not None +reveal_type(x5) # N: Revealed type is '__main__.Foo' + +x6: Literal[Foo.A, Bar.X, Foo.B, Bar.Y, Foo.C, None] +assert x6 is not None +reveal_type(x6) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]' + +# TODO: We should really simplify this down into just 'Bar' as well. +no_forcing: Literal[Bar.X, Bar.X, Bar.Y] +reveal_type(no_forcing) # N: Revealed type is 'Union[Literal[__main__.Bar.X], Literal[__main__.Bar.X], Literal[__main__.Bar.Y]]' + +[case testEnumUnionCompressionAssignment] +from typing_extensions import Literal +from enum import Enum + +class Foo(Enum): + A = 1 + B = 2 + +class Wrapper1: + def __init__(self, x: object, y: Foo) -> None: + if x: + if y is Foo.A: + pass + else: + pass + self.y = y + else: + self.y = y + reveal_type(self.y) # N: Revealed type is '__main__.Foo' + +class Wrapper2: + def __init__(self, x: object, y: Foo) -> None: + if x: + self.y = y + else: + if y is Foo.A: + pass + else: + pass + self.y = y + reveal_type(self.y) # N: Revealed type is '__main__.Foo'