diff --git a/executorlib/interactive/shared.py b/executorlib/interactive/shared.py index 60279aa0..4ce64cce 100644 --- a/executorlib/interactive/shared.py +++ b/executorlib/interactive/shared.py @@ -483,6 +483,8 @@ def get_result(arg: Union[list[Future], Future]) -> Any: return arg.result() elif isinstance(arg, list): return [get_result(arg=el) for el in arg] + elif isinstance(arg, dict): + return {k: get_result(arg=v) for k, v in arg.items()} else: return arg @@ -510,6 +512,8 @@ def find_future_in_list(lst): future_lst.append(el) elif isinstance(el, list): find_future_in_list(lst=el) + elif isinstance(el, dict): + find_future_in_list(lst=el.values()) find_future_in_list(lst=task_dict["args"]) find_future_in_list(lst=task_dict["kwargs"].values()) diff --git a/executorlib/standalone/plot.py b/executorlib/standalone/plot.py index 7f58f0a6..320df56d 100644 --- a/executorlib/standalone/plot.py +++ b/executorlib/standalone/plot.py @@ -39,9 +39,28 @@ def add_element(arg, link_to, label=""): "label": label, } ) - elif isinstance(arg, list) and all(isinstance(a, Future) for a in arg): - for a in arg: - add_element(arg=a, link_to=link_to, label=label) + elif isinstance(arg, list) and any(isinstance(a, Future) for a in arg): + lst_no_future = [a if not isinstance(a, Future) else "$" for a in arg] + node_id = len(node_lst) + node_lst.append( + {"name": str(lst_no_future), "id": node_id, "shape": "circle"} + ) + edge_lst.append({"start": node_id, "end": link_to, "label": label}) + for i, a in enumerate(arg): + if isinstance(a, Future): + add_element(arg=a, link_to=node_id, label="ind: " + str(i)) + elif isinstance(arg, dict) and any(isinstance(a, Future) for a in arg.values()): + dict_no_future = { + kt: vt if not isinstance(vt, Future) else "$" for kt, vt in arg.items() + } + node_id = len(node_lst) + node_lst.append( + {"name": str(dict_no_future), "id": node_id, "shape": "circle"} + ) + edge_lst.append({"start": node_id, "end": link_to, "label": label}) + for kt, vt in arg.items(): + if isinstance(vt, Future): + add_element(arg=vt, link_to=node_id, label="key: " + kt) else: node_id = len(node_lst) node_lst.append({"name": str(arg), "id": node_id, "shape": "circle"}) @@ -92,6 +111,11 @@ def convert_arg(arg, future_hash_inverse_dict): convert_arg(arg=a, future_hash_inverse_dict=future_hash_inverse_dict) for a in arg ] + elif isinstance(arg, dict): + return { + k: convert_arg(arg=v, future_hash_inverse_dict=future_hash_inverse_dict) + for k, v in arg.items() + } else: return arg diff --git a/tests/test_dependencies_executor.py b/tests/test_dependencies_executor.py index aaf21a7d..0aa1d835 100644 --- a/tests/test_dependencies_executor.py +++ b/tests/test_dependencies_executor.py @@ -38,6 +38,10 @@ def merge(lst): return sum(lst) +def return_input_dict(input_dict): + return input_dict + + def raise_error(): raise RuntimeError @@ -130,6 +134,14 @@ def test_many_to_one(self): ) self.assertEqual(future_sum.result(), 15) + def test_future_input_dict(self): + with SingleNodeExecutor() as exe: + fs = exe.submit( + return_input_dict, + input_dict={"a": exe.submit(sum, [2, 2])}, + ) + self.assertEqual(fs.result()["a"], 4) + class TestExecutorErrors(unittest.TestCase): def test_block_allocation_false_one_worker(self): diff --git a/tests/test_plot_dependency.py b/tests/test_plot_dependency.py index 3b79f960..8b8f4741 100644 --- a/tests/test_plot_dependency.py +++ b/tests/test_plot_dependency.py @@ -39,6 +39,10 @@ def merge(lst): return sum(lst) +def return_input_dict(input_dict): + return input_dict + + @unittest.skipIf( skip_graphviz_test, "graphviz is not installed, so the plot_dependency_graph tests are skipped.", @@ -124,8 +128,25 @@ def test_many_to_one_plot(self): v: k for k, v in exe._future_hash_dict.items() }, ) - self.assertEqual(len(nodes), 18) - self.assertEqual(len(edges), 21) + self.assertEqual(len(nodes), 19) + self.assertEqual(len(edges), 22) + + def test_future_input_dict(self): + with SingleNodeExecutor(plot_dependency_graph=True) as exe: + exe.submit( + return_input_dict, + input_dict={"a": exe.submit(sum, [2, 2])}, + ) + self.assertEqual(len(exe._future_hash_dict), 2) + self.assertEqual(len(exe._task_hash_dict), 2) + nodes, edges = generate_nodes_and_edges( + task_hash_dict=exe._task_hash_dict, + future_hash_inverse_dict={ + v: k for k, v in exe._future_hash_dict.items() + }, + ) + self.assertEqual(len(nodes), 4) + self.assertEqual(len(edges), 3) @unittest.skipIf( @@ -197,8 +218,8 @@ def test_many_to_one_plot(self): v: k for k, v in exe._future_hash_dict.items() }, ) - self.assertEqual(len(nodes), 18) - self.assertEqual(len(edges), 21) + self.assertEqual(len(nodes), 19) + self.assertEqual(len(edges), 22) @unittest.skipIf( @@ -266,5 +287,5 @@ def test_many_to_one_plot(self): v: k for k, v in exe._future_hash_dict.items() }, ) - self.assertEqual(len(nodes), 18) - self.assertEqual(len(edges), 21) + self.assertEqual(len(nodes), 19) + self.assertEqual(len(edges), 22) diff --git a/tests/test_plot_dependency_flux.py b/tests/test_plot_dependency_flux.py index 75c06f2b..7c1e2e58 100644 --- a/tests/test_plot_dependency_flux.py +++ b/tests/test_plot_dependency_flux.py @@ -106,8 +106,8 @@ def test_many_to_one_plot(self): v: k for k, v in exe._future_hash_dict.items() }, ) - self.assertEqual(len(nodes), 18) - self.assertEqual(len(edges), 21) + self.assertEqual(len(nodes), 19) + self.assertEqual(len(edges), 22) @unittest.skipIf( @@ -175,5 +175,5 @@ def test_many_to_one_plot(self): v: k for k, v in exe._future_hash_dict.items() }, ) - self.assertEqual(len(nodes), 18) - self.assertEqual(len(edges), 21) + self.assertEqual(len(nodes), 19) + self.assertEqual(len(edges), 22)