diff --git a/executorlib/standalone/command.py b/executorlib/standalone/command.py index 68af9abc..38a2980c 100644 --- a/executorlib/standalone/command.py +++ b/executorlib/standalone/command.py @@ -21,6 +21,8 @@ def get_cache_execute_command( file_name: str, cores: int = 1, backend: Optional[str] = None, + exclusive: bool = False, + openmpi_oversubscribe: bool = False, pmi_mode: Optional[str] = None, ) -> list: """ @@ -30,7 +32,9 @@ def get_cache_execute_command( file_name (str): The name of the file. cores (int, optional): Number of cores used to execute the task. Defaults to 1. backend (str, optional): name of the backend used to spawn tasks ["slurm", "flux"]. - pmi_mode (str): PMI interface to use (OpenMPI v5 requires pmix) default is None (Flux only) + exclusive (bool): Whether to exclusively reserve the compute nodes, or allow sharing compute notes. Defaults to False. + openmpi_oversubscribe (bool, optional): Whether to oversubscribe the cores. Defaults to False. + pmi_mode (str): PMI interface to use (OpenMPI v5 requires pmix) default is None Returns: list[str]: List of strings containing the python executable path and the backend script to execute @@ -47,6 +51,10 @@ def get_cache_execute_command( command_prepend = ["srun", "-n", str(cores)] if pmi_mode is not None: command_prepend += ["--mpi=" + pmi_mode] + if openmpi_oversubscribe: + command_prepend += ["--oversubscribe"] + if exclusive: + command_prepend += ["--exact"] command_lst = ( command_prepend + command_lst @@ -56,6 +64,14 @@ def get_cache_execute_command( flux_command = ["flux", "run"] if pmi_mode is not None: flux_command += ["-o", "pmi=" + pmi_mode] + if openmpi_oversubscribe: + raise ValueError( + "The option openmpi_oversubscribe is not available with the flux backend." + ) + if exclusive: + raise ValueError( + "The option exclusive is not available with the flux backend." + ) command_lst = ( flux_command + ["-n", str(cores)] diff --git a/executorlib/task_scheduler/file/shared.py b/executorlib/task_scheduler/file/shared.py index c712c863..68f99431 100644 --- a/executorlib/task_scheduler/file/shared.py +++ b/executorlib/task_scheduler/file/shared.py @@ -156,6 +156,10 @@ def execute_tasks_h5( file_name=file_name, cores=task_resource_dict["cores"], backend=backend, + exclusive=task_resource_dict.get("exclusive", False), + openmpi_oversubscribe=task_resource_dict.get( + "openmpi_oversubscribe", False + ), pmi_mode=pmi_mode, ), file_name=file_name, diff --git a/executorlib/task_scheduler/file/task_scheduler.py b/executorlib/task_scheduler/file/task_scheduler.py index 587b0e0a..620a2fcf 100644 --- a/executorlib/task_scheduler/file/task_scheduler.py +++ b/executorlib/task_scheduler/file/task_scheduler.py @@ -56,6 +56,8 @@ def __init__( "cores": 1, "cwd": None, "cache_directory": "executorlib_cache", + "exclusive": False, + "openmpi_oversubscribe": False, } if resource_dict is None: resource_dict = {} diff --git a/tests/test_standalone_command.py b/tests/test_standalone_command.py index d1bb55f1..3166cc3a 100644 --- a/tests/test_standalone_command.py +++ b/tests/test_standalone_command.py @@ -51,14 +51,16 @@ def test_get_cache_execute_command_parallel(self): self.assertEqual(output[3], sys.executable) self.assertEqual(output[4].split(os.sep)[-1], "cache_parallel.py") self.assertEqual(output[5], file_name) - output = get_cache_execute_command(cores=2, file_name=file_name, backend="slurm", pmi_mode="pmi2") + output = get_cache_execute_command(cores=2, file_name=file_name, backend="slurm", pmi_mode="pmi2", openmpi_oversubscribe=True, exclusive=True) self.assertEqual(output[0], "srun") self.assertEqual(output[1], "-n") self.assertEqual(output[2], str(2)) self.assertEqual(output[3], "--mpi=pmi2") - self.assertEqual(output[4], sys.executable) - self.assertEqual(output[5].split(os.sep)[-1], "cache_parallel.py") - self.assertEqual(output[6], file_name) + self.assertEqual(output[4], "--oversubscribe") + self.assertEqual(output[5], "--exact") + self.assertEqual(output[6], sys.executable) + self.assertEqual(output[7].split(os.sep)[-1], "cache_parallel.py") + self.assertEqual(output[8], file_name) output = get_cache_execute_command(cores=2, file_name=file_name, backend="slurm") self.assertEqual(output[0], "srun") self.assertEqual(output[1], "-n") @@ -86,3 +88,7 @@ def test_get_cache_execute_command_parallel(self): self.assertEqual(output[8], file_name) with self.assertRaises(ValueError): get_cache_execute_command(cores=2, file_name=file_name, backend="test") + with self.assertRaises(ValueError): + get_cache_execute_command(cores=2, file_name=file_name, backend="flux", openmpi_oversubscribe=True) + with self.assertRaises(ValueError): + get_cache_execute_command(cores=2, file_name=file_name, backend="flux", exclusive=True)