Skip to content

Commit 717a263

Browse files
authored
stubtest: adjust symtable logic (#16823)
Fixes python/typeshed#11318
1 parent 3838bff commit 717a263

File tree

2 files changed

+48
-29
lines changed

2 files changed

+48
-29
lines changed

mypy/stubtest.py

+30-29
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import collections.abc
1111
import copy
1212
import enum
13+
import functools
1314
import importlib
1415
import importlib.machinery
1516
import inspect
@@ -310,35 +311,23 @@ def _verify_exported_names(
310311
)
311312

312313

313-
def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None:
314-
"""Retrieve the names in the global namespace which are known to be imported.
314+
@functools.lru_cache
315+
def _module_symbol_table(runtime: types.ModuleType) -> symtable.SymbolTable | None:
316+
"""Retrieve the symbol table for the module (or None on failure).
315317
316-
1). Use inspect to retrieve the source code of the module
317-
2). Use symtable to parse the source and retrieve names that are known to be imported
318-
from other modules.
319-
320-
If either of the above steps fails, return `None`.
321-
322-
Note that if a set of names is returned,
323-
it won't include names imported via `from foo import *` imports.
318+
1) Use inspect to retrieve the source code of the module
319+
2) Use symtable to parse the source (and use what symtable knows for its purposes)
324320
"""
325321
try:
326322
source = inspect.getsource(runtime)
327323
except (OSError, TypeError, SyntaxError):
328324
return None
329325

330-
if not source.strip():
331-
# The source code for the module was an empty file,
332-
# no point in parsing it with symtable
333-
return frozenset()
334-
335326
try:
336-
module_symtable = symtable.symtable(source, runtime.__name__, "exec")
327+
return symtable.symtable(source, runtime.__name__, "exec")
337328
except SyntaxError:
338329
return None
339330

340-
return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported())
341-
342331

343332
@verify.register(nodes.MypyFile)
344333
def verify_mypyfile(
@@ -369,25 +358,37 @@ def verify_mypyfile(
369358
if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
370359
}
371360

372-
imported_symbols = _get_imported_symbol_names(runtime)
373-
374361
def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
375362
"""Heuristics to determine whether a name originates from another module."""
376363
obj = getattr(r, attr)
377364
if isinstance(obj, types.ModuleType):
378365
return False
379-
if callable(obj):
380-
# It's highly likely to be a class or a function if it's callable,
381-
# so the __module__ attribute will give a good indication of which module it comes from
366+
367+
symbol_table = _module_symbol_table(r)
368+
if symbol_table is not None:
382369
try:
383-
obj_mod = obj.__module__
384-
except Exception:
370+
symbol = symbol_table.lookup(attr)
371+
except KeyError:
385372
pass
386373
else:
387-
if isinstance(obj_mod, str):
388-
return bool(obj_mod == r.__name__)
389-
if imported_symbols is not None:
390-
return attr not in imported_symbols
374+
if symbol.is_imported():
375+
# symtable says we got this from another module
376+
return False
377+
# But we can't just return True here, because symtable doesn't know about symbols
378+
# that come from `from module import *`
379+
if symbol.is_assigned():
380+
# symtable knows we assigned this symbol in the module
381+
return True
382+
383+
# The __module__ attribute is unreliable for anything except functions and classes,
384+
# but it's our best guess at this point
385+
try:
386+
obj_mod = obj.__module__
387+
except Exception:
388+
pass
389+
else:
390+
if isinstance(obj_mod, str):
391+
return bool(obj_mod == r.__name__)
391392
return True
392393

393394
runtime_public_contents = (

mypy/test/teststubtest.py

+18
Original file line numberDiff line numberDiff line change
@@ -1285,6 +1285,24 @@ def test_missing_no_runtime_all(self) -> Iterator[Case]:
12851285
yield Case(stub="", runtime="from json.scanner import NUMBER_RE", error=None)
12861286
yield Case(stub="", runtime="from string import ascii_letters", error=None)
12871287

1288+
@collect_cases
1289+
def test_missing_no_runtime_all_terrible(self) -> Iterator[Case]:
1290+
yield Case(
1291+
stub="",
1292+
runtime="""
1293+
import sys
1294+
import types
1295+
import __future__
1296+
_m = types.SimpleNamespace()
1297+
_m.annotations = __future__.annotations
1298+
sys.modules["_terrible_stubtest_test_module"] = _m
1299+
1300+
from _terrible_stubtest_test_module import *
1301+
assert annotations
1302+
""",
1303+
error=None,
1304+
)
1305+
12881306
@collect_cases
12891307
def test_non_public_1(self) -> Iterator[Case]:
12901308
yield Case(

0 commit comments

Comments
 (0)