|
10 | 10 | import collections.abc
|
11 | 11 | import copy
|
12 | 12 | import enum
|
| 13 | +import functools |
13 | 14 | import importlib
|
14 | 15 | import importlib.machinery
|
15 | 16 | import inspect
|
@@ -310,35 +311,23 @@ def _verify_exported_names(
|
310 | 311 | )
|
311 | 312 |
|
312 | 313 |
|
313 |
| -def _get_imported_symbol_names(runtime: types.ModuleType) -> frozenset[str] | None: |
314 |
| - """Retrieve the names in the global namespace which are known to be imported. |
| 314 | +@functools.lru_cache |
| 315 | +def _module_symbol_table(runtime: types.ModuleType) -> symtable.SymbolTable | None: |
| 316 | + """Retrieve the symbol table for the module (or None on failure). |
315 | 317 |
|
316 |
| - 1). Use inspect to retrieve the source code of the module |
317 |
| - 2). Use symtable to parse the source and retrieve names that are known to be imported |
318 |
| - from other modules. |
319 |
| -
|
320 |
| - If either of the above steps fails, return `None`. |
321 |
| -
|
322 |
| - Note that if a set of names is returned, |
323 |
| - it won't include names imported via `from foo import *` imports. |
| 318 | + 1) Use inspect to retrieve the source code of the module |
| 319 | + 2) Use symtable to parse the source (and use what symtable knows for its purposes) |
324 | 320 | """
|
325 | 321 | try:
|
326 | 322 | source = inspect.getsource(runtime)
|
327 | 323 | except (OSError, TypeError, SyntaxError):
|
328 | 324 | return None
|
329 | 325 |
|
330 |
| - if not source.strip(): |
331 |
| - # The source code for the module was an empty file, |
332 |
| - # no point in parsing it with symtable |
333 |
| - return frozenset() |
334 |
| - |
335 | 326 | try:
|
336 |
| - module_symtable = symtable.symtable(source, runtime.__name__, "exec") |
| 327 | + return symtable.symtable(source, runtime.__name__, "exec") |
337 | 328 | except SyntaxError:
|
338 | 329 | return None
|
339 | 330 |
|
340 |
| - return frozenset(sym.get_name() for sym in module_symtable.get_symbols() if sym.is_imported()) |
341 |
| - |
342 | 331 |
|
343 | 332 | @verify.register(nodes.MypyFile)
|
344 | 333 | def verify_mypyfile(
|
@@ -369,25 +358,37 @@ def verify_mypyfile(
|
369 | 358 | if not o.module_hidden and (not is_probably_private(m) or hasattr(runtime, m))
|
370 | 359 | }
|
371 | 360 |
|
372 |
| - imported_symbols = _get_imported_symbol_names(runtime) |
373 |
| - |
374 | 361 | def _belongs_to_runtime(r: types.ModuleType, attr: str) -> bool:
|
375 | 362 | """Heuristics to determine whether a name originates from another module."""
|
376 | 363 | obj = getattr(r, attr)
|
377 | 364 | if isinstance(obj, types.ModuleType):
|
378 | 365 | return False
|
379 |
| - if callable(obj): |
380 |
| - # It's highly likely to be a class or a function if it's callable, |
381 |
| - # so the __module__ attribute will give a good indication of which module it comes from |
| 366 | + |
| 367 | + symbol_table = _module_symbol_table(r) |
| 368 | + if symbol_table is not None: |
382 | 369 | try:
|
383 |
| - obj_mod = obj.__module__ |
384 |
| - except Exception: |
| 370 | + symbol = symbol_table.lookup(attr) |
| 371 | + except KeyError: |
385 | 372 | pass
|
386 | 373 | else:
|
387 |
| - if isinstance(obj_mod, str): |
388 |
| - return bool(obj_mod == r.__name__) |
389 |
| - if imported_symbols is not None: |
390 |
| - return attr not in imported_symbols |
| 374 | + if symbol.is_imported(): |
| 375 | + # symtable says we got this from another module |
| 376 | + return False |
| 377 | + # But we can't just return True here, because symtable doesn't know about symbols |
| 378 | + # that come from `from module import *` |
| 379 | + if symbol.is_assigned(): |
| 380 | + # symtable knows we assigned this symbol in the module |
| 381 | + return True |
| 382 | + |
| 383 | + # The __module__ attribute is unreliable for anything except functions and classes, |
| 384 | + # but it's our best guess at this point |
| 385 | + try: |
| 386 | + obj_mod = obj.__module__ |
| 387 | + except Exception: |
| 388 | + pass |
| 389 | + else: |
| 390 | + if isinstance(obj_mod, str): |
| 391 | + return bool(obj_mod == r.__name__) |
391 | 392 | return True
|
392 | 393 |
|
393 | 394 | runtime_public_contents = (
|
|
0 commit comments