Skip to content

Commit 5a15958

Browse files
authored
Merge cb5867d into e2a4592
2 parents e2a4592 + cb5867d commit 5a15958

File tree

4 files changed

+69
-63
lines changed

4 files changed

+69
-63
lines changed

src/_pytask/dag.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
107107
scheduler = TopologicalSorter.from_dag(dag)
108108
visited_nodes: set[str] = set()
109109

110-
for task_signature in scheduler.static_order():
110+
while scheduler.is_active():
111+
task_signature = scheduler.get_ready()[0]
111112
if task_signature not in visited_nodes:
112113
task = dag.nodes[task_signature]["task"]
113114
have_changed = _have_task_or_neighbors_changed(session, dag, task)
@@ -117,6 +118,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
117118
dag.nodes[task_signature]["task"].markers.append(
118119
Mark("skip_unchanged", (), {})
119120
)
121+
scheduler.done(task_signature)
120122

121123

122124
def _have_task_or_neighbors_changed(

src/_pytask/dag_utils.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,22 +62,27 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
6262
class TopologicalSorter:
6363
"""The topological sorter class.
6464
65-
This class allows to perform a topological sort
65+
This class allows to perform a topological sort#
66+
67+
Attributes
68+
----------
69+
dag
70+
Not the full DAG, but a reduced version that only considers tasks.
71+
priorities
72+
A dictionary of task names to a priority value. 1 for try first, 0 for the
73+
default priority and, -1 for try last.
6674
6775
"""
6876

6977
dag: nx.DiGraph
70-
dag_backup: nx.DiGraph
7178
priorities: dict[str, int] = field(factory=dict)
72-
_is_prepared: bool = False
73-
_nodes_out: set[str] = field(factory=set)
79+
_nodes_processing: set[str] = field(factory=set)
80+
_nodes_done: set[str] = field(factory=set)
7481

7582
@classmethod
7683
def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
7784
"""Instantiate from a DAG."""
78-
if not dag.is_directed():
79-
msg = "Only directed graphs have a topological order."
80-
raise ValueError(msg)
85+
cls.check_dag(dag)
8186

8287
tasks = [
8388
dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node]
@@ -90,35 +95,46 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
9095
}
9196
task_dag = nx.DiGraph(task_dict).reverse()
9297

93-
return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy())
98+
return cls(dag=task_dag, priorities=priorities)
99+
100+
@classmethod
101+
def from_dag_and_sorter(
102+
cls, dag: nx.DiGraph, sorter: TopologicalSorter
103+
) -> TopologicalSorter:
104+
"""Instantiate a sorter from another sorter and a DAG."""
105+
new_sorter = cls.from_dag(dag)
106+
new_sorter.done(*sorter._nodes_done)
107+
new_sorter._nodes_processing = sorter._nodes_processing
108+
return new_sorter
109+
110+
@staticmethod
111+
def check_dag(dag: nx.DiGraph) -> None:
112+
if not dag.is_directed():
113+
msg = "Only directed graphs have a topological order."
114+
raise ValueError(msg)
94115

95-
def prepare(self) -> None:
96-
"""Perform some checks before creating a topological ordering."""
97116
try:
98-
nx.algorithms.cycles.find_cycle(self.dag)
117+
nx.algorithms.cycles.find_cycle(dag)
99118
except nx.NetworkXNoCycle:
100119
pass
101120
else:
102121
msg = "The DAG contains cycles."
103122
raise ValueError(msg)
104123

105-
self._is_prepared = True
106-
107124
def get_ready(self, n: int = 1) -> list[str]:
108125
"""Get up to ``n`` tasks which are ready."""
109-
if not self._is_prepared:
110-
msg = "The TopologicalSorter needs to be prepared."
111-
raise ValueError(msg)
112126
if not isinstance(n, int) or n < 1:
113127
msg = "'n' must be an integer greater or equal than 1."
114128
raise ValueError(msg)
115129

116-
ready_nodes = {v for v, d in self.dag.in_degree() if d == 0} - self._nodes_out
130+
ready_nodes = {
131+
v for v, d in self.dag.in_degree() if d == 0
132+
} - self._nodes_processing
117133
prioritized_nodes = sorted(
118134
ready_nodes, key=lambda x: self.priorities.get(x, 0)
119135
)[-n:]
120136

121-
self._nodes_out.update(prioritized_nodes)
137+
self._nodes_processing.update(prioritized_nodes)
122138

123139
return prioritized_nodes
124140

@@ -128,23 +144,9 @@ def is_active(self) -> bool:
128144

129145
def done(self, *nodes: str) -> None:
130146
"""Mark some tasks as done."""
131-
self._nodes_out = self._nodes_out - set(nodes)
147+
self._nodes_processing = self._nodes_processing - set(nodes)
132148
self.dag.remove_nodes_from(nodes)
133-
134-
def reset(self) -> None:
135-
"""Reset an exhausted topological sorter."""
136-
if self.dag_backup:
137-
self.dag = self.dag_backup.copy()
138-
self._is_prepared = False
139-
self._nodes_out = set()
140-
141-
def static_order(self) -> Generator[str, None, None]:
142-
"""Return a topological order of tasks as an iterable."""
143-
self.prepare()
144-
while self.is_active():
145-
new_task = self.get_ready()[0]
146-
yield new_task
147-
self.done(new_task)
149+
self._nodes_done.update(nodes)
148150

