diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index f8f9fb9e..07ab23b5 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -107,7 +107,8 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None: scheduler = TopologicalSorter.from_dag(dag) visited_nodes: set[str] = set() - for task_signature in scheduler.static_order(): + while scheduler.is_active(): + task_signature = scheduler.get_ready()[0] if task_signature not in visited_nodes: task = dag.nodes[task_signature]["task"] have_changed = _have_task_or_neighbors_changed(session, dag, task) @@ -117,6 +118,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None: dag.nodes[task_signature]["task"].markers.append( Mark("skip_unchanged", (), {}) ) + scheduler.done(task_signature) def _have_task_or_neighbors_changed( diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index ded08d3e..cb359e0c 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -62,22 +62,27 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]: class TopologicalSorter: """The topological sorter class. - This class allows to perform a topological sort + This class allows to perform a topological sort# + + Attributes + ---------- + dag + Not the full DAG, but a reduced version that only considers tasks. + priorities + A dictionary of task names to a priority value. 1 for try first, 0 for the + default priority and, -1 for try last. """ dag: nx.DiGraph - dag_backup: nx.DiGraph priorities: dict[str, int] = field(factory=dict) - _is_prepared: bool = False - _nodes_out: set[str] = field(factory=set) + _nodes_processing: set[str] = field(factory=set) + _nodes_done: set[str] = field(factory=set) @classmethod def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: """Instantiate from a DAG.""" - if not dag.is_directed(): - msg = "Only directed graphs have a topological order." - raise ValueError(msg) + cls.check_dag(dag) tasks = [ dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] @@ -90,35 +95,46 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: } task_dag = nx.DiGraph(task_dict).reverse() - return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy()) + return cls(dag=task_dag, priorities=priorities) + + @classmethod + def from_dag_and_sorter( + cls, dag: nx.DiGraph, sorter: TopologicalSorter + ) -> TopologicalSorter: + """Instantiate a sorter from another sorter and a DAG.""" + new_sorter = cls.from_dag(dag) + new_sorter.done(*sorter._nodes_done) + new_sorter._nodes_processing = sorter._nodes_processing + return new_sorter + + @staticmethod + def check_dag(dag: nx.DiGraph) -> None: + if not dag.is_directed(): + msg = "Only directed graphs have a topological order." + raise ValueError(msg) - def prepare(self) -> None: - """Perform some checks before creating a topological ordering.""" try: - nx.algorithms.cycles.find_cycle(self.dag) + nx.algorithms.cycles.find_cycle(dag) except nx.NetworkXNoCycle: pass else: msg = "The DAG contains cycles." raise ValueError(msg) - self._is_prepared = True - def get_ready(self, n: int = 1) -> list[str]: """Get up to ``n`` tasks which are ready.""" - if not self._is_prepared: - msg = "The TopologicalSorter needs to be prepared." - raise ValueError(msg) if not isinstance(n, int) or n < 1: msg = "'n' must be an integer greater or equal than 1." raise ValueError(msg) - ready_nodes = {v for v, d in self.dag.in_degree() if d == 0} - self._nodes_out + ready_nodes = { + v for v, d in self.dag.in_degree() if d == 0 + } - self._nodes_processing prioritized_nodes = sorted( ready_nodes, key=lambda x: self.priorities.get(x, 0) )[-n:] - self._nodes_out.update(prioritized_nodes) + self._nodes_processing.update(prioritized_nodes) return prioritized_nodes @@ -128,23 +144,9 @@ def is_active(self) -> bool: def done(self, *nodes: str) -> None: """Mark some tasks as done.""" - self._nodes_out = self._nodes_out - set(nodes) + self._nodes_processing = self._nodes_processing - set(nodes) self.dag.remove_nodes_from(nodes) - - def reset(self) -> None: - """Reset an exhausted topological sorter.""" - if self.dag_backup: - self.dag = self.dag_backup.copy() - self._is_prepared = False - self._nodes_out = set() - - def static_order(self) -> Generator[str, None, None]: - """Return a topological order of tasks as an iterable.""" - self.prepare() - while self.is_active(): - new_task = self.get_ready()[0] - yield new_task - self.done(new_task) + self._nodes_done.update(nodes) def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index c6641710..b7dbd5f8 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -72,21 +72,21 @@ def pytask_execute_log_start(session: Session) -> None: @hookimpl(trylast=True) def pytask_execute_create_scheduler(session: Session) -> TopologicalSorter: """Create a scheduler based on topological sorting.""" - scheduler = TopologicalSorter.from_dag(session.dag) - scheduler.prepare() - return scheduler + return TopologicalSorter.from_dag(session.dag) @hookimpl def pytask_execute_build(session: Session) -> bool | None: """Execute tasks.""" if isinstance(session.scheduler, TopologicalSorter): - for name in session.scheduler.static_order(): - task = session.dag.nodes[name]["task"] + while session.scheduler.is_active(): + task_name = session.scheduler.get_ready()[0] + task = session.dag.nodes[task_name]["task"] report = session.hook.pytask_execute_task_protocol( session=session, task=task ) session.execution_reports.append(report) + session.scheduler.done(task_name) if session.should_stop: return True diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index f6c685dd..158322fa 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -16,6 +16,7 @@ @pytest.fixture() def dag(): + """Create a dag with five nodes in a line.""" dag = nx.DiGraph() for i in range(4): task = Task(base_name=str(i), path=Path(), function=None) @@ -29,7 +30,12 @@ def dag(): @pytest.mark.unit() def test_sort_tasks_topologically(dag): - topo_ordering = list(TopologicalSorter.from_dag(dag).static_order()) + sorter = TopologicalSorter.from_dag(dag) + topo_ordering = [] + while sorter.is_active(): + task_name = sorter.get_ready()[0] + topo_ordering.append(task_name) + sorter.done(task_name) topo_names = [dag.nodes[sig]["task"].name for sig in topo_ordering] assert topo_names == [f".::{i}" for i in range(5)] @@ -166,37 +172,33 @@ def test_raise_error_for_cycle_in_graph(dag): "115f685b0af2aef0c7317a0b48562f34cfb7a622549562bd3d34d4d948b4fdab", "55c6cef62d3e62d5f8fc65bb846e66d8d0d3ca60608c04f6f7b095ea073a7dcf", ) - scheduler = TopologicalSorter.from_dag(dag) with pytest.raises(ValueError, match="The DAG contains cycles."): - scheduler.prepare() - - -@pytest.mark.unit() -def test_raise_if_topological_sorter_is_not_prepared(dag): - scheduler = TopologicalSorter.from_dag(dag) - with pytest.raises(ValueError, match="The TopologicalSorter needs to be prepared."): - scheduler.get_ready(1) + TopologicalSorter.from_dag(dag) @pytest.mark.unit() def test_ask_for_invalid_number_of_ready_tasks(dag): scheduler = TopologicalSorter.from_dag(dag) - scheduler.prepare() with pytest.raises(ValueError, match="'n' must be"): scheduler.get_ready(0) @pytest.mark.unit() -def test_reset_topological_sorter(dag): - scheduler = TopologicalSorter.from_dag(dag) - scheduler.prepare() - name = scheduler.get_ready()[0] - scheduler.done(name) - - assert scheduler._is_prepared - assert name not in scheduler.dag.nodes +def test_instantiate_sorter_from_other_sorter(dag): + name_to_sig = {dag.nodes[sig]["task"].name: sig for sig in dag.nodes} - scheduler.reset() - - assert not scheduler._is_prepared - assert name in scheduler.dag.nodes + scheduler = TopologicalSorter.from_dag(dag) + for _ in range(2): + task_name = scheduler.get_ready()[0] + scheduler.done(task_name) + assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")} + + task = Task(base_name="5", path=Path(), function=None) + dag.add_node(task.signature, task=Task(base_name="5", path=Path(), function=None)) + dag.add_edge(name_to_sig[".::4"], task.signature) + + new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) + while new_scheduler.is_active(): + task_name = new_scheduler.get_ready()[0] + new_scheduler.done(task_name) + assert new_scheduler._nodes_done == set(name_to_sig.values()) | {task.signature}