Skip to content

Commit 9ce8a55

Browse files
committed
Recombine complete union of enum literals into original type (#9063)
1 parent 259e0cf commit 9ce8a55

File tree

3 files changed

+77
-9
lines changed

3 files changed

+77
-9
lines changed

mypy/join.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ def visit_literal_type(self, t: LiteralType) -> ProperType:
310310
if isinstance(self.s, LiteralType):
311311
if t == self.s:
312312
return t
313-
else:
314-
return join_types(self.s.fallback, t.fallback)
313+
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
314+
return mypy.typeops.make_simplified_union([self.s, t])
315+
return join_types(self.s.fallback, t.fallback)
315316
else:
316317
return join_types(self.s, t.fallback)
317318

mypy/typeops.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
since these may assume that MROs are ready.
66
"""
77

8-
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar
8+
from typing import cast, Optional, List, Sequence, Set, Iterable, TypeVar, Dict, Tuple, Any
99
from typing_extensions import Type as TypingType
10+
import itertools
1011
import sys
1112

1213
from mypy.types import (
@@ -313,7 +314,8 @@ def callable_corresponding_argument(typ: CallableType,
313314

314315
def make_simplified_union(items: Sequence[Type],
315316
line: int = -1, column: int = -1,
316-
*, keep_erased: bool = False) -> ProperType:
317+
*, keep_erased: bool = False,
318+
contract_literals: bool = True) -> ProperType:
317319
"""Build union type with redundant union items removed.
318320
319321
If only a single item remains, this may return a non-union type.
@@ -361,6 +363,9 @@ def make_simplified_union(items: Sequence[Type],
361363
items[i] = true_or_false(ti)
362364

363365
simplified_set = [items[i] for i in range(len(items)) if i not in removed]
366+
if contract_literals and any(isinstance(item, LiteralType) for item in simplified_set):
367+
return UnionType.make_union(
368+
try_contracting_literals_in_union(simplified_set), line, column)
364369
return UnionType.make_union(simplified_set, line, column)
365370

366371

@@ -637,7 +642,7 @@ class Status(Enum):
637642

638643
if isinstance(typ, UnionType):
639644
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
640-
return make_simplified_union(items)
645+
return make_simplified_union(items, contract_literals=False)
641646
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
642647
new_items = []
643648
for name, symbol in typ.type.names.items():
@@ -655,11 +660,39 @@ class Status(Enum):
655660
# only using CPython, but we might as well for the sake of full correctness.
656661
if sys.version_info < (3, 7):
657662
new_items.sort(key=lambda lit: lit.value)
658-
return make_simplified_union(new_items)
663+
return make_simplified_union(new_items, contract_literals=False)
659664
else:
660665
return typ
661666

662667

668+
def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:
669+
"""Contracts any literal types back into a sum type if possible.
670+
671+
Will replace the first instance of the literal with the sum type and
672+
remove all others.
673+
674+
if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`,
675+
this function will return Color.
676+
"""
677+
sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]]
678+
marked_for_deletion = set()
679+
for idx, typ in enumerate(types):
680+
if isinstance(typ, LiteralType):
681+
fullname = typ.fallback.type.fullname
682+
if typ.fallback.type.is_enum:
683+
if fullname not in sum_types:
684+
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
685+
literals, indexes = sum_types[fullname]
686+
literals.discard(typ.value)
687+
indexes.append(idx)
688+
if not literals:
689+
first, *rest = indexes
690+
types[first] = typ.fallback
691+
marked_for_deletion |= set(rest)
692+
return list(itertools.compress(types, [(i not in marked_for_deletion)
693+
for i in range(len(types))]))
694+
695+
663696
def coerce_to_literal(typ: Type) -> Type:
664697
"""Recursively converts any Instances that have a last_known_value or are
665698
instances of enum types with a single value into the corresponding LiteralType.

test-data/unit/check-enum.test

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,7 @@ elif x is Foo.C:
632632
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
633633
else:
634634
reveal_type(x) # No output here: this branch is unreachable
635+
reveal_type(y) # N: Revealed type is '__main__.Foo'
635636

636637
if Foo.A is x:
637638
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -641,6 +642,7 @@ elif Foo.C is x:
641642
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.C]'
642643
else:
643644
reveal_type(x) # No output here: this branch is unreachable
645+
reveal_type(y) # N: Revealed type is '__main__.Foo'
644646

645647
y: Foo
646648
if y is Foo.A:
@@ -651,6 +653,7 @@ elif y is Foo.C:
651653
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
652654
else:
653655
reveal_type(y) # No output here: this branch is unreachable
656+
reveal_type(y) # N: Revealed type is '__main__.Foo'
654657

655658
if Foo.A is y:
656659
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -660,6 +663,7 @@ elif Foo.C is y:
660663
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.C]'
661664
else:
662665
reveal_type(y) # No output here: this branch is unreachable
666+
reveal_type(y) # N: Revealed type is '__main__.Foo'
663667
[builtins fixtures/bool.pyi]
664668

665669
[case testEnumReachabilityChecksWithOrdering]
@@ -734,12 +738,14 @@ if x is y:
734738
else:
735739
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
736740
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
741+
reveal_type(x) # N: Revealed type is '__main__.Foo'
737742
if y is x:
738743
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
739744
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
740745
else:
741746
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
742747
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
748+
reveal_type(x) # N: Revealed type is '__main__.Foo'
743749

744750
if x is z:
745751
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -749,6 +755,7 @@ else:
749755
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
750756
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
751757
accepts_foo_a(z)
758+
reveal_type(x) # N: Revealed type is '__main__.Foo'
752759
if z is x:
753760
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
754761
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
@@ -757,6 +764,7 @@ else:
757764
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
758765
reveal_type(z) # N: Revealed type is 'Literal[__main__.Foo.A]?'
759766
accepts_foo_a(z)
767+
reveal_type(x) # N: Revealed type is '__main__.Foo'
760768

761769
if y is z:
762770
reveal_type(y) # N: Revealed type is 'Literal[__main__.Foo.A]'
@@ -828,6 +836,7 @@ if x is Foo.A:
828836
reveal_type(x) # N: Revealed type is 'Literal[__main__.Foo.A]'
829837
else:
830838
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C], None]'
839+
reveal_type(x) # N: Revealed type is 'Union[__main__.Foo, None]'
831840
[builtins fixtures/bool.pyi]
832841

833842
[case testEnumReachabilityWithMultipleEnums]
@@ -847,18 +856,21 @@ if x1 is Foo.A:
847856
reveal_type(x1) # N: Revealed type is 'Literal[__main__.Foo.A]'
848857
else:
849858
reveal_type(x1) # N: Revealed type is 'Union[Literal[__main__.Foo.B], __main__.Bar]'
859+
reveal_type(x1) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
850860

851861
x2: Union[Foo, Bar]
852862
if x2 is Bar.A:
853863
reveal_type(x2) # N: Revealed type is 'Literal[__main__.Bar.A]'
854864
else:
855865
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, Literal[__main__.Bar.B]]'
866+
reveal_type(x2) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
856867

857868
x3: Union[Foo, Bar]
858869
if x3 is Foo.A or x3 is Bar.A:
859870
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.A], Literal[__main__.Bar.A]]'
860871
else:
861872
reveal_type(x3) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Bar.B]]'
873+
reveal_type(x3) # N: Revealed type is 'Union[__main__.Foo, __main__.Bar]'
862874

863875
[builtins fixtures/bool.pyi]
864876

@@ -877,7 +889,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
877889
# E: Unsupported left operand type for + ("Empty") \
878890
# N: Left operand is of type "Union[int, None, Empty]"
879891
if x is _empty:
880-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
892+
reveal_type(x) # N: Revealed type is '__main__.Empty'
881893
return 0
882894
elif x is None:
883895
reveal_type(x) # N: Revealed type is 'None'
@@ -924,7 +936,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
924936
# E: Unsupported left operand type for + ("Empty") \
925937
# N: Left operand is of type "Union[int, None, Empty]"
926938
if x is _empty:
927-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
939+
reveal_type(x) # N: Revealed type is '__main__.Empty'
928940
return 0
929941
elif x is None:
930942
reveal_type(x) # N: Revealed type is 'None'
@@ -953,7 +965,7 @@ def func(x: Union[int, None, Empty] = _empty) -> int:
953965
# E: Unsupported left operand type for + ("Empty") \
954966
# N: Left operand is of type "Union[int, None, Empty]"
955967
if x is _empty:
956-
reveal_type(x) # N: Revealed type is 'Literal[__main__.Empty.token]'
968+
reveal_type(x) # N: Revealed type is '__main__.Empty'
957969
return 0
958970
elif x is None:
959971
reveal_type(x) # N: Revealed type is 'None'
@@ -1162,3 +1174,25 @@ class Comparator(enum.Enum):
11621174

11631175
reveal_type(Comparator.__foo__) # N: Revealed type is 'builtins.dict[builtins.str, builtins.int]'
11641176
[builtins fixtures/dict.pyi]
1177+
1178+
[case testEnumNarrowedToTwoLiterals]
1179+
# Regression test: two literals of an enum would be joined
1180+
# as the full type, regardless of the amount of elements
1181+
# the enum contains.
1182+
from enum import Enum
1183+
from typing import Union
1184+
from typing_extensions import Literal
1185+
1186+
class Foo(Enum):
1187+
A = 1
1188+
B = 2
1189+
C = 3
1190+
1191+
def f(x: Foo):
1192+
if x is Foo.A:
1193+
return x
1194+
if x is Foo.B:
1195+
pass
1196+
reveal_type(x) # N: Revealed type is 'Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]'
1197+
1198+
[builtins fixtures/bool.pyi]

0 commit comments

Comments
 (0)