Skip to content

Commit 9e68959

Browse files
authored
Improve for loop index variable type narrowing (python#18014)
Preserve the literal type of index expressions a bit longer (until the next assignment) to support TypedDict lookups. ```py from typing import TypedDict class X(TypedDict): hourly: int daily: int def func(x: X) -> None: for var in ("hourly", "daily"): print(x[var]) ``` Closes python#9230
1 parent 60d1b37 commit 9e68959

File tree

5 files changed

+59
-4
lines changed

5 files changed

+59
-4
lines changed

mypy/checker.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3175,6 +3175,14 @@ def check_assignment(
31753175
# Don't use type binder for definitions of special forms, like named tuples.
31763176
if not (isinstance(lvalue, NameExpr) and lvalue.is_special_form):
31773177
self.binder.assign_type(lvalue, rvalue_type, lvalue_type, False)
3178+
if (
3179+
isinstance(lvalue, NameExpr)
3180+
and isinstance(lvalue.node, Var)
3181+
and lvalue.node.is_inferred
3182+
and lvalue.node.is_index_var
3183+
and lvalue_type is not None
3184+
):
3185+
lvalue.node.type = remove_instance_last_known_values(lvalue_type)
31783186

31793187
elif index_lvalue:
31803188
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
@@ -3184,6 +3192,7 @@ def check_assignment(
31843192
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
31853193
if not (
31863194
inferred.is_final
3195+
or inferred.is_index_var
31873196
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
31883197
):
31893198
rvalue_type = remove_instance_last_known_values(rvalue_type)

mypy/nodes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,7 @@ def is_dynamic(self) -> bool:
969969
"is_classvar",
970970
"is_abstract_var",
971971
"is_final",
972+
"is_index_var",
972973
"final_unset_in_class",
973974
"final_set_in_init",
974975
"explicit_self_type",
@@ -1005,6 +1006,7 @@ class Var(SymbolNode):
10051006
"is_classvar",
10061007
"is_abstract_var",
10071008
"is_final",
1009+
"is_index_var",
10081010
"final_unset_in_class",
10091011
"final_set_in_init",
10101012
"is_suppressed_import",
@@ -1039,6 +1041,7 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
10391041
self.is_settable_property = False
10401042
self.is_classvar = False
10411043
self.is_abstract_var = False
1044+
self.is_index_var = False
10421045
# Set to true when this variable refers to a module we were unable to
10431046
# parse for some reason (eg a silenced module)
10441047
self.is_suppressed_import = False

mypy/semanal.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4225,6 +4225,7 @@ def analyze_lvalue(
42254225
is_final: bool = False,
42264226
escape_comprehensions: bool = False,
42274227
has_explicit_value: bool = False,
4228+
is_index_var: bool = False,
42284229
) -> None:
42294230
"""Analyze an lvalue or assignment target.
42304231
@@ -4235,6 +4236,7 @@ def analyze_lvalue(
42354236
escape_comprehensions: If we are inside a comprehension, set the variable
42364237
in the enclosing scope instead. This implements
42374238
https://www.python.org/dev/peps/pep-0572/#scope-of-the-target
4239+
is_index_var: If lval is the index variable in a for loop
42384240
"""
42394241
if escape_comprehensions:
42404242
assert isinstance(lval, NameExpr), "assignment expression target must be NameExpr"
@@ -4245,6 +4247,7 @@ def analyze_lvalue(
42454247
is_final,
42464248
escape_comprehensions,
42474249
has_explicit_value=has_explicit_value,
4250+
is_index_var=is_index_var,
42484251
)
42494252
elif isinstance(lval, MemberExpr):
42504253
self.analyze_member_lvalue(lval, explicit_type, is_final, has_explicit_value)
@@ -4271,6 +4274,7 @@ def analyze_name_lvalue(
42714274
is_final: bool,
42724275
escape_comprehensions: bool,
42734276
has_explicit_value: bool,
4277+
is_index_var: bool,
42744278
) -> None:
42754279
"""Analyze an lvalue that targets a name expression.
42764280
@@ -4309,7 +4313,9 @@ def analyze_name_lvalue(
43094313

43104314
if (not existing or isinstance(existing.node, PlaceholderNode)) and not outer:
43114315
# Define new variable.
4312-
var = self.make_name_lvalue_var(lvalue, kind, not explicit_type, has_explicit_value)
4316+
var = self.make_name_lvalue_var(
4317+
lvalue, kind, not explicit_type, has_explicit_value, is_index_var
4318+
)
43134319
added = self.add_symbol(name, var, lvalue, escape_comprehensions=escape_comprehensions)
43144320
# Only bind expression if we successfully added name to symbol table.
43154321
if added:
@@ -4361,7 +4367,12 @@ def is_alias_for_final_name(self, name: str) -> bool:
43614367
return existing is not None and is_final_node(existing.node)
43624368

43634369
def make_name_lvalue_var(
4364-
self, lvalue: NameExpr, kind: int, inferred: bool, has_explicit_value: bool
4370+
self,
4371+
lvalue: NameExpr,
4372+
kind: int,
4373+
inferred: bool,
4374+
has_explicit_value: bool,
4375+
is_index_var: bool,
43654376
) -> Var:
43664377
"""Return a Var node for an lvalue that is a name expression."""
43674378
name = lvalue.name
@@ -4380,6 +4391,7 @@ def make_name_lvalue_var(
43804391
v._fullname = name
43814392
v.is_ready = False # Type not inferred yet
43824393
v.has_explicit_value = has_explicit_value
4394+
v.is_index_var = is_index_var
43834395
return v
43844396

43854397
def make_name_lvalue_point_to_existing_def(
@@ -5290,7 +5302,7 @@ def visit_for_stmt(self, s: ForStmt) -> None:
52905302
s.expr.accept(self)
52915303

52925304
# Bind index variables and check if they define new names.
5293-
self.analyze_lvalue(s.index, explicit_type=s.index_type is not None)
5305+
self.analyze_lvalue(s.index, explicit_type=s.index_type is not None, is_index_var=True)
52945306
if s.index_type:
52955307
if self.is_classvar(s.index_type):
52965308
self.fail_invalid_classvar(s.index)

test-data/unit/check-inference.test

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,6 +1238,35 @@ class B: pass
12381238
[builtins fixtures/for.pyi]
12391239
[out]
12401240

1241+
[case testForStatementIndexNarrowing]
1242+
from typing_extensions import TypedDict
1243+
1244+
class X(TypedDict):
1245+
hourly: int
1246+
daily: int
1247+
1248+
x: X
1249+
for a in ("hourly", "daily"):
1250+
reveal_type(a) # N: Revealed type is "Union[Literal['hourly']?, Literal['daily']?]"
1251+
reveal_type(x[a]) # N: Revealed type is "builtins.int"
1252+
reveal_type(a.upper()) # N: Revealed type is "builtins.str"
1253+
c = a
1254+
reveal_type(c) # N: Revealed type is "builtins.str"
1255+
a = "monthly"
1256+
reveal_type(a) # N: Revealed type is "builtins.str"
1257+
a = "yearly"
1258+
reveal_type(a) # N: Revealed type is "builtins.str"
1259+
a = 1 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
1260+
reveal_type(a) # N: Revealed type is "builtins.str"
1261+
d = a
1262+
reveal_type(d) # N: Revealed type is "builtins.str"
1263+
1264+
b: str
1265+
for b in ("hourly", "daily"):
1266+
reveal_type(b) # N: Revealed type is "builtins.str"
1267+
reveal_type(b.upper()) # N: Revealed type is "builtins.str"
1268+
[builtins fixtures/for.pyi]
1269+
12411270

12421271
-- Regression tests
12431272
-- ----------------

test-data/unit/fixtures/for.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ class type: pass
1212
class tuple(Generic[t]):
1313
def __iter__(self) -> Iterator[t]: pass
1414
class function: pass
15+
class ellipsis: pass
1516
class bool: pass
1617
class int: pass # for convenience
17-
class str: pass # for convenience
18+
class str: # for convenience
19+
def upper(self) -> str: ...
1820

1921
class list(Iterable[t], Generic[t]):
2022
def __iter__(self) -> Iterator[t]: pass

0 commit comments

Comments
 (0)