@@ -50,8 +50,7 @@ def execute_tasks_h5(
5050 future_queue : queue .Queue ,
5151 cache_directory : str ,
5252 execute_function : callable ,
53- cores_per_worker : int = 1 ,
54- cwd : Optional [str ] = None ,
53+ resource_dict : dict ,
5554 terminate_function : Optional [callable ] = None ,
5655 config_directory : Optional [str ] = None ,
5756 backend : Optional [str ] = None ,
@@ -62,9 +61,10 @@ def execute_tasks_h5(
6261 Args:
6362 future_queue (queue.Queue): The queue containing the tasks.
6463 cache_directory (str): The directory to store the HDF5 files.
65- cores_per_worker (int): The number of cores per worker.
64+ resource_dict (dict): A dictionary of resources required by the task. With the following keys:
65+ - cores (int): number of MPI cores to be used for each function call
66+ - cwd (str/None): current working directory where the parallel python task is executed
6667 execute_function (callable): The function to execute the tasks.
67- cwd (str/None): current working directory where the parallel python task is executed
6868 terminate_function (callable): The function to terminate the tasks.
6969 config_directory (str, optional): path to the config directory.
7070 backend (str, optional): name of the backend used to spawn tasks.
@@ -97,16 +97,15 @@ def execute_tasks_h5(
9797 memory_dict = memory_dict ,
9898 file_name_dict = file_name_dict ,
9999 )
100- resource_dict = task_dict ["resource_dict" ].copy ()
101- if "cores" not in resource_dict :
102- resource_dict ["cores" ] = cores_per_worker
103- if "cwd" not in resource_dict :
104- resource_dict ["cwd" ] = cwd
100+ task_resource_dict = task_dict ["resource_dict" ].copy ()
101+ task_resource_dict .update (
102+ {k : v for k , v in resource_dict .items () if k not in task_resource_dict }
103+ )
105104 task_key , data_dict = serialize_funct_h5 (
106105 fn = task_dict ["fn" ],
107106 fn_args = task_args ,
108107 fn_kwargs = task_kwargs ,
109- resource_dict = resource_dict ,
108+ resource_dict = task_resource_dict ,
110109 )
111110 if task_key not in memory_dict .keys ():
112111 if task_key + ".h5out" not in os .listdir (cache_directory ):
@@ -115,12 +114,12 @@ def execute_tasks_h5(
115114 process_dict [task_key ] = execute_function (
116115 command = _get_execute_command (
117116 file_name = file_name ,
118- cores = cores_per_worker ,
117+ cores = task_resource_dict [ "cores" ] ,
119118 ),
120119 task_dependent_lst = [
121120 process_dict [k ] for k in future_wait_key_lst
122121 ],
123- resource_dict = resource_dict ,
122+ resource_dict = task_resource_dict ,
124123 config_directory = config_directory ,
125124 backend = backend ,
126125 )
0 commit comments