@@ -62,22 +62,27 @@ def node_and_neighbors(dag: nx.DiGraph, node: str) -> Iterable[str]:
62
62
class TopologicalSorter :
63
63
"""The topological sorter class.
64
64
65
- This class allows to perform a topological sort
65
+ This class allows to perform a topological sort#
66
+
67
+ Attributes
68
+ ----------
69
+ dag
70
+ Not the full DAG, but a reduced version that only considers tasks.
71
+ priorities
72
+ A dictionary of task names to a priority value. 1 for try first, 0 for the
73
+ default priority and, -1 for try last.
66
74
67
75
"""
68
76
69
77
dag : nx .DiGraph
70
- dag_backup : nx .DiGraph
71
78
priorities : dict [str , int ] = field (factory = dict )
72
- _is_prepared : bool = False
73
- _nodes_out : set [str ] = field (factory = set )
79
+ _nodes_processing : set [ str ] = field ( factory = set )
80
+ _nodes_done : set [str ] = field (factory = set )
74
81
75
82
@classmethod
76
83
def from_dag (cls , dag : nx .DiGraph ) -> TopologicalSorter :
77
84
"""Instantiate from a DAG."""
78
- if not dag .is_directed ():
79
- msg = "Only directed graphs have a topological order."
80
- raise ValueError (msg )
85
+ cls .check_dag (dag )
81
86
82
87
tasks = [
83
88
dag .nodes [node ]["task" ] for node in dag .nodes if "task" in dag .nodes [node ]
@@ -90,35 +95,46 @@ def from_dag(cls, dag: nx.DiGraph) -> TopologicalSorter:
90
95
}
91
96
task_dag = nx .DiGraph (task_dict ).reverse ()
92
97
93
- return cls (dag = task_dag , priorities = priorities , dag_backup = task_dag .copy ())
98
+ return cls (dag = task_dag , priorities = priorities )
99
+
100
+ @classmethod
101
+ def from_dag_and_sorter (
102
+ cls , dag : nx .DiGraph , sorter : TopologicalSorter
103
+ ) -> TopologicalSorter :
104
+ """Instantiate a sorter from another sorter and a DAG."""
105
+ new_sorter = cls .from_dag (dag )
106
+ new_sorter .done (* sorter ._nodes_done )
107
+ new_sorter ._nodes_processing = sorter ._nodes_processing
108
+ return new_sorter
109
+
110
+ @staticmethod
111
+ def check_dag (dag : nx .DiGraph ) -> None :
112
+ if not dag .is_directed ():
113
+ msg = "Only directed graphs have a topological order."
114
+ raise ValueError (msg )
94
115
95
- def prepare (self ) -> None :
96
- """Perform some checks before creating a topological ordering."""
97
116
try :
98
- nx .algorithms .cycles .find_cycle (self . dag )
117
+ nx .algorithms .cycles .find_cycle (dag )
99
118
except nx .NetworkXNoCycle :
100
119
pass
101
120
else :
102
121
msg = "The DAG contains cycles."
103
122
raise ValueError (msg )
104
123
105
- self ._is_prepared = True
106
-
107
124
def get_ready (self , n : int = 1 ) -> list [str ]:
108
125
"""Get up to ``n`` tasks which are ready."""
109
- if not self ._is_prepared :
110
- msg = "The TopologicalSorter needs to be prepared."
111
- raise ValueError (msg )
112
126
if not isinstance (n , int ) or n < 1 :
113
127
msg = "'n' must be an integer greater or equal than 1."
114
128
raise ValueError (msg )
115
129
116
- ready_nodes = {v for v , d in self .dag .in_degree () if d == 0 } - self ._nodes_out
130
+ ready_nodes = {
131
+ v for v , d in self .dag .in_degree () if d == 0
132
+ } - self ._nodes_processing
117
133
prioritized_nodes = sorted (
118
134
ready_nodes , key = lambda x : self .priorities .get (x , 0 )
119
135
)[- n :]
120
136
121
- self ._nodes_out .update (prioritized_nodes )
137
+ self ._nodes_processing .update (prioritized_nodes )
122
138
123
139
return prioritized_nodes
124
140
@@ -128,23 +144,9 @@ def is_active(self) -> bool:
128
144
129
145
def done (self , * nodes : str ) -> None :
130
146
"""Mark some tasks as done."""
131
- self ._nodes_out = self ._nodes_out - set (nodes )
147
+ self ._nodes_processing = self ._nodes_processing - set (nodes )
132
148
self .dag .remove_nodes_from (nodes )
133
-
134
- def reset (self ) -> None :
135
- """Reset an exhausted topological sorter."""
136
- if self .dag_backup :
137
- self .dag = self .dag_backup .copy ()
138
- self ._is_prepared = False
139
- self ._nodes_out = set ()
140
-
141
- def static_order (self ) -> Generator [str , None , None ]:
142
- """Return a topological order of tasks as an iterable."""
143
- self .prepare ()
144
- while self .is_active ():
145
- new_task = self .get_ready ()[0 ]
146
- yield new_task
147
- self .done (new_task )
149
+ self ._nodes_done .update (nodes )
148
150
149
151
150
152
def _extract_priorities_from_tasks (tasks : list [PTask ]) -> dict [str , int ]:
0 commit comments