Skip to content

Commit 8e82171

Browse files
authored
Improve type narrowing for walrus operator in conditional statements (#11202)
Authored-by: @kprzybyla <>
1 parent d469295 commit 8e82171

File tree

5 files changed

+184
-27
lines changed

5 files changed

+184
-27
lines changed

mypy/binder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
from mypy.erasetype import remove_instance_last_known_values
1414
from mypy.nodes import Expression, Var, RefExpr
1515
from mypy.literals import Key, literal, literal_hash, subkeys
16-
from mypy.nodes import IndexExpr, MemberExpr, NameExpr
16+
from mypy.nodes import IndexExpr, MemberExpr, AssignmentExpr, NameExpr
1717

1818

19-
BindableExpression = Union[IndexExpr, MemberExpr, NameExpr]
19+
BindableExpression = Union[IndexExpr, MemberExpr, AssignmentExpr, NameExpr]
2020

2121

2222
class Frame:
@@ -136,7 +136,7 @@ def _get(self, key: Key, index: int = -1) -> Optional[Type]:
136136
return None
137137

138138
def put(self, expr: Expression, typ: Type) -> None:
139-
if not isinstance(expr, (IndexExpr, MemberExpr, NameExpr)):
139+
if not isinstance(expr, (IndexExpr, MemberExpr, AssignmentExpr, NameExpr)):
140140
return
141141
if not literal(expr):
142142
return

mypy/checker.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4286,12 +4286,10 @@ def find_isinstance_check_helper(self, node: Expression) -> Tuple[TypeMap, TypeM
42864286
type_map = self.type_map
42874287
if is_true_literal(node):
42884288
return {}, None
4289-
elif is_false_literal(node):
4289+
if is_false_literal(node):
42904290
return None, {}
4291-
elif isinstance(node, CallExpr):
4292-
self._check_for_truthy_type(type_map[node], node)
4293-
if len(node.args) == 0:
4294-
return {}, {}
4291+
4292+
if isinstance(node, CallExpr) and len(node.args) != 0:
42954293
expr = collapse_walrus(node.args[0])
42964294
if refers_to_fullname(node.callee, 'builtins.isinstance'):
42974295
if len(node.args) != 2: # the error will be reported elsewhere
@@ -4472,21 +4470,27 @@ def has_no_custom_eq_checks(t: Type) -> bool:
44724470

44734471
return reduce_conditional_maps(partial_type_maps)
44744472
elif isinstance(node, AssignmentExpr):
4475-
return self.find_isinstance_check_helper(node.target)
4476-
elif isinstance(node, RefExpr):
4477-
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
4478-
# respectively
4479-
vartype = type_map[node]
4480-
self._check_for_truthy_type(vartype, node)
4481-
if_type: Type = true_only(vartype)
4482-
else_type: Type = false_only(vartype)
4483-
ref: Expression = node
4484-
if_map = ({ref: if_type} if not isinstance(get_proper_type(if_type), UninhabitedType)
4485-
else None)
4486-
else_map = ({ref: else_type} if not isinstance(get_proper_type(else_type),
4487-
UninhabitedType)
4488-
else None)
4489-
return if_map, else_map
4473+
if_map = {}
4474+
else_map = {}
4475+
4476+
if_assignment_map, else_assignment_map = self.find_isinstance_check_helper(node.target)
4477+
4478+
if if_assignment_map is not None:
4479+
if_map.update(if_assignment_map)
4480+
if else_assignment_map is not None:
4481+
else_map.update(else_assignment_map)
4482+
4483+
if_condition_map, else_condition_map = self.find_isinstance_check_helper(node.value)
4484+
4485+
if if_condition_map is not None:
4486+
if_map.update(if_condition_map)
4487+
if else_condition_map is not None:
4488+
else_map.update(else_condition_map)
4489+
4490+
return (
4491+
(None if if_assignment_map is None or if_condition_map is None else if_map),
4492+
(None if else_assignment_map is None or else_condition_map is None else else_map),
4493+
)
44904494
elif isinstance(node, OpExpr) and node.op == 'and':
44914495
left_if_vars, left_else_vars = self.find_isinstance_check_helper(node.left)
44924496
right_if_vars, right_else_vars = self.find_isinstance_check_helper(node.right)
@@ -4507,8 +4511,24 @@ def has_no_custom_eq_checks(t: Type) -> bool:
45074511
left, right = self.find_isinstance_check_helper(node.expr)
45084512
return right, left
45094513

4510-
# Not a supported isinstance check
4511-
return {}, {}
4514+
# Restrict the type of the variable to True-ish/False-ish in the if and else branches
4515+
# respectively
4516+
vartype = type_map[node]
4517+
self._check_for_truthy_type(vartype, node)
4518+
if_type = true_only(vartype) # type: Type
4519+
else_type = false_only(vartype) # type: Type
4520+
ref = node # type: Expression
4521+
if_map = (
4522+
{ref: if_type}
4523+
if not isinstance(get_proper_type(if_type), UninhabitedType)
4524+
else None
4525+
)
4526+
else_map = (
4527+
{ref: else_type}
4528+
if not isinstance(get_proper_type(else_type), UninhabitedType)
4529+
else None
4530+
)
4531+
return if_map, else_map
45124532

45134533
def propagate_up_typemap_info(self,
45144534
existing_types: Mapping[Expression, Type],

mypy/literals.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def literal(e: Expression) -> int:
6161
elif isinstance(e, (MemberExpr, UnaryExpr, StarExpr)):
6262
return literal(e.expr)
6363

64+
elif isinstance(e, AssignmentExpr):
65+
return literal(e.target)
66+
6467
elif isinstance(e, IndexExpr):
6568
if literal(e.index) == LITERAL_YES:
6669
return literal(e.base)
@@ -160,8 +163,8 @@ def visit_index_expr(self, e: IndexExpr) -> Optional[Key]:
160163
return ('Index', literal_hash(e.base), literal_hash(e.index))
161164
return None
162165

163-
def visit_assignment_expr(self, e: AssignmentExpr) -> None:
164-
return None
166+
def visit_assignment_expr(self, e: AssignmentExpr) -> Optional[Key]:
167+
return literal_hash(e.target)
165168

166169
def visit_call_expr(self, e: CallExpr) -> None:
167170
return None

test-data/unit/check-python38.test

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,57 @@ reveal_type(z2) # E: Name "z2" is not defined # N: Revealed type is "Any"
384384

385385
[builtins fixtures/isinstancelist.pyi]
386386

387+
[case testWalrusConditionalTypeBinder]
388+
from typing import Union
389+
from typing_extensions import Literal
390+
391+
class Good:
392+
@property
393+
def is_good(self) -> Literal[True]: ...
394+
395+
class Bad:
396+
@property
397+
def is_good(self) -> Literal[False]: ...
398+
399+
def get_thing() -> Union[Good, Bad]: ...
400+
401+
if (thing := get_thing()).is_good:
402+
reveal_type(thing) # N: Revealed type is "__main__.Good"
403+
else:
404+
reveal_type(thing) # N: Revealed type is "__main__.Bad"
405+
[builtins fixtures/property.pyi]
406+
407+
[case testWalrusConditionalTypeCheck]
408+
# flags: --strict-optional
409+
from typing import Optional
410+
411+
maybe_str: Optional[str]
412+
413+
if (is_str := maybe_str is not None):
414+
reveal_type(is_str) # N: Revealed type is "builtins.bool"
415+
reveal_type(maybe_str) # N: Revealed type is "builtins.str"
416+
else:
417+
reveal_type(is_str) # N: Revealed type is "builtins.bool"
418+
reveal_type(maybe_str) # N: Revealed type is "None"
419+
420+
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
421+
[builtins fixtures/bool.pyi]
422+
423+
[case testWalrusConditionalTypeCheck2]
424+
from typing import Optional
425+
426+
maybe_str: Optional[str]
427+
428+
if (x := maybe_str) is not None:
429+
reveal_type(x) # N: Revealed type is "builtins.str"
430+
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
431+
else:
432+
reveal_type(x) # N: Revealed type is "None"
433+
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
434+
435+
reveal_type(maybe_str) # N: Revealed type is "Union[builtins.str, None]"
436+
[builtins fixtures/bool.pyi]
437+
387438
[case testWalrusPartialTypes]
388439
from typing import List
389440

@@ -400,6 +451,77 @@ def check_partial_list() -> None:
400451
reveal_type(z) # N: Revealed type is "builtins.list[builtins.int]"
401452
[builtins fixtures/list.pyi]
402453

454+
[case testWalrusAssignmentAndConditionScopeForLiteral]
455+
# flags: --warn-unreachable
456+
457+
if (x := 0):
458+
reveal_type(x) # E: Statement is unreachable
459+
else:
460+
reveal_type(x) # N: Revealed type is "builtins.int"
461+
462+
reveal_type(x) # N: Revealed type is "builtins.int"
463+
464+
[case testWalrusAssignmentAndConditionScopeForProperty]
465+
# flags: --warn-unreachable
466+
467+
from typing_extensions import Literal
468+
469+
class PropertyWrapper:
470+
@property
471+
def f(self) -> str: ...
472+
@property
473+
def always_false(self) -> Literal[False]: ...
474+
475+
wrapper = PropertyWrapper()
476+
477+
if x := wrapper.f:
478+
reveal_type(x) # N: Revealed type is "builtins.str"
479+
else:
480+
reveal_type(x) # N: Revealed type is "builtins.str"
481+
482+
reveal_type(x) # N: Revealed type is "builtins.str"
483+
484+
if y := wrapper.always_false:
485+
reveal_type(y) # E: Statement is unreachable
486+
else:
487+
reveal_type(y) # N: Revealed type is "Literal[False]"
488+
489+
reveal_type(y) # N: Revealed type is "Literal[False]"
490+
[builtins fixtures/property.pyi]
491+
492+
[case testWalrusAssignmentAndConditionScopeForFunction]
493+
# flags: --warn-unreachable
494+
495+
from typing_extensions import Literal
496+
497+
def f() -> str: ...
498+
499+
if x := f():
500+
reveal_type(x) # N: Revealed type is "builtins.str"
501+
else:
502+
reveal_type(x) # N: Revealed type is "builtins.str"
503+
504+
reveal_type(x) # N: Revealed type is "builtins.str"
505+
506+
def always_false() -> Literal[False]: ...
507+
508+
if y := always_false():
509+
reveal_type(y) # E: Statement is unreachable
510+
else:
511+
reveal_type(y) # N: Revealed type is "Literal[False]"
512+
513+
reveal_type(y) # N: Revealed type is "Literal[False]"
514+
515+
def always_false_with_parameter(x: int) -> Literal[False]: ...
516+
517+
if z := always_false_with_parameter(5):
518+
reveal_type(z) # E: Statement is unreachable
519+
else:
520+
reveal_type(z) # N: Revealed type is "Literal[False]"
521+
522+
reveal_type(z) # N: Revealed type is "Literal[False]"
523+
[builtins fixtures/tuple.pyi]
524+
403525
[case testWalrusExpr]
404526
def func() -> None:
405527
foo = Foo()

test-data/unit/check-unreachable-code.test

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,3 +1378,15 @@ def f(t: T) -> None:
13781378
except BaseException as e:
13791379
pass
13801380
[builtins fixtures/dict.pyi]
1381+
1382+
1383+
[case testUnreachableLiteral]
1384+
# flags: --warn-unreachable
1385+
from typing_extensions import Literal
1386+
1387+
def nope() -> Literal[False]: ...
1388+
1389+
def f() -> None:
1390+
if nope():
1391+
x = 1 # E: Statement is unreachable
1392+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)