Skip to content

Commit 77622f9

Browse files
committed
Add support for narrowing Literals using equality
This pull request (finally) adds support for narrowing expressions using Literal types by equality, instead of just identity. For example, the following "tagged union" pattern is now supported: ```python class Foo(TypedDict): key: Literal["A"] blah: int class Bar(TypedDict): key: Literal["B"] something: str x: Union[Foo, Bar] if x.key == "A": reveal_type(x) # Revealed type is 'Foo' else: reveal_type(x) # Revealed type is 'Bar' ``` Previously, this was possible to do only with Enum Literals and the `is` operator, which is perhaps not very intuitive. The main limitation with this pull request is that it'll perform narrowing only if either the LHS or RHS contains an explicit Literal type somewhere. If this limitation is not present, we end up breaking a decent amount of real-world code -- mostly tests -- that do something like this: ```python def some_test_case() -> None: worker = Worker() # Without the limitation, we narrow 'worker.state' to # Literal['ready'] in this assert... assert worker.state == 'ready' worker.start() # ...which subsequently causes this second assert to narrow # worker.state to <uninhabited>, causing the last line to be # unreachable. assert worker.state == 'running' worker.query() ``` I tried for several weeks to find a more intelligent way around this problem, but everything I tried ended up being either insufficient or super-hacky, so I gave up and went for this brute-force solution. The other main limitation is that we perform narrowing only if both the LHS and RHS do not define custom `__eq__` or `__ne__` methods, but this seems like a more reasonable one to me. Resolves #7944.
1 parent 91d5bf9 commit 77622f9

File tree

6 files changed

+491
-126
lines changed

6 files changed

+491
-126
lines changed

mypy/checker.py

Lines changed: 110 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from typing import (
88
Dict, Set, List, cast, Tuple, TypeVar, Union, Optional, NamedTuple, Iterator, Sequence,
9-
Mapping,
9+
Mapping, Callable
1010
)
1111
from typing_extensions import Final
1212

