diff --git a/executorlib/task_scheduler/interactive/blockallocation.py b/executorlib/task_scheduler/interactive/blockallocation.py index 8cde6535..7955899d 100644 --- a/executorlib/task_scheduler/interactive/blockallocation.py +++ b/executorlib/task_scheduler/interactive/blockallocation.py @@ -240,8 +240,10 @@ def _execute_multiple_tasks( future_queue.join() break elif "fn" in task_dict and "future" in task_dict: + f = task_dict.pop("future") execute_task_dict( task_dict=task_dict, + future_obj=f, interface=interface, cache_directory=cache_directory, cache_key=cache_key, diff --git a/executorlib/task_scheduler/interactive/onetoone.py b/executorlib/task_scheduler/interactive/onetoone.py index b3ffddbd..3b631565 100644 --- a/executorlib/task_scheduler/interactive/onetoone.py +++ b/executorlib/task_scheduler/interactive/onetoone.py @@ -1,4 +1,5 @@ import queue +from concurrent.futures import Future from threading import Thread from typing import Optional @@ -186,6 +187,7 @@ def _wrap_execute_task_in_separate_process( dictionary containing the future objects and the number of cores they require """ resource_dict = task_dict.pop("resource_dict").copy() + f = task_dict.pop("future") if "cores" not in resource_dict or ( resource_dict["cores"] == 1 and executor_kwargs["cores"] >= 1 ): @@ -197,7 +199,7 @@ def _wrap_execute_task_in_separate_process( max_cores=max_cores, max_workers=max_workers, ) - active_task_dict[task_dict["future"]] = slots_required + active_task_dict[f] = slots_required task_kwargs = executor_kwargs.copy() task_kwargs.update(resource_dict) task_kwargs.update( @@ -205,6 +207,7 @@ def _wrap_execute_task_in_separate_process( "task_dict": task_dict, "spawner": spawner, "hostname_localhost": hostname_localhost, + "future_obj": f, } ) process = Thread( @@ -217,6 +220,7 @@ def _wrap_execute_task_in_separate_process( def _execute_task_in_thread( task_dict: dict, + future_obj: Future, cores: int = 1, spawner: type[BaseSpawner] = MpiExecSpawner, hostname_localhost: Optional[bool] = None, @@ -233,6 +237,7 @@ def _execute_task_in_thread( Args: task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} + future_obj (Future): A Future representing the given call. cores (int): defines the total number of MPI ranks to use spawner (BaseSpawner): Spawner to start process on selected compute resources hostname_localhost (boolean): use localhost instead of the hostname to establish the zmq connection. In the @@ -253,6 +258,7 @@ def _execute_task_in_thread( """ execute_task_dict( task_dict=task_dict, + future_obj=future_obj, interface=interface_bootup( command_lst=get_interactive_execute_command( cores=cores, diff --git a/executorlib/task_scheduler/interactive/shared.py b/executorlib/task_scheduler/interactive/shared.py index 8be9076f..68dd68d6 100644 --- a/executorlib/task_scheduler/interactive/shared.py +++ b/executorlib/task_scheduler/interactive/shared.py @@ -2,6 +2,7 @@ import os import queue import time +from concurrent.futures import Future from typing import Optional from executorlib.standalone.interactive.communication import SocketInterface @@ -10,6 +11,7 @@ def execute_task_dict( task_dict: dict, + future_obj: Future, interface: SocketInterface, cache_directory: Optional[str] = None, cache_key: Optional[str] = None, @@ -21,6 +23,7 @@ def execute_task_dict( Args: task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} + future_obj (Future): A Future representing the given call. interface (SocketInterface): socket interface for zmq communication cache_directory (str, optional): The directory to store cache files. Defaults to "executorlib_cache". cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be @@ -28,25 +31,37 @@ def execute_task_dict( error_log_file (str): Name of the error log file to use for storing exceptions raised by the Python functions submitted to the Executor. """ - if error_log_file is not None: - task_dict["error_log_file"] = error_log_file - if cache_directory is None: - _execute_task_without_cache(interface=interface, task_dict=task_dict) - else: - _execute_task_with_cache( - interface=interface, - task_dict=task_dict, - cache_directory=cache_directory, - cache_key=cache_key, - ) + if not future_obj.done() and future_obj.set_running_or_notify_cancel(): + if error_log_file is not None: + task_dict["error_log_file"] = error_log_file + if cache_directory is None: + _execute_task_without_cache( + interface=interface, task_dict=task_dict, future_obj=future_obj + ) + else: + _execute_task_with_cache( + interface=interface, + task_dict=task_dict, + cache_directory=cache_directory, + cache_key=cache_key, + future_obj=future_obj, + ) def task_done(future_queue: queue.Queue): + """ + Mark the current task as done in the current queue. + + Args: + future_queue (queue): Queue of task dictionaries waiting for execution. + """ with contextlib.suppress(ValueError): future_queue.task_done() -def _execute_task_without_cache(interface: SocketInterface, task_dict: dict): +def _execute_task_without_cache( + interface: SocketInterface, task_dict: dict, future_obj: Future +): """ Execute the task in the task_dict by communicating it via the interface. @@ -54,19 +69,19 @@ def _execute_task_without_cache(interface: SocketInterface, task_dict: dict): interface (SocketInterface): socket interface for zmq communication task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} + future_obj (Future): A Future representing the given call. """ - f = task_dict.pop("future") - if not f.done() and f.set_running_or_notify_cancel(): - try: - f.set_result(interface.send_and_receive_dict(input_dict=task_dict)) - except Exception as thread_exception: - interface.shutdown(wait=True) - f.set_exception(exception=thread_exception) + try: + future_obj.set_result(interface.send_and_receive_dict(input_dict=task_dict)) + except Exception as thread_exception: + interface.shutdown(wait=True) + future_obj.set_exception(exception=thread_exception) def _execute_task_with_cache( interface: SocketInterface, task_dict: dict, + future_obj: Future, cache_directory: str, cache_key: Optional[str] = None, ): @@ -77,6 +92,7 @@ def _execute_task_with_cache( interface (SocketInterface): socket interface for zmq communication task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} + future_obj (Future): A Future representing the given call. cache_directory (str): The directory to store cache files. cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be overwritten by setting the cache_key. @@ -92,19 +108,16 @@ def _execute_task_with_cache( ) file_name = os.path.abspath(os.path.join(cache_directory, task_key + "_o.h5")) if file_name not in get_cache_files(cache_directory=cache_directory): - f = task_dict.pop("future") - if f.set_running_or_notify_cancel(): - try: - time_start = time.time() - result = interface.send_and_receive_dict(input_dict=task_dict) - data_dict["output"] = result - data_dict["runtime"] = time.time() - time_start - dump(file_name=file_name, data_dict=data_dict) - f.set_result(result) - except Exception as thread_exception: - interface.shutdown(wait=True) - f.set_exception(exception=thread_exception) + try: + time_start = time.time() + result = interface.send_and_receive_dict(input_dict=task_dict) + data_dict["output"] = result + data_dict["runtime"] = time.time() - time_start + dump(file_name=file_name, data_dict=data_dict) + future_obj.set_result(result) + except Exception as thread_exception: + interface.shutdown(wait=True) + future_obj.set_exception(exception=thread_exception) else: _, _, result = get_output(file_name=file_name) - future = task_dict["future"] - future.set_result(result) + future_obj.set_result(result)