Skip to content

Commit 8661ea2

Browse files
committed
Handle chained assignment and iterable unpacking assignment.
1 parent ad646b9 commit 8661ea2

File tree

2 files changed

+66
-15
lines changed

2 files changed

+66
-15
lines changed

mypy/semanal.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2357,22 +2357,39 @@ def fail_invalid_classvar(self, context: Context) -> None:
23572357
self.fail('ClassVar can only be used for assignments in class body', context)
23582358

23592359
def process_module_assignment(self, s: AssignmentStmt) -> None:
2360-
"""Check if s assigns a module an alias name; if yes, update symbol table."""
2361-
# TODO support more complex forms of module alias assignment
2362-
# (e.g. `x, y = (mod1, mod2)`) and aliases not in global scope
2363-
if (
2364-
len(s.lvalues) != 1
2365-
or not isinstance(s.lvalues[0], NameExpr)
2366-
or not isinstance(s.rvalue, NameExpr)
2367-
or not self.is_module_scope()
2368-
):
2360+
"""Check if s assigns a module an alias name; if so, update symbol table."""
2361+
# TODO support module alias assignment not in global scope
2362+
if not self.is_module_scope():
23692363
return
2370-
rnode = self.lookup(s.rvalue.name, s)
2371-
if rnode and rnode.kind == MODULE_REF:
2372-
lnode = self.lookup(s.lvalues[0].name, s)
2373-
if lnode:
2374-
lnode.kind = MODULE_REF
2375-
lnode.node = rnode.node
2364+
self._process_module_assignment(s.lvalues, s.rvalue, s)
2365+
2366+
def _process_module_assignment(
2367+
self,
2368+
lvals: List[Expression],
2369+
rval: Expression,
2370+
ctx: AssignmentStmt,
2371+
) -> None:
2372+
"""Propagate module references across assignments.
2373+
2374+
Recursively handles the simple form of iterable unpacking; doesn't
2375+
handle advanced unpacking with *rest.
2376+
2377+
"""
2378+
if all(isinstance(v, (TupleExpr, ListExpr)) for v in lvals + [rval]):
2379+
litemlists = [v.items for v in cast(List[Union[TupleExpr, ListExpr]], lvals)]
2380+
ritems = cast(Union[TupleExpr, ListExpr], rval).items
2381+
for rv, *lvs in zip(ritems, *litemlists):
2382+
self._process_module_assignment(lvs, rv, ctx)
2383+
elif isinstance(rval, NameExpr):
2384+
rnode = self.lookup(rval.name, ctx)
2385+
if rnode and rnode.kind == MODULE_REF:
2386+
for lval in lvals:
2387+
if not isinstance(lval, NameExpr):
2388+
continue
2389+
lnode = self.lookup(lval.name, ctx)
2390+
if lnode:
2391+
lnode.kind = MODULE_REF
2392+
lnode.node = rnode.node
23762393

23772394
def process_enum_call(self, s: AssignmentStmt) -> None:
23782395
"""Check if s defines an Enum; if yes, store the definition in symbol table."""

test-data/unit/check-modules.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,3 +1426,37 @@ reveal_type(m2.a) # E: Revealed type is 'builtins.str'
14261426
a = 'foo'
14271427

14281428
[builtins fixtures/module.pyi]
1429+
1430+
[case testChainedModuleAssignment]
1431+
import m
1432+
m3 = m2 = m
1433+
reveal_type(m2.a) # E: Revealed type is 'builtins.str'
1434+
reveal_type(m3.a) # E: Revealed type is 'builtins.str'
1435+
1436+
[file m.py]
1437+
a = 'foo'
1438+
1439+
[builtins fixtures/module.pyi]
1440+
1441+
[case testMultiModuleAssignment]
1442+
import m, n
1443+
m2, n2, (m3, n3) = m, n, [m, n]
1444+
reveal_type(m2.a) # E: Revealed type is 'builtins.str'
1445+
reveal_type(n2.b) # E: Revealed type is 'builtins.str'
1446+
reveal_type(m3.a) # E: Revealed type is 'builtins.str'
1447+
reveal_type(n3.b) # E: Revealed type is 'builtins.str'
1448+
1449+
[file m.py]
1450+
a = 'foo'
1451+
1452+
[file n.py]
1453+
b = 'bar'
1454+
1455+
[builtins fixtures/module.pyi]
1456+
1457+
[case testMultiAssignment]
1458+
m2, n2 = 'foo', 'bar'
1459+
reveal_type(m2) # E: Revealed type is 'builtins.str'
1460+
reveal_type(n2) # E: Revealed type is 'builtins.str'
1461+
1462+
[builtins fixtures/module.pyi]

0 commit comments

Comments
 (0)