Skip to content

Modernize TopologicalSorter. #458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/_pytask/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
scheduler = TopologicalSorter.from_dag(dag)
visited_nodes: set[str] = set()

for task_signature in scheduler.static_order():
while scheduler.is_active():
task_signature = scheduler.get_ready()[0]
if task_signature not in visited_nodes:
task = dag.nodes[task_signature]["task"]
have_changed = _have_task_or_neighbors_changed(session, dag, task)
Expand All @@ -117,6 +118,7 @@ def pytask_dag_select_execution_dag(session: Session, dag: nx.DiGraph) -> None:
dag.nodes[task_signature]["task"].markers.append(
Mark("skip_unchanged", (), {})
)
scheduler.done(task_signature)


def _have_task_or_neighbors_changed(
Expand Down
70 changes: 36 additions & 34 deletions src/_pytask/dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,22 +62,27 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
class TopologicalSorter:
"""The topological sorter class.

This class allows to perform a topological sort
This class allows to perform a topological sort#

Attributes
----------
dag
Not the full DAG, but a reduced version that only considers tasks.
priorities
A dictionary of task names to a priority value. 1 for try first, 0 for the
default priority and, -1 for try last.

"""

dag: nx.DiGraph
dag_backup: nx.DiGraph
priorities: dict[str, int] = field(factory=dict)
_is_prepared: bool = False
_nodes_out: set[str] = field(factory=set)
_nodes_processing: set[str] = field(factory=set)
_nodes_done: set[str] = field(factory=set)

@classmethod
def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
"""Instantiate from a DAG."""
if not dag.is_directed():
msg = "Only directed graphs have a topological order."
raise ValueError(msg)
cls.check_dag(dag)

tasks = [
dag.nodes[node]["task"] for node in dag.nodes if "task" in dag.nodes[node]
Expand All @@ -90,35 +95,46 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
}
task_dag = nx.DiGraph(task_dict).reverse()

return cls(dag=task_dag, priorities=priorities, dag_backup=task_dag.copy())
return cls(dag=task_dag, priorities=priorities)

@classmethod
def from_dag_and_sorter(
cls, dag: nx.DiGraph, sorter: TopologicalSorter
) -> TopologicalSorter:
"""Instantiate a sorter from another sorter and a DAG."""
new_sorter = cls.from_dag(dag)
new_sorter.done(*sorter._nodes_done)
new_sorter._nodes_processing = sorter._nodes_processing
return new_sorter

@staticmethod
def check_dag(dag: nx.DiGraph) -> None:
if not dag.is_directed():
msg = "Only directed graphs have a topological order."
raise ValueError(msg)

def prepare(self) -> None:
"""Perform some checks before creating a topological ordering."""
try:
nx.algorithms.cycles.find_cycle(self.dag)
nx.algorithms.cycles.find_cycle(dag)
except nx.NetworkXNoCycle:
pass
else:
msg = "The DAG contains cycles."
raise ValueError(msg)

self._is_prepared = True

def get_ready(self, n: int = 1) -> list[str]:
"""Get up to ``n`` tasks which are ready."""
if not self._is_prepared:
msg = "The TopologicalSorter needs to be prepared."
raise ValueError(msg)
if not isinstance(n, int) or n < 1:
msg = "'n' must be an integer greater or equal than 1."
raise ValueError(msg)

ready_nodes = {v for v, d in self.dag.in_degree() if d == 0} - self._nodes_out
ready_nodes = {
v for v, d in self.dag.in_degree() if d == 0
} - self._nodes_processing
prioritized_nodes = sorted(
ready_nodes, key=lambda x: self.priorities.get(x, 0)
)[-n:]

self._nodes_out.update(prioritized_nodes)
self._nodes_processing.update(prioritized_nodes)

return prioritized_nodes

Expand All @@ -128,23 +144,9 @@ def is_active(self) -> bool:

def done(self, *nodes: str) -> None:
"""Mark some tasks as done."""
self._nodes_out = self._nodes_out - set(nodes)
self._nodes_processing = self._nodes_processing - set(nodes)
self.dag.remove_nodes_from(nodes)

def reset(self) -> None:
"""Reset an exhausted topological sorter."""
if self.dag_backup:
self.dag = self.dag_backup.copy()
self._is_prepared = False
self._nodes_out = set()

def static_order(self) -> Generator[str, None, None]:
"""Return a topological order of tasks as an iterable."""
self.prepare()
while self.is_active():
new_task = self.get_ready()[0]
yield new_task
self.done(new_task)
self._nodes_done.update(nodes)


def _extract_priorities_from_tasks(tasks: list[PTask]) -> dict[str, int]:
Expand Down
10 changes: 5 additions & 5 deletions src/_pytask/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,21 @@ def pytask_execute_log_start(session: Session) -> None:
@hookimpl(trylast=True)
def pytask_execute_create_scheduler(session: Session) -> TopologicalSorter:
"""Create a scheduler based on topological sorting."""
scheduler = TopologicalSorter.from_dag(session.dag)
scheduler.prepare()
return scheduler
return TopologicalSorter.from_dag(session.dag)


@hookimpl
def pytask_execute_build(session: Session) -> bool | None:
"""Execute tasks."""
if isinstance(session.scheduler, TopologicalSorter):
for name in session.scheduler.static_order():
task = session.dag.nodes[name]["task"]
while session.scheduler.is_active():
task_name = session.scheduler.get_ready()[0]
task = session.dag.nodes[task_name]["task"]
report = session.hook.pytask_execute_task_protocol(
session=session, task=task
)
session.execution_reports.append(report)
session.scheduler.done(task_name)

if session.should_stop:
return True
Expand Down
48 changes: 25 additions & 23 deletions tests/test_dag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

@pytest.fixture()
def dag():
"""Create a dag with five nodes in a line."""
dag = nx.DiGraph()
for i in range(4):
task = Task(base_name=str(i), path=Path(), function=None)
Expand All @@ -29,7 +30,12 @@ def dag():

@pytest.mark.unit()
def test_sort_tasks_topologically(dag):
topo_ordering = list(TopologicalSorter.from_dag(dag).static_order())
sorter = TopologicalSorter.from_dag(dag)
topo_ordering = []
while sorter.is_active():
task_name = sorter.get_ready()[0]
topo_ordering.append(task_name)
sorter.done(task_name)
topo_names = [dag.nodes[sig]["task"].name for sig in topo_ordering]
assert topo_names == [f".::{i}" for i in range(5)]

Expand Down Expand Up @@ -166,37 +172,33 @@ def test_raise_error_for_cycle_in_graph(dag):
"115f685b0af2aef0c7317a0b48562f34cfb7a622549562bd3d34d4d948b4fdab",
"55c6cef62d3e62d5f8fc65bb846e66d8d0d3ca60608c04f6f7b095ea073a7dcf",
)
scheduler = TopologicalSorter.from_dag(dag)
with pytest.raises(ValueError, match="The DAG contains cycles."):
scheduler.prepare()


