diff --git a/pympipool/shared/executor.py b/pympipool/shared/executor.py index 4d1865f1..4d1688c6 100644 --- a/pympipool/shared/executor.py +++ b/pympipool/shared/executor.py @@ -520,11 +520,17 @@ def _update_futures_in_input(args: tuple, kwargs: dict): Returns: tuple, dict: arguments and keyword arguments with each future object in them being evaluated """ - args = [arg if not isinstance(arg, Future) else arg.result() for arg in args] - kwargs = { - key: value if not isinstance(value, Future) else value.result() - for key, value in kwargs.items() - } + + def get_result(arg): + if isinstance(arg, Future): + return arg.result() + elif isinstance(arg, list): + return [get_result(arg=el) for el in arg] + else: + return arg + + args = [get_result(arg=arg) for arg in args] + kwargs = {key: get_result(arg=value) for key, value in kwargs.items()} return args, kwargs @@ -539,9 +545,17 @@ def _get_future_objects_from_input(task_dict: dict): Returns: list, boolean: list of future objects and boolean flag if all future objects are already done """ - future_lst = [arg for arg in task_dict["args"] if isinstance(arg, Future)] + [ - value for value in task_dict["kwargs"].values() if isinstance(value, Future) - ] + future_lst = [] + + def find_future_in_list(lst): + for el in lst: + if isinstance(el, Future): + future_lst.append(el) + elif isinstance(el, list): + find_future_in_list(lst=el) + + find_future_in_list(lst=task_dict["args"]) + find_future_in_list(lst=task_dict["kwargs"].values()) boolean_flag = len([future for future in future_lst if future.done()]) == len( future_lst ) diff --git a/pympipool/shared/plot.py b/pympipool/shared/plot.py index 8c9ccaa6..3ed3fe6a 100644 --- a/pympipool/shared/plot.py +++ b/pympipool/shared/plot.py @@ -6,51 +6,60 @@ def generate_nodes_and_edges( task_hash_dict: dict, future_hash_inverse_dict: dict -) -> Tuple[list]: +) -> Tuple[list, list]: node_lst, edge_lst = [], [] hash_id_dict = {} + + def add_element(arg, link_to, label=""): + if isinstance(arg, Future): + edge_lst.append( + { + "start": hash_id_dict[future_hash_inverse_dict[arg]], + "end": link_to, + "label": label, + } + ) + elif isinstance(arg, list) and all([isinstance(a, Future) for a in arg]): + for a in arg: + add_element(arg=a, link_to=link_to, label=label) + else: + node_id = len(node_lst) + node_lst.append({"name": str(arg), "id": node_id, "shape": "circle"}) + edge_lst.append({"start": node_id, "end": link_to, "label": label}) + for k, v in task_hash_dict.items(): hash_id_dict[k] = len(node_lst) - node_lst.append({"name": v["fn"].__name__, "id": hash_id_dict[k]}) + node_lst.append( + {"name": v["fn"].__name__, "id": hash_id_dict[k], "shape": "box"} + ) for k, task_dict in task_hash_dict.items(): for arg in task_dict["args"]: - if not isinstance(arg, Future): - node_id = len(node_lst) - node_lst.append({"name": str(arg), "id": node_id}) - edge_lst.append({"start": node_id, "end": hash_id_dict[k], "label": ""}) - else: - edge_lst.append( - { - "start": hash_id_dict[future_hash_inverse_dict[arg]], - "end": hash_id_dict[k], - "label": "", - } - ) + add_element(arg=arg, link_to=hash_id_dict[k], label="") + for kw, v in task_dict["kwargs"].items(): - if not isinstance(v, Future): - node_id = len(node_lst) - node_lst.append({"name": str(v), "id": node_id}) - edge_lst.append( - {"start": node_id, "end": hash_id_dict[k], "label": str(kw)} - ) - else: - edge_lst.append( - { - "start": hash_id_dict[future_hash_inverse_dict[v]], - "end": hash_id_dict[k], - "label": str(kw), - } - ) + add_element(arg=v, link_to=hash_id_dict[k], label=str(kw)) + return node_lst, edge_lst def generate_task_hash(task_dict: dict, future_hash_inverse_dict: dict) -> bytes: + def convert_arg(arg, future_hash_inverse_dict): + if isinstance(arg, Future): + return future_hash_inverse_dict[arg] + elif isinstance(arg, list): + return [ + convert_arg(arg=a, future_hash_inverse_dict=future_hash_inverse_dict) + for a in arg + ] + else: + return arg + args_for_hash = [ - arg if not isinstance(arg, Future) else future_hash_inverse_dict[arg] + convert_arg(arg=arg, future_hash_inverse_dict=future_hash_inverse_dict) for arg in task_dict["args"] ] kwargs_for_hash = { - k: v if not isinstance(v, Future) else future_hash_inverse_dict[v] + k: convert_arg(arg=v, future_hash_inverse_dict=future_hash_inverse_dict) for k, v in task_dict["kwargs"].items() } return cloudpickle.dumps( @@ -65,7 +74,7 @@ def draw(node_lst: list, edge_lst: list): graph = nx.DiGraph() for node in node_lst: - graph.add_node(node["id"], label=node["name"]) + graph.add_node(node["id"], label=node["name"], shape=node["shape"]) for edge in edge_lst: graph.add_edge(edge["start"], edge["end"], label=edge["label"]) svg = nx.nx_agraph.to_agraph(graph).draw(prog="dot", format="svg") diff --git a/tests/test_dependencies_executor.py b/tests/test_dependencies_executor.py index 833a1425..9ec50088 100644 --- a/tests/test_dependencies_executor.py +++ b/tests/test_dependencies_executor.py @@ -24,6 +24,21 @@ def add_function(parameter_1, parameter_2): return parameter_1 + parameter_2 +def generate_tasks(length): + sleep(0.2) + return range(length) + + +def calc_from_lst(lst, ind, parameter): + sleep(0.2) + return lst[ind] + parameter + + +def merge(lst): + sleep(0.2) + return sum(lst) + + class TestExecutorWithDependencies(unittest.TestCase): def test_executor(self): with Executor(max_cores=1, backend="local", hostname_localhost=True) as exe: @@ -34,7 +49,7 @@ def test_executor(self): @unittest.skipIf( skip_graphviz_test, - "graphviz is not installed, so the plot_dependency_graph test is skipped.", + "graphviz is not installed, so the plot_dependency_graph tests are skipped.", ) def test_executor_dependency_plot(self): with Executor( @@ -108,3 +123,81 @@ def test_dependency_steps(self): self.assertTrue(fs1.done()) self.assertTrue(fs2.done()) q.put({"shutdown": True, "wait": True}) + + def test_many_to_one(self): + length = 5 + parameter = 1 + with Executor(max_cores=2, backend="local", hostname_localhost=True) as exe: + cloudpickle_register(ind=1) + future_lst = exe.submit( + generate_tasks, + length=length, + resource_dict={"cores": 1}, + ) + lst = [] + for i in range(length): + lst.append( + exe.submit( + calc_from_lst, + lst=future_lst, + ind=i, + parameter=parameter, + resource_dict={"cores": 1}, + ) + ) + future_sum = exe.submit( + merge, + lst=lst, + resource_dict={"cores": 1}, + ) + self.assertEqual(future_sum.result(), 15) + + @unittest.skipIf( + skip_graphviz_test, + "graphviz is not installed, so the plot_dependency_graph tests are skipped.", + ) + def test_many_to_one_plot(self): + length = 5 + parameter = 1 + with Executor( + max_cores=2, + backend="local", + hostname_localhost=True, + plot_dependency_graph=True, + ) as exe: + cloudpickle_register(ind=1) + future_lst = exe.submit( + generate_tasks, + length=length, + resource_dict={"cores": 1}, + ) + lst = [] + for i in range(length): + lst.append( + exe.submit( + calc_from_lst, + lst=future_lst, + ind=i, + parameter=parameter, + resource_dict={"cores": 1}, + ) + ) + future_sum = exe.submit( + merge, + lst=lst, + resource_dict={"cores": 1}, + ) + self.assertTrue(future_lst.done()) + for l in lst: + self.assertTrue(l.done()) + self.assertTrue(future_sum.done()) + self.assertEqual(len(exe._future_hash_dict), 7) + self.assertEqual(len(exe._task_hash_dict), 7) + nodes, edges = generate_nodes_and_edges( + task_hash_dict=exe._task_hash_dict, + future_hash_inverse_dict={ + v: k for k, v in exe._future_hash_dict.items() + }, + ) + self.assertEqual(len(nodes), 18) + self.assertEqual(len(edges), 21)