Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions executorlib/interactive/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,15 @@ def create_executor(
of the individual function.
init_function (None): optional function to preset arguments for functions which are submitted later
"""
check_init_function(block_allocation=block_allocation, init_function=init_function)
if flux_executor is not None and backend != "flux_allocation":
backend = "flux_allocation"
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
if backend == "flux_allocation":
check_init_function(
block_allocation=block_allocation, init_function=init_function
)
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
check_oversubscribe(
oversubscribe=resource_dict.get("openmpi_oversubscribe", False)
)
Expand All @@ -100,40 +101,41 @@ def create_executor(
return create_flux_allocation_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
flux_executor=flux_executor,
flux_executor_pmi_mode=flux_executor_pmi_mode,
flux_executor_nesting=flux_executor_nesting,
flux_log_files=flux_log_files,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
elif backend == "slurm_allocation":
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
check_flux_log_files(flux_log_files=flux_log_files)
return create_slurm_allocation_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
elif backend == "local":
check_pmi(backend=backend, pmi=flux_executor_pmi_mode)
check_executor(executor=flux_executor)
check_nested_flux_executor(nested_flux_executor=flux_executor_nesting)
check_flux_log_files(flux_log_files=flux_log_files)
check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
return create_local_executor(
max_workers=max_workers,
max_cores=max_cores,
cores_per_worker=cores_per_worker,
cache_directory=cache_directory,
resource_dict=resource_dict,
hostname_localhost=hostname_localhost,
block_allocation=block_allocation,
init_function=init_function,
)
Expand All @@ -146,15 +148,25 @@ def create_executor(
def create_flux_allocation_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Fix mutable default arguments in function signatures.

Using mutable default arguments (resource_dict = {}) can lead to unexpected behavior if the dictionary is modified between function calls, as the same dictionary instance is shared across all calls.

Replace the mutable defaults with None and initialize within the functions:

def create_flux_allocation_executor(
    max_workers: Optional[int] = None,
    max_cores: Optional[int] = None,
    cache_directory: Optional[str] = None,
-   resource_dict: dict = {},
+   resource_dict: Optional[dict] = None,
    flux_executor=None,
    ...
):
+   if resource_dict is None:
+       resource_dict = {}

def create_slurm_allocation_executor(
    max_workers: Optional[int] = None,
    max_cores: Optional[int] = None,
    cache_directory: Optional[str] = None,
-   resource_dict: dict = {},
+   resource_dict: Optional[dict] = None,
    ...
):
+   if resource_dict is None:
+       resource_dict = {}

def create_local_executor(
    max_workers: Optional[int] = None,
    max_cores: Optional[int] = None,
    cache_directory: Optional[str] = None,
-   resource_dict: dict = {},
+   resource_dict: Optional[dict] = None,
    ...
):
+   if resource_dict is None:
+       resource_dict = {}

Also applies to: 208-208, 248-248

resource_dict: dict = {},
flux_executor=None,
flux_executor_pmi_mode: Optional[str] = None,
flux_executor_nesting: bool = False,
flux_log_files: bool = False,
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
check_pmi(backend="flux_allocation", pmi=flux_executor_pmi_mode)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
check_oversubscribe(oversubscribe=resource_dict.get("openmpi_oversubscribe", False))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
if "openmpi_oversubscribe" in resource_dict.keys():
del resource_dict["openmpi_oversubscribe"]
if "slurm_cmd_args" in resource_dict.keys():
Expand Down Expand Up @@ -193,11 +205,16 @@ def create_flux_allocation_executor(
def create_slurm_allocation_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
resource_dict: dict = {},
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost
if block_allocation:
resource_dict["init_function"] = init_function
max_workers = validate_number_of_cores(
Expand Down Expand Up @@ -228,11 +245,21 @@ def create_slurm_allocation_executor(
def create_local_executor(
max_workers: Optional[int] = None,
max_cores: Optional[int] = None,
cores_per_worker: int = 1,
cache_directory: Optional[str] = None,
resource_dict: dict = {},
hostname_localhost: Optional[bool] = None,
block_allocation: bool = False,
init_function: Optional[Callable] = None,
) -> Union[InteractiveStepExecutor, InteractiveExecutor]:
check_init_function(block_allocation=block_allocation, init_function=init_function)
cores_per_worker = resource_dict.get("cores", 1)
resource_dict["cache_directory"] = cache_directory
resource_dict["hostname_localhost"] = hostname_localhost

check_gpus_per_worker(gpus_per_worker=resource_dict.get("gpus_per_core", 0))
check_command_line_argument_lst(
command_line_argument_lst=resource_dict.get("slurm_cmd_args", [])
)
if "threads_per_core" in resource_dict.keys():
del resource_dict["threads_per_core"]
if "gpus_per_core" in resource_dict.keys():
Expand Down
Loading