@pytest.mark.unit()
def test_raise_if_topological_sorter_is_not_prepared(dag):
scheduler = TopologicalSorter.from_dag(dag)
with pytest.raises(ValueError, match="The TopologicalSorter needs to be prepared."):
scheduler.get_ready(1)
TopologicalSorter.from_dag(dag)


@pytest.mark.unit()
def test_ask_for_invalid_number_of_ready_tasks(dag):
scheduler = TopologicalSorter.from_dag(dag)
scheduler.prepare()
with pytest.raises(ValueError, match="'n' must be"):
scheduler.get_ready(0)


@pytest.mark.unit()
def test_reset_topological_sorter(dag):
scheduler = TopologicalSorter.from_dag(dag)
scheduler.prepare()
name = scheduler.get_ready()[0]
scheduler.done(name)

assert scheduler._is_prepared
assert name not in scheduler.dag.nodes
def test_instantiate_sorter_from_other_sorter(dag):
name_to_sig = {dag.nodes[sig]["task"].name: sig for sig in dag.nodes}

scheduler.reset()

assert not scheduler._is_prepared
assert name in scheduler.dag.nodes
scheduler = TopologicalSorter.from_dag(dag)
for _ in range(2):
task_name = scheduler.get_ready()[0]
scheduler.done(task_name)
assert scheduler._nodes_done == {name_to_sig[name] for name in (".::0", ".::1")}

task = Task(base_name="5", path=Path(), function=None)
dag.add_node(task.signature, task=Task(base_name="5", path=Path(), function=None))
dag.add_edge(name_to_sig[".::4"], task.signature)

new_scheduler = TopologicalSorter.from_dag_and_sorter(dag, scheduler)
while new_scheduler.is_active():
task_name = new_scheduler.get_ready()[0]
new_scheduler.done(task_name)
assert new_scheduler._nodes_done == set(name_to_sig.values()) | {task.signature}