149151

150152
def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:

src/_pytask/execute.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,21 +72,21 @@ def pytask_execute_log_start(session: Session) -> None:
7272
@hookimpl(trylast=True)
7373
def pytask_execute_create_scheduler(session: Session) -> TopologicalSorter:
7474
"""Create a scheduler based on topological sorting."""
75-
scheduler = TopologicalSorter.from_dag(session.dag)
76-
scheduler.prepare()
77-
return scheduler
75+
return TopologicalSorter.from_dag(session.dag)
7876

7977

8078
@hookimpl
8179
def pytask_execute_build(session: Session) -> bool | None:
8280
"""Execute tasks."""
8381
if isinstance(session.scheduler, TopologicalSorter):
84-
for name in session.scheduler.static_order():
85-
task = session.dag.nodes[name]["task"]
82+
while session.scheduler.is_active():
83+
task_name = session.scheduler.get_ready()[0]
84+
task = session.dag.nodes[task_name]["task"]
8685
report = session.hook.pytask_execute_task_protocol(
8786
session=session, task=task
8887
)
8988
session.execution_reports.append(report)
89+
session.scheduler.done(task_name)
9090

9191
if session.should_stop:
9292
return True

tests/test_dag_utils.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
@pytest.fixture()
1818
def dag():
19+
"""Create a dag with five nodes in a line."""
1920
dag = nx.DiGraph()
2021
for i in range(4):
2122
task = Task(base_name=str(i), path=Path(), function=None)
@@ -29,7 +30,12 @@ def dag():
2930

3031
@pytest.mark.unit()
3132
def test_sort_tasks_topologically(dag):
32-
topo_ordering = list(TopologicalSorter.from_dag(dag).static_order())
33+
sorter = TopologicalSorter.from_dag(dag)
34+
topo_ordering = []
35+
while sorter.is_active():
36+
task_name = sorter.get_ready()[0]
37+
topo_ordering.append(task_name)
38+
sorter.done(task_name)
3339
topo_names = [dag.nodes[sig]["task"].name for sig in topo_ordering]
3440
assert topo_names == [f".::{i}" for i in range(5)]
3541

@@ -166,37 +172,33 @@ def test_raise_error_for_cycle_in_graph(dag):
166172
"115f685b0af2aef0c7317a0b48562f34cfb7a622549562bd3d34d4d948b4fdab",
167173
"55c6cef62d3e62d5f8fc65bb846e66d8d0d3ca60608c04f6f7b095ea073a7dcf",
168174
)
169-
scheduler = TopologicalSorter.from_dag(dag)
170175
with pytest.raises(ValueError, match="The DAG contains cycles."):
171-
scheduler.prepare()
172-
173-
174-
@pytest.mark.unit()
175-
def test_raise_if_topological_sorter_is_not_prepared(dag):
176-
scheduler = TopologicalSorter.from_dag(dag)
177-
with pytest.raises(ValueError, match="The TopologicalSorter needs to be prepared."):
178-
scheduler.get_ready(1)
176+
TopologicalSorter.from_dag(dag)
179177

180178

181179
@pytest.mark.unit()
182180
def test_ask_for_invalid_number_of_ready_tasks(dag):
183181
scheduler = TopologicalSorter.from_dag(dag)
184-
scheduler.prepare()
185182
with pytest.raises(ValueError, match="'n' must be"):
186183
scheduler.get_ready(0)
187184

188185

189186
@pytest.mark.unit()
190-
def test_reset_topological_sorter(dag):
191-
scheduler = TopologicalSorter.from_dag(dag)
192-
scheduler.prepare()
193-
name = scheduler.get_ready()[0]
194-
scheduler.done(name)
195-
196-
assert scheduler._is_prepared
197-
assert name not in scheduler.dag.nodes
187+
def test_instantiate_sorter_from_other_sorter(dag):
188+
name_to_sig = {dag.nodes[sig]["task"].name: sig for sig in dag.nodes}
198189

199-
scheduler.reset()
200-
201-
assert not scheduler._is_prepared
202-
assert name in scheduler.dag.nodes
190+
scheduler = TopologicalSorter.from_dag(dag)
191+
for _ in range(2):
192+
task_name = scheduler.get_ready()[0]
193+
scheduler.done(task_name)
194+
assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")}
195+
196+
task = Task(base_name="5", path=Path(), function=None)
197+
dag.add_node(task.signature, task=Task(base_name="5", path=Path(), function=None))
198+
dag.add_edge(name_to_sig[".::4"], task.signature)
199+
200+
new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler)
201+
while new_scheduler.is_active():
202+
task_name = new_scheduler.get_ready()[0]
203+
new_scheduler.done(task_name)
204+
assert new_scheduler._nodes_done == set(name_to_sig.values()) | {task.signature}

0 commit comments

Comments
 (0)