Skip to content

Commit df63e3b

Browse files
committed
Narrow booleans to literals with identity check
1 parent 389a172 commit df63e3b

File tree

3 files changed

+56
-24
lines changed

3 files changed

+56
-24
lines changed

mypy/checker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4536,7 +4536,8 @@ def refine_identity_comparison_expression(self,
45364536

45374537
enum_name = None
45384538
target = get_proper_type(target)
4539-
if isinstance(target, LiteralType) and target.is_enum_literal():
4539+
if (isinstance(target, LiteralType) and
4540+
(target.is_enum_literal() or isinstance(target.value, bool))):
45404541
enum_name = target.fallback.type.fullname
45414542

45424543
target_type = [TypeRange(target, is_upper_bound=False)]

mypy/typeops.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -693,26 +693,32 @@ class Status(Enum):
693693
if isinstance(typ, UnionType):
694694
items = [try_expanding_enum_to_union(item, target_fullname) for item in typ.items]
695695
return make_simplified_union(items, contract_literals=False)
696-
elif isinstance(typ, Instance) and typ.type.is_enum and typ.type.fullname == target_fullname:
697-
new_items = []
698-
for name, symbol in typ.type.names.items():
699-
if not isinstance(symbol.node, Var):
700-
continue
701-
# Skip "_order_" and "__order__", since Enum will remove it
702-
if name in ("_order_", "__order__"):
703-
continue
704-
new_items.append(LiteralType(name, typ))
705-
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
706-
# insertion order only starting with Python 3.7. So, we sort these for older
707-
# versions of Python to help make tests deterministic.
708-
#
709-
# We could probably skip the sort for Python 3.6 since people probably run mypy
710-
# only using CPython, but we might as well for the sake of full correctness.
711-
if sys.version_info < (3, 7):
712-
new_items.sort(key=lambda lit: lit.value)
713-
return make_simplified_union(new_items, contract_literals=False)
714-
else:
715-
return typ
696+
elif isinstance(typ, Instance) and typ.type.fullname == target_fullname:
697+
if typ.type.is_enum:
698+
new_items = []
699+
for name, symbol in typ.type.names.items():
700+
if not isinstance(symbol.node, Var):
701+
continue
702+
# Skip "_order_" and "__order__", since Enum will remove it
703+
if name in ("_order_", "__order__"):
704+
continue
705+
new_items.append(LiteralType(name, typ))
706+
# SymbolTables are really just dicts, and dicts are guaranteed to preserve
707+
# insertion order only starting with Python 3.7. So, we sort these for older
708+
# versions of Python to help make tests deterministic.
709+
#
710+
# We could probably skip the sort for Python 3.6 since people probably run mypy
711+
# only using CPython, but we might as well for the sake of full correctness.
712+
if sys.version_info < (3, 7):
713+
new_items.sort(key=lambda lit: lit.value)
714+
return make_simplified_union(new_items, contract_literals=False)
715+
elif typ.type.fullname == "builtins.bool":
716+
return make_simplified_union(
717+
[LiteralType(True, typ), LiteralType(False, typ)],
718+
contract_literals=False
719+
)
720+
721+
return typ
716722

717723

718724
def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]:
@@ -730,9 +736,12 @@ def try_contracting_literals_in_union(types: Sequence[Type]) -> List[ProperType]
730736
for idx, typ in enumerate(proper_types):
731737
if isinstance(typ, LiteralType):
732738
fullname = typ.fallback.type.fullname
733-
if typ.fallback.type.is_enum:
739+
if typ.fallback.type.is_enum or isinstance(typ.value, bool):
734740
if fullname not in sum_types:
735-
sum_types[fullname] = (set(get_enum_values(typ.fallback)), [])
741+
sum_types[fullname] = (set(get_enum_values(typ.fallback))
742+
if typ.fallback.type.is_enum
743+
else set((True, False)),
744+
[])
736745
literals, indexes = sum_types[fullname]
737746
literals.discard(typ.value)
738747
indexes.append(idx)

test-data/unit/check-narrowing.test

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1026,8 +1026,30 @@ else:
10261026
if str_or_bool_literal is not True and str_or_bool_literal is not False:
10271027
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.str"
10281028
else:
1029-
reveal_type(str_or_bool_literal) # N: Revealed type is "Union[Literal[False], Literal[True]]"
1029+
reveal_type(str_or_bool_literal) # N: Revealed type is "builtins.bool"
1030+
[builtins fixtures/primitives.pyi]
1031+
1032+
[case testNarrowingBooleanIdentityCheck]
1033+
# flags: --strict-optional
1034+
from typing import Optional
1035+
from typing_extensions import Literal
1036+
1037+
bool_val: bool
10301038

1039+
if bool_val is not False:
1040+
reveal_type(bool_val) # N: Revealed type is "Literal[True]"
1041+
else:
1042+
reveal_type(bool_val) # N: Revealed type is "Literal[False]"
1043+
1044+
opt_bool_val: Optional[bool]
1045+
1046+
if opt_bool_val is not None:
1047+
reveal_type(opt_bool_val) # N: Revealed type is "builtins.bool"
1048+
1049+
if opt_bool_val is not False:
1050+
reveal_type(opt_bool_val) # N: Revealed type is "Union[Literal[True], None]"
1051+
else:
1052+
reveal_type(opt_bool_val) # N: Revealed type is "Literal[False]"
10311053
[builtins fixtures/primitives.pyi]
10321054

10331055
[case testNarrowingTypedDictUsingEnumLiteral]

0 commit comments

Comments
 (0)