@@ -50,7 +50,8 @@
5050
erase_def_to_union_or_bound, erase_to_union_or_bound, coerce_to_literal,
5151
try_getting_str_literals_from_type, try_getting_int_literals_from_type,
5252
tuple_fallback, is_singleton_type, try_expanding_enum_to_union,
53-
true_only, false_only, function_type, TypeVarExtractor,
53+
true_only, false_only, function_type, TypeVarExtractor, custom_special_method,
54+
is_literal_type_like,
5455
)
5556
from mypy import message_registry
5657
from mypy.subtypes import (
@@ -3844,20 +3845,59 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
38443845

38453846
partial_type_maps = []
38463847
for operator, expr_indices in simplified_operator_list:
3847-
if operator in {'is', 'is not'}:
3848-
if_map, else_map = self.refine_identity_comparison_expression(
3849-
operands,
3850-
operand_types,
3851-
expr_indices,
3852-
narrowable_operand_indices,
3853-
)
3854-
elif operator in {'==', '!='}:
3855-
if_map, else_map = self.refine_equality_comparison_expression(
3856-
operands,
3857-
operand_types,
3858-
expr_indices,
3859-
narrowable_operand_indices,
3860-
)
3848+
if operator in {'is', 'is not', '==', '!='}:
3849+
# is_valid_target:
3850+
# Controls which types we're allowed to narrow exprs to. Note that
3851+
# we cannot use 'is_literal_type_like' in both cases since doing
3852+
# 'x = 10000 + 1; x is 10001' is not always True in all Python impls.
3853+
#
3854+
# coerce_only_in_literal_context:
3855+
# If true, coerce types into literal types only if one or more of
3856+
# the provided exprs contains an explicit Literal type. This could
3857+
# technically be set to any arbitrary value, but it seems being liberal
3858+
# with narrowing when using 'is' and conservative when using '==' seems
3859+
# to break the least amount of real-world code.
3860+
#
3861+
# should_narrow_by_identity:
3862+
# Set to 'false' only if the user defines custom __eq__ or __ne__ methods
3863+
# that could cause identity-based narrowing to produce invalid results.
3864+
if operator in {'is', 'is not'}:
3865+
is_valid_target = is_singleton_type # type: Callable[[Type], bool]
3866+
coerce_only_in_literal_context = False
3867+
should_narrow_by_identity = True
3868+
else:
3869+
is_valid_target = is_exactly_literal_type
3870+
coerce_only_in_literal_context = True
3871+
3872+
def has_no_custom_eq_checks(t: Type) -> bool:
3873+
return not custom_special_method(t, '__eq__', check_all=False) \
3874+
and not custom_special_method(t, '__ne__', check_all=False)
3875+
expr_types = [operand_types[i] for i in expr_indices]
3876+
should_narrow_by_identity = all(map(has_no_custom_eq_checks, expr_types))
3877+
3878+
if_map = {} # type: TypeMap
3879+
else_map = {} # type: TypeMap
3880+
if should_narrow_by_identity:
3881+
if_map, else_map = self.refine_identity_comparison_expression(
3882+
operands,
3883+
operand_types,
3884+
expr_indices,
3885+
narrowable_operand_indices,
3886+
is_valid_target,
3887+
coerce_only_in_literal_context,
3888+
)
3889+
3890+
# Strictly speaking, we should also skip this check if the objects in the expr
3891+
# chain have custom __eq__ or __ne__ methods. But we (maybe optimistically)
3892+
# assume nobody would actually create a custom objects that considers itself
3893+
# equal to None.
3894+
if if_map == {} and else_map == {}:
3895+
if_map, else_map = self.refine_away_none_in_comparison(
3896+
operands,
3897+
operand_types,
3898+
expr_indices,
3899+
narrowable_operand_indices,
3900+
)
38613901
elif operator in {'in', 'not in'}:
38623902
assert len(expr_indices) == 2
38633903
left_index, right_index = expr_indices
@@ -4100,8 +4140,10 @@ def refine_identity_comparison_expression(self,
41004140
operand_types: List[Type],
41014141
chain_indices: List[int],
41024142
narrowable_operand_indices: Set[int],
4143+
is_valid_target: Callable[[ProperType], bool],
4144+
coerce_only_in_literal_context: bool,
41034145
) -> Tuple[TypeMap, TypeMap]:
4104-
"""Produces conditional type maps refining expressions used in an identity comparison.
4146+
"""Produces conditional type maps refining exprs used in an identity/equality comparison.
41054147
41064148
The 'operands' and 'operand_types' lists should be the full list of operands used
41074149
in the overall comparison expression. The 'chain_indices' list is the list of indices
@@ -4117,49 +4159,65 @@ def refine_identity_comparison_expression(self,
41174159
The 'narrowable_operand_indices' parameter is the set of all indices we are allowed
41184160
to refine the types of: that is, all operands that will potentially be a part of
41194161
the output TypeMaps.
4162+
4163+
Although this function could theoretically try setting the types of the operands
4164+
in the chains to the meet, doing that causes too many issues in real-world code.
4165+
Instead, we use 'is_valid_target' to identify which of the given chain types
4166+
we could plausibly use as the refined type for the expressions in the chain.
4167+
4168+
Similarly, 'coerce_only_in_literal_context' controls whether we should try coercing
4169+
expressions in the chain to a Literal type. Performing this coercion is sometimes
4170+
too aggressive of a narrowing, depending on context.
41204171
"""
4121-
singleton = None # type: Optional[ProperType]
4122-
possible_singleton_indices = []
4172+
should_coerce = True
4173+
if coerce_only_in_literal_context:
4174+
should_coerce = any(is_literal_type_like(operand_types[i]) for i in chain_indices)
4175+
4176+
target = None # type: Optional[Type]
4177+
possible_target_indices = []
41234178
for i in chain_indices:
4124-
coerced_type = coerce_to_literal(operand_types[i])
4125-
if not is_singleton_type(coerced_type):
4179+
expr_type = operand_types[i]
4180+
if should_coerce:
4181+
expr_type = coerce_to_literal(expr_type)
4182+
if not is_valid_target(get_proper_type(expr_type)):
41264183
continue
4127-
if singleton and not is_same_type(singleton, coerced_type):
4128-
# We have multiple disjoint singleton types. So the 'if' branch
4184+
if target and not is_same_type(target, expr_type):
4185+
# We have multiple disjoint target types. So the 'if' branch
41294186
# must be unreachable.
41304187
return None, {}
4131-
singleton = coerced_type
4132-
possible_singleton_indices.append(i)
4188+
target = expr_type
4189+
possible_target_indices.append(i)
41334190

4134-
# There's nothing we can currently infer if none of the operands are singleton types,
4191+
# There's nothing we can currently infer if none of the operands are valid targets,
41354192
# so we end early and infer nothing.
4136-
if singleton is None:
4193+
if target is None:
41374194
return {}, {}
41384195

4139-
# If possible, use an unassignable expression as the singleton.
4140-
# We skip refining the type of the singleton below, so ideally we'd
4196+
# If possible, use an unassignable expression as the target.
4197+
# We skip refining the type of the target below, so ideally we'd
41414198
# want to pick an expression we were going to skip anyways.
41424199
singleton_index = -1
4143-
for i in possible_singleton_indices:
4200+
for i in possible_target_indices:
41444201
if i not in narrowable_operand_indices:
41454202
singleton_index = i
41464203

41474204
# Oh well, give up and just arbitrarily pick the last item.
41484205
if singleton_index == -1:
4149-
singleton_index = possible_singleton_indices[-1]
4206+
singleton_index = possible_target_indices[-1]
41504207

41514208
enum_name = None
4152-
if isinstance(singleton, LiteralType) and singleton.is_enum_literal():
4153-
enum_name = singleton.fallback.type.fullname
4209+
target = get_proper_type(target)
4210+
if isinstance(target, LiteralType) and target.is_enum_literal():
4211+
enum_name = target.fallback.type.fullname
41544212

4155-
target_type = [TypeRange(singleton, is_upper_bound=False)]
4213+
target_type = [TypeRange(target, is_upper_bound=False)]
41564214

41574215
partial_type_maps = []
41584216
for i in chain_indices:
4159-
# If we try refining a singleton against itself, conditional_type_map
4217+
# If we try refining a type against itself, conditional_type_map
41604218
# will end up assuming that the 'else' branch is unreachable. This is
41614219
# typically not what we want: generally the user will intend for the
4162-
# singleton type to be some fixed 'sentinel' value and will want to refine
4220+
# target type to be some fixed 'sentinel' value and will want to refine
41634221
# the other exprs against this one instead.
41644222
if i == singleton_index:
41654223
continue
@@ -4177,17 +4235,16 @@ def refine_identity_comparison_expression(self,
41774235

41784236
return reduce_partial_type_maps(partial_type_maps)
41794237

4180-
def refine_equality_comparison_expression(self,
4181-
operands: List[Expression],
4182-
operand_types: List[Type],
4183-
chain_indices: List[int],
4184-
narrowable_operand_indices: Set[int],
4185-
) -> Tuple[TypeMap, TypeMap]:
4186-
"""Produces conditional type maps refining expressions used in an equality comparison.
4238+
def refine_away_none_in_comparison(self,
4239+
operands: List[Expression],
4240+
operand_types: List[Type],
4241+
chain_indices: List[int],
4242+
narrowable_operand_indices: Set[int],
4243+
) -> Tuple[TypeMap, TypeMap]:
4244+
"""Produces conditional type maps refining away None in an identity/equality chain.
41874245
4188-
For more details, see the docstring of 'refine_equality_comparison' up above.
4189-
The only difference is that this function is for refining equality operations
4190-
(e.g. 'a == b == c') instead of identity ('a is b is c').
4246+
For more details about what the different arguments mean, see the
4247+
docstring of 'refine_identity_comparison_expression' up above.
41914248
"""
41924249
non_optional_types = []
41934250
for i in chain_indices:
@@ -4662,7 +4719,7 @@ def is_literal_enum(type_map: Mapping[Expression, Type], n: Expression) -> bool:
46624719
return False
46634720

46644721
parent_type = get_proper_type(parent_type)
4665-
member_type = coerce_to_literal(member_type)
4722+
member_type = get_proper_type(coerce_to_literal(member_type))
46664723
if not isinstance(parent_type, FunctionLike) or not isinstance(member_type, LiteralType):
46674724
return False
46684725

@@ -5252,3 +5309,9 @@ def has_bool_item(typ: ProperType) -> bool:
52525309
return any(is_named_instance(item, 'builtins.bool')
52535310
for item in typ.items)
52545311
return False
5312+
5313+
5314+
# TODO: why can't we define this as an inline function?
5315+
# Does mypyc not support them?
5316+
def is_exactly_literal_type(t: Type) -> bool:
5317+
return isinstance(get_proper_type(t), LiteralType)

mypy/checkexpr.py

Lines changed: 3 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr,
3333
TypeAliasExpr, BackquoteExpr, EnumCallExpr, TypeAlias, SymbolNode, PlaceholderNode,
3434
ARG_POS, ARG_OPT, ARG_NAMED, ARG_STAR, ARG_STAR2, LITERAL_TYPE, REVEAL_TYPE,
35-
SYMBOL_FUNCBASE_TYPES
3635
)
3736
from mypy.literals import literal
3837
from mypy import nodes
@@ -51,15 +50,16 @@
5150
from mypy import erasetype
5251
from mypy.checkmember import analyze_member_access, type_object_type
5352
from mypy.argmap import ArgTypeExpander, map_actuals_to_formals, map_formals_to_actuals
54-
from mypy.checkstrformat import StringFormatterChecker, custom_special_method
53+
from mypy.checkstrformat import StringFormatterChecker
5554
from mypy.expandtype import expand_type, expand_type_by_instance, freshen_function_type_vars
5655
from mypy.util import split_module_names
5756
from mypy.typevars import fill_typevars
5857
from mypy.visitor import ExpressionVisitor
5958
from mypy.plugin import Plugin, MethodContext, MethodSigContext, FunctionContext
6059
from mypy.typeops import (
6160
tuple_fallback, make_simplified_union, true_only, false_only, erase_to_union_or_bound,
62-
function_type, callable_type, try_getting_str_literals
61+
function_type, callable_type, try_getting_str_literals, custom_special_method,
62+
is_literal_type_like,
6363
)
6464
import mypy.errorcodes as codes
6565

@@ -4196,24 +4196,6 @@ def merge_typevars_in_callables_by_name(
41964196
return output, variables
41974197

41984198

4199-
def is_literal_type_like(t: Optional[Type]) -> bool:
4200-
"""Returns 'true' if the given type context is potentially either a LiteralType,
4201-
a Union of LiteralType, or something similar.
4202-
"""
4203-
t = get_proper_type(t)
4204-
if t is None:
4205-
return False
4206-
elif isinstance(t, LiteralType):
4207-
return True
4208-
elif isinstance(t, UnionType):
4209-
return any(is_literal_type_like(item) for item in t.items)
4210-
elif isinstance(t, TypeVarType):
4211-
return (is_literal_type_like(t.upper_bound)
4212-
or any(is_literal_type_like(item) for item in t.values))
4213-
else:
4214-
return False
4215-
4216-
42174199
def try_getting_literal(typ: Type) -> ProperType:
42184200
"""If possible, get a more precise literal type for a given type."""
42194201
typ = get_proper_type(typ)
@@ -4235,29 +4217,6 @@ def is_expr_literal_type(node: Expression) -> bool:
42354217
return False
42364218

42374219

4238-
def custom_equality_method(typ: Type) -> bool:
4239-
"""Does this type have a custom __eq__() method?"""
4240-
typ = get_proper_type(typ)
4241-
if isinstance(typ, Instance):
4242-
method = typ.type.get('__eq__')
4243-
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
4244-
if method.node.info:
4245-
return not method.node.info.fullname.startswith('builtins.')
4246-
return False
4247-
if isinstance(typ, UnionType):
4248-
return any(custom_equality_method(t) for t in typ.items)
4249-
if isinstance(typ, TupleType):
4250-
return custom_equality_method(tuple_fallback(typ))
4251-
if isinstance(typ, CallableType) and typ.is_type_obj():
4252-
# Look up __eq__ on the metaclass for class objects.
4253-
return custom_equality_method(typ.fallback)
4254-
if isinstance(typ, AnyType):
4255-
# Avoid false positives in uncertain cases.
4256-
return True
4257-
# TODO: support other types (see ExpressionChecker.has_member())?
4258-
return False
4259-
4260-
42614220
def has_bytes_component(typ: Type, py2: bool = False) -> bool:
42624221
"""Is this one of builtin byte types, or a union that contains it?"""
42634222
typ = get_proper_type(typ)

mypy/checkstrformat.py

Lines changed: 3 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
from mypy.types import (
2121
Type, AnyType, TupleType, Instance, UnionType, TypeOfAny, get_proper_type, TypeVarType,
22-
CallableType, LiteralType, get_proper_types
22+
LiteralType, get_proper_types
2323
)
2424
from mypy.nodes import (
2525
StrExpr, BytesExpr, UnicodeExpr, TupleExpr, DictExpr, Context, Expression, StarExpr, CallExpr,
2626
IndexExpr, MemberExpr, TempNode, ARG_POS, ARG_STAR, ARG_NAMED, ARG_STAR2,
27-
SYMBOL_FUNCBASE_TYPES, Decorator, Var, Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
27+
Node, MypyFile, ExpressionStmt, NameExpr, IntExpr
2828
)
2929
import mypy.errorcodes as codes
3030

@@ -35,7 +35,7 @@
3535
from mypy import message_registry
3636
from mypy.messages import MessageBuilder
3737
from mypy.maptype import map_instance_to_supertype
38-
from mypy.typeops import tuple_fallback
38+
from mypy.typeops import custom_special_method
3939
from mypy.subtypes import is_subtype
4040
from mypy.parse import parse
4141

@@ -961,32 +961,3 @@ def has_type_component(typ: Type, fullname: str) -> bool:
961961
elif isinstance(typ, UnionType):
962962
return any(has_type_component(t, fullname) for t in typ.relevant_items())
963963
return False
964-
965-
966-
def custom_special_method(typ: Type, name: str,
967-
check_all: bool = False) -> bool:
968-
"""Does this type have a custom special method such as __format__() or __eq__()?
969-
970-
If check_all is True ensure all items of a union have a custom method, not just some.
971-
"""
972-
typ = get_proper_type(typ)
973-
if isinstance(typ, Instance):
974-
method = typ.type.get(name)
975-
if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)):
976-
if method.node.info:
977-
return not method.node.info.fullname.startswith('builtins.')
978-
return False
979-
if isinstance(typ, UnionType):
980-
if check_all:
981-
return all(custom_special_method(t, name, check_all) for t in typ.items)
982-
return any(custom_special_method(t, name) for t in typ.items)
983-
if isinstance(typ, TupleType):
984-
return custom_special_method(tuple_fallback(typ), name)
985-
if isinstance(typ, CallableType) and typ.is_type_obj():
986-
# Look up __method__ on the metaclass for class objects.
987-
return custom_special_method(typ.fallback, name)
988-
if isinstance(typ, AnyType):
989-
# Avoid false positives in uncertain cases.
990-
return True
991-
# TODO: support other types (see ExpressionChecker.has_member())?
992-
return False

0 commit comments

Comments
 (0)