diff --git a/mypy/build.py b/mypy/build.py index b61178be4ed9..988a201116c0 100644 --- a/mypy/build.py +++ b/mypy/build.py @@ -327,12 +327,22 @@ def default_lib_path(data_dir: str, pyversion: Tuple[int, int], ('data_json', str), # path of .data.json ('suppressed', List[str]), # dependencies that weren't imported ('flags', Optional[List[str]]), # build flags + ('dep_prios', List[int]), ]) -# NOTE: dependencies + suppressed == all unreachable imports; +# NOTE: dependencies + suppressed == all reachable imports; # suppressed contains those reachable imports that were prevented by # --silent-imports or simply not found. +# Priorities used for imports. (Here, top-level includes inside a class.) +# These are used to determine a more predictable order in which the +# nodes in an import cycle are processed. +PRI_HIGH = 5 # top-level "from X import blah" +PRI_MED = 10 # top-level "import X" +PRI_LOW = 20 # either form inside a function +PRI_ALL = 99 # include all priorities + + class BuildManager: """This class holds shared state for building a mypy program. @@ -393,12 +403,13 @@ def __init__(self, data_dir: str, self.missing_modules = set() # type: Set[str] def all_imported_modules_in_file(self, - file: MypyFile) -> List[Tuple[str, int]]: + file: MypyFile) -> List[Tuple[int, str, int]]: """Find all reachable import statements in a file. - Return list of tuples (module id, import line number) for all modules - imported in file. + Return list of tuples (priority, module id, import line number) + for all modules imported in file; lower numbers == higher priority. """ + def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: """Function to correct for relative imports.""" file_id = file.fullname() @@ -413,21 +424,23 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: return new_id - res = [] # type: List[Tuple[str, int]] + res = [] # type: List[Tuple[int, str, int]] for imp in file.imports: if not imp.is_unreachable: if isinstance(imp, Import): + pri = PRI_MED if imp.is_top_level else PRI_LOW for id, _ in imp.ids: - res.append((id, imp.line)) + res.append((pri, id, imp.line)) elif isinstance(imp, ImportFrom): cur_id = correct_rel_imp(imp) pos = len(res) all_are_submodules = True # Also add any imported names that are submodules. + pri = PRI_MED if imp.is_top_level else PRI_LOW for name, __ in imp.names: sub_id = cur_id + '.' + name if self.is_module(sub_id): - res.append((sub_id, imp.line)) + res.append((pri, sub_id, imp.line)) else: all_are_submodules = False # If all imported names are submodules, don't add @@ -436,9 +449,12 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str: # cur_id is also a dependency, and we should # insert it *before* any submodules. if not all_are_submodules: - res.insert(pos, ((cur_id, imp.line))) + pri = PRI_HIGH if imp.is_top_level else PRI_LOW + res.insert(pos, ((pri, cur_id, imp.line))) elif isinstance(imp, ImportAll): - res.append((correct_rel_imp(imp), imp.line)) + pri = PRI_HIGH if imp.is_top_level else PRI_LOW + res.append((pri, correct_rel_imp(imp), imp.line)) + return res def is_module(self, id: str) -> bool: @@ -754,16 +770,18 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[Cache data_json, meta.get('suppressed', []), meta.get('flags'), + meta.get('dep_prios', []), ) if (m.id != id or m.path != path or m.mtime is None or m.size is None or m.dependencies is None or m.data_mtime is None): return None - # Metadata generated by older mypy version and no flags were saved - if m.flags is None: + # Ignore cache if generated by an older mypy version. + if m.flags is None or len(m.dependencies) != len(m.dep_prios): return None + # Ignore cache if (relevant) flags aren't the same. cached_flags = select_flags_affecting_cache(m.flags) current_flags = select_flags_affecting_cache(manager.flags) if cached_flags != current_flags: @@ -802,6 +820,7 @@ def random_string(): def write_cache(id: str, path: str, tree: MypyFile, dependencies: List[str], suppressed: List[str], + dep_prios: List[int], manager: BuildManager) -> None: """Write cache files for a module. @@ -811,6 +830,7 @@ def write_cache(id: str, path: str, tree: MypyFile, tree: the fully checked module data dependencies: module IDs on which this module depends suppressed: module IDs which were suppressed as dependencies + dep_prios: priorities (parallel array to dependencies) manager: the build manager (for pyversion, log/trace) """ path = os.path.abspath(path) @@ -840,6 +860,7 @@ def write_cache(id: str, path: str, tree: MypyFile, 'dependencies': dependencies, 'suppressed': suppressed, 'flags': manager.flags, + 'dep_prios': dep_prios, } with open(meta_json_tmp, 'w') as f: json.dump(meta, f, sort_keys=True) @@ -1012,6 +1033,7 @@ class State: tree = None # type: Optional[MypyFile] dependencies = None # type: List[str] suppressed = None # type: List[str] # Suppressed/missing dependencies + priorities = None # type: Dict[str, int] # Map each dependency to the line number where it is first imported dep_line_map = None # type: Dict[str, int] @@ -1114,6 +1136,9 @@ def __init__(self, # compare them to the originals later. self.dependencies = list(self.meta.dependencies) self.suppressed = list(self.meta.suppressed) + assert len(self.meta.dependencies) == len(self.meta.dep_prios) + self.priorities = {id: pri + for id, pri in zip(self.meta.dependencies, self.meta.dep_prios)} self.dep_line_map = {} else: # Parse the file (and then some) to get the dependencies. @@ -1249,8 +1274,10 @@ def parse_file(self) -> None: # Also keep track of each dependency's source line. dependencies = [] suppressed = [] + priorities = {} # type: Dict[str, int] # id -> priority dep_line_map = {} # type: Dict[str, int] # id -> line - for id, line in manager.all_imported_modules_in_file(self.tree): + for pri, id, line in manager.all_imported_modules_in_file(self.tree): + priorities[id] = min(pri, priorities.get(id, PRI_ALL)) if id == self.id: continue # Omit missing modules, as otherwise we could not type-check @@ -1281,6 +1308,7 @@ def parse_file(self) -> None: # for differences (e.g. --silent-imports). self.dependencies = dependencies self.suppressed = suppressed + self.priorities = priorities self.dep_line_map = dep_line_map self.check_blockers() @@ -1320,8 +1348,10 @@ def type_check(self) -> None: def write_cache(self) -> None: if self.path and INCREMENTAL in self.manager.flags and not self.manager.errors.is_errors(): + dep_prios = [self.priorities.get(dep, PRI_HIGH) for dep in self.dependencies] write_cache(self.id, self.path, self.tree, list(self.dependencies), list(self.suppressed), + dep_prios, self.manager) @@ -1388,10 +1418,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: # dependencies) to roots (those from which everything else can be # reached). for ascc in sccs: - # Sort the SCC's nodes in *reverse* order or encounter. - # This is a heuristic for handling import cycles. + # Order the SCC's nodes using a heuristic. # Note that ascc is a set, and scc is a list. - scc = sorted(ascc, key=lambda id: -graph[id].order) + scc = order_ascc(graph, ascc) # If builtins is in the list, move it last. (This is a bit of # a hack, but it's necessary because the builtins module is # part of a small cycle involving at least {builtins, abc, @@ -1400,6 +1429,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: if 'builtins' in ascc: scc.remove('builtins') scc.append('builtins') + if manager.flags.count(VERBOSE) >= 2: + for id in scc: + manager.trace("Priorities for %s:" % id, + " ".join("%s:%d" % (x, graph[id].priorities[x]) + for x in graph[id].dependencies + if x in ascc and x in graph[id].priorities)) # Because the SCCs are presented in topological sort order, we # don't need to look at dependencies recursively for staleness # -- the immediate dependencies are sufficient. @@ -1426,7 +1461,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: # cache file is newer than any scc node's cache file. oldest_in_scc = min(graph[id].meta.data_mtime for id in scc) newest_in_deps = 0 if not deps else max(graph[dep].meta.data_mtime for dep in deps) - if manager.flags.count(VERBOSE) >= 2: # Dump all mtimes for extreme debugging. + if manager.flags.count(VERBOSE) >= 3: # Dump all mtimes for extreme debugging. all_ids = sorted(ascc | deps, key=lambda id: graph[id].meta.data_mtime) for id in all_ids: if id in scc: @@ -1466,6 +1501,53 @@ def process_graph(graph: Graph, manager: BuildManager) -> None: process_stale_scc(graph, scc) +def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_ALL) -> List[str]: + """Come up with the ideal processing order within an SCC. + + Using the priorities assigned by all_imported_modules_in_file(), + try to reduce the cycle to a DAG, by omitting arcs representing + dependencies of lower priority. + + In the simplest case, if we have A <--> B where A has a top-level + "import B" (medium priority) but B only has the reverse "import A" + inside a function (low priority), we turn the cycle into a DAG by + dropping the B --> A arc, which leaves only A --> B. + + If all arcs have the same priority, we fall back to sorting by + reverse global order (the order in which modules were first + encountered). + + The algorithm is recursive, as follows: when as arcs of different + priorities are present, drop all arcs of the lowest priority, + identify SCCs in the resulting graph, and apply the algorithm to + each SCC thus found. The recursion is bounded because at each + recursion the spread in priorities is (at least) one less. + + In practice there are only a few priority levels (currently + N=3) and in the worst case we just carry out the same algorithm + for finding SCCs N times. Thus the complexity is no worse than + the complexity of the original SCC-finding algorithm -- see + strongly_connected_components() below for a reference. + """ + if len(ascc) == 1: + return [s for s in ascc] + pri_spread = set() + for id in ascc: + state = graph[id] + for dep in state.dependencies: + if dep in ascc: + pri = state.priorities.get(dep, PRI_HIGH) + if pri < pri_max: + pri_spread.add(pri) + if len(pri_spread) == 1: + # Filtered dependencies are uniform -- order by global order. + return sorted(ascc, key=lambda id: -graph[id].order) + pri_max = max(pri_spread) + sccs = sorted_components(graph, ascc, pri_max) + # The recursion is bounded by the len(pri_spread) check above. + return [s for ss in sccs for s in order_ascc(graph, ss, pri_max)] + + def process_fresh_scc(graph: Graph, scc: List[str]) -> None: """Process the modules in one SCC from their cached data.""" for id in scc: @@ -1497,7 +1579,9 @@ def process_stale_scc(graph: Graph, scc: List[str]) -> None: graph[id].write_cache() -def sorted_components(graph: Graph) -> List[AbstractSet[str]]: +def sorted_components(graph: Graph, + vertices: Optional[AbstractSet[str]] = None, + pri_max: int = PRI_ALL) -> List[AbstractSet[str]]: """Return the graph's SCCs, topologically sorted by dependencies. The sort order is from leaves (nodes without dependencies) to @@ -1507,9 +1591,9 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]: dependencies that aren't present in graph.keys() are ignored. """ # Compute SCCs. - vertices = set(graph) - edges = {id: [dep for dep in st.dependencies if dep in graph] - for id, st in graph.items()} + if vertices is None: + vertices = set(graph) + edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices} sccs = list(strongly_connected_components(vertices, edges)) # Topsort. sccsmap = {id: frozenset(scc) for scc in sccs for id in scc} @@ -1517,7 +1601,7 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]: for scc in sccs: deps = set() # type: Set[AbstractSet[str]] for id in scc: - deps.update(sccsmap[x] for x in graph[id].dependencies if x in graph) + deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max)) data[frozenset(scc)] = deps res = [] for ready in topsort(data): @@ -1534,7 +1618,17 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]: return res -def strongly_connected_components(vertices: Set[str], +def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: int) -> List[str]: + """Filter dependencies for id with pri < pri_max.""" + if id not in vertices: + return [] + state = graph[id] + return [dep + for dep in state.dependencies + if dep in vertices and state.priorities.get(dep, PRI_HIGH) < pri_max] + + +def strongly_connected_components(vertices: AbstractSet[str], edges: Dict[str, List[str]]) -> Iterator[Set[str]]: """Compute Strongly Connected Components of a directed graph. diff --git a/mypy/nodes.py b/mypy/nodes.py index 61213c732fb0..aaa814b1a8f1 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -224,6 +224,7 @@ def deserialize(cls, data: JsonDict) -> 'MypyFile': class ImportBase(Node): """Base class for all import statements.""" is_unreachable = False + is_top_level = False # Set by semanal.FirstPass # If an import replaces existing definitions, we construct dummy assignment # statements that assign the imported names to the names in the current scope, # for type checking purposes. Example: diff --git a/mypy/semanal.py b/mypy/semanal.py index c4d6ac2a80d6..13cc87b0a994 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2414,12 +2414,14 @@ def visit_import_from(self, node: ImportFrom) -> None: # We can't bind module names during the first pass, as the target module might be # unprocessed. However, we add dummy unbound imported names to the symbol table so # that we at least know that the name refers to a module. + node.is_top_level = True for name, as_name in node.names: imported_name = as_name or name if imported_name not in self.sem.globals: self.sem.add_symbol(imported_name, SymbolTableNode(UNBOUND_IMPORTED, None), node) def visit_import(self, node: Import) -> None: + node.is_top_level = True # This is similar to visit_import_from -- see the comment there. for id, as_id in node.ids: imported_id = as_id or id @@ -2429,6 +2431,9 @@ def visit_import(self, node: Import) -> None: # If the previous symbol is a variable, this should take precedence. self.sem.globals[imported_id] = SymbolTableNode(UNBOUND_IMPORTED, None) + def visit_import_all(self, node: ImportAll) -> None: + node.is_top_level = True + def visit_while_stmt(self, s: WhileStmt) -> None: s.body.accept(self) if s.else_body: diff --git a/mypy/test/testgraph.py b/mypy/test/testgraph.py index 6f9897660acf..13cf7002b62c 100644 --- a/mypy/test/testgraph.py +++ b/mypy/test/testgraph.py @@ -4,7 +4,7 @@ from mypy.myunit import Suite, assert_equal from mypy.build import BuildManager, State, TYPE_CHECK -from mypy.build import topsort, strongly_connected_components, sorted_components +from mypy.build import topsort, strongly_connected_components, sorted_components, order_ascc class GraphSuite(Suite): @@ -30,7 +30,7 @@ def test_scc(self) -> None: frozenset({'B', 'C'}), frozenset({'D'})}) - def test_sorted_components(self) -> None: + def _make_manager(self): manager = BuildManager( data_dir='', lib_path=[], @@ -41,9 +41,27 @@ def test_sorted_components(self) -> None: custom_typing_module='', source_set=None, reports=None) + return manager + + def test_sorted_components(self) -> None: + manager = self._make_manager() graph = {'a': State('a', None, 'import b, c', manager), + 'd': State('d', None, 'pass', manager), 'b': State('b', None, 'import c', manager), 'c': State('c', None, 'import b, d', manager), - 'd': State('d', None, 'pass', manager)} + } res = sorted_components(graph) assert_equal(res, [frozenset({'d'}), frozenset({'c', 'b'}), frozenset({'a'})]) + + def test_order_ascc(self) -> None: + manager = self._make_manager() + graph = {'a': State('a', None, 'import b, c', manager), + 'd': State('d', None, 'def f(): import a', manager), + 'b': State('b', None, 'import c', manager), + 'c': State('c', None, 'import b, d', manager), + } + res = sorted_components(graph) + assert_equal(res, [frozenset({'a', 'd', 'c', 'b'})]) + ascc = res[0] + scc = order_ascc(graph, ascc) + assert_equal(scc, ['d', 'c', 'b', 'a']) diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index d2f863e6205d..6674f91ee0e0 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -811,3 +811,39 @@ from a import x [file a/__init__.py] x = 0 [out] + + +-- Test stability under import cycles +-- ---------------------------------- + +-- The two tests are identical except one main has 'import x' and the other 'import y'. +-- Previously (before build.order_ascc() was added) one of these would fail because the +-- imports were processed in the (reverse) order in which the files were encountered. + +[case testImportCycleStability1] +import x +[file x.py] +def f() -> str: return '' +class Base: + attr = f() +def foo(): + import y +[file y.py] +import x +class Sub(x.Base): + attr = x.Base.attr +[out] + +[case testImportCycleStability2] +import y +[file x.py] +def f() -> str: return '' +class Base: + attr = f() +def foo(): + import y +[file y.py] +import x +class Sub(x.Base): + attr = x.Base.attr +[out]