From 2cd96d761a05ad2fba7fa3bb78bc9bb631844208 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 21 Oct 2023 17:28:31 +0200 Subject: [PATCH 1/4] Simplify TopologicalSorter. --- src/_pytask/dag_utils.py | 28 ++++++++-------------------- src/_pytask/execute.py | 4 +--- tests/test_dag_utils.py | 27 +-------------------------- 3 files changed, 10 insertions(+), 49 deletions(-) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 650cf75c..3c74c7fd 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -69,15 +69,12 @@ class TopologicalSorter: dag: nx.DiGraph dag_backup: nx.DiGraph priorities: dict[str, int] = field(factory=dict) - _is_prepared: bool = False _nodes_out: set[str] = field(factory=set) @classmethod def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: """Instantiate from a DAG.""" - if not dag.is_directed(): - msg = "Only directed graphs have a topological order." - raise ValueError(msg) + cls.check_dag(dag) tasks = [ dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node] @@ -90,23 +87,22 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy()) - def prepare(self) -> None: - """Perform some checks before creating a topological ordering.""" + @staticmethod + def check_dag(dag: nx.DiGraph) -> None: + if not dag.is_directed(): + msg = "Only directed graphs have a topological order." + raise ValueError(msg) + try: - nx.algorithms.cycles.find_cycle(self.dag) + nx.algorithms.cycles.find_cycle(dag) except nx.NetworkXNoCycle: pass else: msg = "The DAG contains cycles." raise ValueError(msg) - self._is_prepared = True - def get_ready(self, n: int = 1) -> list[str]: """Get up to ``n`` tasks which are ready.""" - if not self._is_prepared: - msg = "The TopologicalSorter needs to be prepared." - raise ValueError(msg) if not isinstance(n, int) or n < 1: msg = "'n' must be an integer greater or equal than 1." raise ValueError(msg) @@ -129,16 +125,8 @@ def done(self, *nodes: str) -> None: self._nodes_out = self._nodes_out - set(nodes) self.dag.remove_nodes_from(nodes) - def reset(self) -> None: - """Reset an exhausted topological sorter.""" - if self.dag_backup: - self.dag = self.dag_backup.copy() - self._is_prepared = False - self._nodes_out = set() - def static_order(self) -> Generator[str, None, None]: """Return a topological order of tasks as an iterable.""" - self.prepare() while self.is_active(): new_task = self.get_ready()[0] yield new_task diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index e468340f..5ba15aac 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -75,9 +75,7 @@ def pytask_execute_log_start(session: Session) -> None: @hookimpl(trylast=True) def pytask_execute_create_scheduler(session: Session) -> TopologicalSorter: """Create a scheduler based on topological sorting.""" - scheduler = TopologicalSorter.from_dag(session.dag) - scheduler.prepare() - return scheduler + return TopologicalSorter.from_dag(session.dag) @hookimpl diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 3b25a8f0..aa675daa 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -140,37 +140,12 @@ def test_raise_error_for_undirected_graphs(dag): @pytest.mark.unit() def test_raise_error_for_cycle_in_graph(dag): dag.add_edge(".::4", ".::1") - scheduler = TopologicalSorter.from_dag(dag) with pytest.raises(ValueError, match="The DAG contains cycles."): - scheduler.prepare() - - -@pytest.mark.unit() -def test_raise_if_topological_sorter_is_not_prepared(dag): - scheduler = TopologicalSorter.from_dag(dag) - with pytest.raises(ValueError, match="The TopologicalSorter needs to be prepared."): - scheduler.get_ready(1) + TopologicalSorter.from_dag(dag) @pytest.mark.unit() def test_ask_for_invalid_number_of_ready_tasks(dag): scheduler = TopologicalSorter.from_dag(dag) - scheduler.prepare() with pytest.raises(ValueError, match="'n' must be"): scheduler.get_ready(0) - - -@pytest.mark.unit() -def test_reset_topological_sorter(dag): - scheduler = TopologicalSorter.from_dag(dag) - scheduler.prepare() - name = scheduler.get_ready()[0] - scheduler.done(name) - - assert scheduler._is_prepared - assert name not in scheduler.dag.nodes - - scheduler.reset() - - assert not scheduler._is_prepared - assert name in scheduler.dag.nodes From 6f2aaef48a175a3b59d93bedbeaee98f6eedfe43 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 21 Oct 2023 17:50:28 +0200 Subject: [PATCH 2/4] Remove static_order. --- src/_pytask/dag.py | 4 +++- src/_pytask/dag_utils.py | 7 ------- src/_pytask/execute.py | 6 ++++-- tests/test_dag_utils.py | 7 ++++++- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/_pytask/dag.py b/src/_pytask/dag.py index b67ec49e..43d537fe 100644 --- a/src/_pytask/dag.py +++ b/src/_pytask/dag.py @@ -122,7 +122,8 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None: scheduler = TopologicalSorter.from_dag(dag) visited_nodes: set[str] = set() - for task_name in scheduler.static_order(): + while scheduler.is_active(): + task_name = scheduler.get_ready()[0] if task_name not in visited_nodes: task = dag.nodes[task_name]["task"] have_changed = _have_task_or_neighbors_changed(session, dag, task) @@ -132,6 +133,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None: dag.nodes[task_name]["task"].markers.append( Mark("skip_unchanged", (), {}) ) + scheduler.done(task_name) @hookimpl diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 3c74c7fd..3218bfeb 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -125,13 +125,6 @@ def done(self, *nodes: str) -> None: self._nodes_out = self._nodes_out - set(nodes) self.dag.remove_nodes_from(nodes) - def static_order(self) -> Generator[str, None, None]: - """Return a topological order of tasks as an iterable.""" - while self.is_active(): - new_task = self.get_ready()[0] - yield new_task - self.done(new_task) - def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: """Extract priorities from tasks. diff --git a/src/_pytask/execute.py b/src/_pytask/execute.py index 5ba15aac..b2ad6bf8 100644 --- a/src/_pytask/execute.py +++ b/src/_pytask/execute.py @@ -82,12 +82,14 @@ def pytask_execute_create_scheduler(session: Session) -> TopologicalSorter: def pytask_execute_build(session: Session) -> bool | None: """Execute tasks.""" if isinstance(session.scheduler, TopologicalSorter): - for name in session.scheduler.static_order(): - task = session.dag.nodes[name]["task"] + while session.scheduler.is_active(): + task_name = session.scheduler.get_ready()[0] + task = session.dag.nodes[task_name]["task"] report = session.hook.pytask_execute_task_protocol( session=session, task=task ) session.execution_reports.append(report) + session.scheduler.done(task_name) if session.should_stop: return True diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index aa675daa..8512f394 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -29,7 +29,12 @@ def dag(): @pytest.mark.unit() def test_sort_tasks_topologically(dag): - topo_ordering = list(TopologicalSorter.from_dag(dag).static_order()) + dag = TopologicalSorter.from_dag(dag) + topo_ordering = [] + while dag.is_active(): + task_name = dag.get_ready()[0] + topo_ordering.append(task_name) + dag.done(task_name) assert topo_ordering == [f".::{i}" for i in range(5)] From e37abd6bf8a272ae143b9e0894c2f702d54d6f65 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sat, 21 Oct 2023 17:52:33 +0200 Subject: [PATCH 3/4] Remove backup. --- src/_pytask/dag_utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 3218bfeb..29e5748e 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -67,7 +67,6 @@ class TopologicalSorter: """ dag: nx.DiGraph - dag_backup: nx.DiGraph priorities: dict[str, int] = field(factory=dict) _nodes_out: set[str] = field(factory=set) @@ -85,7 +84,7 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: task_dict = {name: nx.ancestors(dag, name) & task_names for name in task_names} task_dag = nx.DiGraph(task_dict).reverse() - return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy()) + return cls(dag=task_dag, priorities=priorities) @staticmethod def check_dag(dag: nx.DiGraph) -> None: From 789e0d8e438f57e7352373379b822c8d67f300b1 Mon Sep 17 00:00:00 2001 From: Tobias Raabe Date: Sun, 22 Oct 2023 14:01:06 +0200 Subject: [PATCH 4/4] Add instantiation from another sorter and a DAG. --- src/_pytask/dag_utils.py | 32 +++++++++++++++++++++++++++----- tests/test_dag_utils.py | 19 +++++++++++++++++++ 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/src/_pytask/dag_utils.py b/src/_pytask/dag_utils.py index 29e5748e..83bb1464 100644 --- a/src/_pytask/dag_utils.py +++ b/src/_pytask/dag_utils.py @@ -62,13 +62,22 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]: class TopologicalSorter: """The topological sorter class. - This class allows to perform a topological sort + This class allows to perform a topological sort# + + Attributes + ---------- + dag + Not the full DAG, but a reduced version that only considers tasks. + priorities + A dictionary of task names to a priority value. 1 for try first, 0 for the + default priority and, -1 for try last. """ dag: nx.DiGraph priorities: dict[str, int] = field(factory=dict) - _nodes_out: set[str] = field(factory=set) + _nodes_processing: set[str] = field(factory=set) + _nodes_done: set[str] = field(factory=set) @classmethod def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: @@ -86,6 +95,16 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter: return cls(dag=task_dag, priorities=priorities) + @classmethod + def from_dag_and_sorter( + cls, dag: nx.DiGraph, sorter: TopologicalSorter + ) -> TopologicalSorter: + """Instantiate a sorter from another sorter and a DAG.""" + new_sorter = cls.from_dag(dag) + new_sorter.done(*sorter._nodes_done) + new_sorter._nodes_processing = sorter._nodes_processing + return new_sorter + @staticmethod def check_dag(dag: nx.DiGraph) -> None: if not dag.is_directed(): @@ -106,12 +125,14 @@ def get_ready(self, n: int = 1) -> list[str]: msg = "'n' must be an integer greater or equal than 1." raise ValueError(msg) - ready_nodes = {v for v, d in self.dag.in_degree() if d == 0} - self._nodes_out + ready_nodes = { + v for v, d in self.dag.in_degree() if d == 0 + } - self._nodes_processing prioritized_nodes = sorted( ready_nodes, key=lambda x: self.priorities.get(x, 0) )[-n:] - self._nodes_out.update(prioritized_nodes) + self._nodes_processing.update(prioritized_nodes) return prioritized_nodes @@ -121,8 +142,9 @@ def is_active(self) -> bool: def done(self, *nodes: str) -> None: """Mark some tasks as done.""" - self._nodes_out = self._nodes_out - set(nodes) + self._nodes_processing = self._nodes_processing - set(nodes) self.dag.remove_nodes_from(nodes) + self._nodes_done.update(nodes) def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]: diff --git a/tests/test_dag_utils.py b/tests/test_dag_utils.py index 8512f394..9b840957 100644 --- a/tests/test_dag_utils.py +++ b/tests/test_dag_utils.py @@ -16,6 +16,7 @@ @pytest.fixture() def dag(): + """Create a dag with five nodes in a line.""" dag = nx.DiGraph() for i in range(4): dag.add_node(f".::{i}", task=Task(base_name=str(i), path=Path(), function=None)) @@ -154,3 +155,21 @@ def test_ask_for_invalid_number_of_ready_tasks(dag): scheduler = TopologicalSorter.from_dag(dag) with pytest.raises(ValueError, match="'n' must be"): scheduler.get_ready(0) + + +@pytest.mark.unit() +def test_instantiate_sorter_from_other_sorter(dag): + scheduler = TopologicalSorter.from_dag(dag) + for _ in range(2): + task_name = scheduler.get_ready()[0] + scheduler.done(task_name) + assert scheduler._nodes_done == {".::0", ".::1"} + + dag.add_node(".::5", task=Task(base_name="5", path=Path(), function=None)) + dag.add_edge(".::4", ".::5") + + new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler) + while new_scheduler.is_active(): + task_name = new_scheduler.get_ready()[0] + new_scheduler.done(task_name) + assert new_scheduler._nodes_done == {".::0", ".::1", ".::2", ".::3", ".::4", ".::5"}