Skip to content

Commit 7a2798c

Browse files
committed
Make enum type compatible with union of all enum item literals
For example, consider this enum: ``` class E(Enum): A = 1 B = 1 ``` This PR makes `E` compatible with `Literal[E.A, E.B]`. Also fix mutation of the argument list in `try_contracting_literals_in_union`. This fixes some regressions introduced in #9097.
1 parent 7189a23 commit 7a2798c

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

mypy/subtypes.py

+12
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,18 @@ def _is_subtype(left: Type, right: Type,
121121
ignore_declared_variance=ignore_declared_variance,
122122
ignore_promotions=ignore_promotions)
123123
for item in right.items)
124+
# Recombine rhs literal types, to make an enum type a subtype
125+
# of a union of all enum items as literal types. Only do it if
126+
# the previous check didn't succeed, since recombining can be
127+
# expensive.
128+
if not is_subtype_of_item and isinstance(left, Instance) and left.type.is_enum:
129+
right = UnionType(mypy.typeops.try_contracting_literals_in_union(right.items))
130+
is_subtype_of_item = any(is_subtype(orig_left, item,
131+
ignore_type_params=ignore_type_params,
132+
ignore_pos_arg_names=ignore_pos_arg_names,
133+
ignore_declared_variance=ignore_declared_variance,
134+
ignore_promotions=ignore_promotions)
135+
for item in right.items)
124136
# However, if 'left' is a type variable T, T might also have
125137
# an upper bound which is itself a union. This case will be
126138
# handled below by the SubtypeVisitor. We have to check both

mypy/typeops.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ class Status(Enum):
715715
return typ
716716

717717

718-
def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperType]:
718+
def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
719719
"""Contracts any literal types back into a sum type if possible.
720720
721721
Will replace the first instance of the literal with the sum type and
@@ -724,9 +724,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp
724724
if we call `try_contracting_union(Literal[Color.RED, Color.BLUE, Color.YELLOW])`,
725725
this function will return Color.
726726
"""
727+
proper_types = [get_proper_type(typ) for typ in types]
727728
sum_types = {} # type: Dict[str, Tuple[Set[Any], List[int]]]
728729
marked_for_deletion = set()
729-
for idx, typ in enumerate(types):
730+
for idx, typ in enumerate(proper_types):
730731
if isinstance(typ, LiteralType):
731732
fullname = typ.fallback.type.fullname
732733
if typ.fallback.type.is_enum:
@@ -737,10 +738,10 @@ def try_contracting_literals_in_union(types: List[ProperType]) -> List[ProperTyp
737738
indexes.append(idx)
738739
if not literals:
739740
first, *rest = indexes
740-
types[first] = typ.fallback
741+
proper_types[first] = typ.fallback
741742
marked_for_deletion |= set(rest)
742-
return list(itertools.compress(types, [(i not in marked_for_deletion)
743-
for i in range(len(types))]))
743+
return list(itertools.compress(proper_types, [(i not in marked_for_deletion)
744+
for i in range(len(proper_types))]))
744745

745746

746747
def coerce_to_literal(typ: Type) -> Type:

test-data/unit/check-enum.test

+16
Original file line numberDiff line numberDiff line change
@@ -1333,3 +1333,19 @@ def f(x: Foo):
13331333
reveal_type(x) # N: Revealed type is "Union[Literal[__main__.Foo.B], Literal[__main__.Foo.C]]"
13341334

13351335
[builtins fixtures/bool.pyi]
1336+
1337+
[case testEnumTypeCompatibleWithLiteralUnion]
1338+
from enum import Enum
1339+
from typing_extensions import Literal
1340+
1341+
class E(Enum):
1342+
A = 1
1343+
B = 2
1344+
C = 3
1345+
1346+
e: E
1347+
a: Literal[E.A, E.B, E.C] = e
1348+
b: Literal[E.A, E.B] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.B]]")
1349+
c: Literal[E.A, E.C] = e # E: Incompatible types in assignment (expression has type "E", variable has type "Union[Literal[E.A], Literal[E.C]]")
1350+
b = a # E: Incompatible types in assignment (expression has type "Union[Literal[E.A], Literal[E.B], Literal[E.C]]", variable has type "Union[Literal[E.A], Literal[E.B]]")
1351+
[builtins fixtures/bool.pyi]

0 commit comments

Comments
 (0)