diff --git a/executorlib/interactive/create.py b/executorlib/interactive/create.py index 4c357e47..016174a8 100644 --- a/executorlib/interactive/create.py +++ b/executorlib/interactive/create.py @@ -83,14 +83,15 @@ def create_executor( of the individual function. init_function (None): optional function to preset arguments for functions which are submitted later """ - check_init_function(block_allocation=block_allocation, init_function=init_function) if flux_executor is not None and backend != "flux_allocation": backend = "flux_allocation" - check_pmi(backend=backend, pmi=flux_executor_pmi_mode) - cores_per_worker = resource_dict.get("cores", 1) - resource_dict["cache_directory"] = cache_directory - resource_dict["hostname_localhost"] = hostname_localhost if backend == "flux_allocation": + check_init_function( + block_allocation=block_allocation, init_function=init_function + ) + check_pmi(backend=backend, pmi=flux_executor_pmi_mode) + resource_dict["cache_directory"] = cache_directory + resource_dict["hostname_localhost"] = hostname_localhost check_oversubscribe( oversubscribe=resource_dict.get("openmpi_oversubscribe", False) ) @@ -100,40 +101,41 @@ def create_executor( return create_flux_allocation_executor( max_workers=max_workers, max_cores=max_cores, - cores_per_worker=cores_per_worker, + cache_directory=cache_directory, resource_dict=resource_dict, flux_executor=flux_executor, flux_executor_pmi_mode=flux_executor_pmi_mode, flux_executor_nesting=flux_executor_nesting, flux_log_files=flux_log_files, + hostname_localhost=hostname_localhost, block_allocation=block_allocation, init_function=init_function, ) elif backend == "slurm_allocation": + check_pmi(backend=backend, pmi=flux_executor_pmi_mode) check_executor(executor=flux_executor) check_nested_flux_executor(nested_flux_executor=flux_executor_nesting) check_flux_log_files(flux_log_files=flux_log_files) return create_slurm_allocation_executor( max_workers=max_workers, max_cores=max_cores, - cores_per_worker=cores_per_worker, + cache_directory=cache_directory, resource_dict=resource_dict, + hostname_localhost=hostname_localhost, block_allocation=block_allocation, init_function=init_function, ) elif backend == "local": + check_pmi(backend=backend, pmi=flux_executor_pmi_mode) check_executor(executor=flux_executor) check_nested_flux_executor(nested_flux_executor=flux_executor_nesting) check_flux_log_files(flux_log_files=flux_log_files) - check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0)) - check_command_line_argument_lst( - command_line_argument_lst=resource_dict.get("slurm_cmd_args", []) - ) return create_local_executor( max_workers=max_workers, max_cores=max_cores, - cores_per_worker=cores_per_worker, + cache_directory=cache_directory, resource_dict=resource_dict, + hostname_localhost=hostname_localhost, block_allocation=block_allocation, init_function=init_function, ) @@ -146,15 +148,25 @@ def create_executor( def create_flux_allocation_executor( max_workers: Optional[int] = None, max_cores: Optional[int] = None, - cores_per_worker: int = 1, + cache_directory: Optional[str] = None, resource_dict: dict = {}, flux_executor=None, flux_executor_pmi_mode: Optional[str] = None, flux_executor_nesting: bool = False, flux_log_files: bool = False, + hostname_localhost: Optional[bool] = None, block_allocation: bool = False, init_function: Optional[Callable] = None, ) -> Union[InteractiveStepExecutor, InteractiveExecutor]: + check_init_function(block_allocation=block_allocation, init_function=init_function) + check_pmi(backend="flux_allocation", pmi=flux_executor_pmi_mode) + cores_per_worker = resource_dict.get("cores", 1) + resource_dict["cache_directory"] = cache_directory + resource_dict["hostname_localhost"] = hostname_localhost + check_oversubscribe(oversubscribe=resource_dict.get("openmpi_oversubscribe", False)) + check_command_line_argument_lst( + command_line_argument_lst=resource_dict.get("slurm_cmd_args", []) + ) if "openmpi_oversubscribe" in resource_dict.keys(): del resource_dict["openmpi_oversubscribe"] if "slurm_cmd_args" in resource_dict.keys(): @@ -193,11 +205,16 @@ def create_flux_allocation_executor( def create_slurm_allocation_executor( max_workers: Optional[int] = None, max_cores: Optional[int] = None, - cores_per_worker: int = 1, + cache_directory: Optional[str] = None, resource_dict: dict = {}, + hostname_localhost: Optional[bool] = None, block_allocation: bool = False, init_function: Optional[Callable] = None, ) -> Union[InteractiveStepExecutor, InteractiveExecutor]: + check_init_function(block_allocation=block_allocation, init_function=init_function) + cores_per_worker = resource_dict.get("cores", 1) + resource_dict["cache_directory"] = cache_directory + resource_dict["hostname_localhost"] = hostname_localhost if block_allocation: resource_dict["init_function"] = init_function max_workers = validate_number_of_cores( @@ -228,11 +245,21 @@ def create_slurm_allocation_executor( def create_local_executor( max_workers: Optional[int] = None, max_cores: Optional[int] = None, - cores_per_worker: int = 1, + cache_directory: Optional[str] = None, resource_dict: dict = {}, + hostname_localhost: Optional[bool] = None, block_allocation: bool = False, init_function: Optional[Callable] = None, ) -> Union[InteractiveStepExecutor, InteractiveExecutor]: + check_init_function(block_allocation=block_allocation, init_function=init_function) + cores_per_worker = resource_dict.get("cores", 1) + resource_dict["cache_directory"] = cache_directory + resource_dict["hostname_localhost"] = hostname_localhost + + check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0)) + check_command_line_argument_lst( + command_line_argument_lst=resource_dict.get("slurm_cmd_args", []) + ) if "threads_per_core" in resource_dict.keys(): del resource_dict["threads_per_core"] if "gpus_per_core" in resource_dict.keys():