diff --git a/executorlib/backend/interactive_parallel.py b/executorlib/backend/interactive_parallel.py index 3d5aadcc..36880dc9 100644 --- a/executorlib/backend/interactive_parallel.py +++ b/executorlib/backend/interactive_parallel.py @@ -43,7 +43,7 @@ def main() -> None: host=argument_dict["host"], port=argument_dict["zmqport"] ) - memory = None + memory = {"executorlib_worker_id": int(argument_dict["worker_id"])} # required for flux interface - otherwise the current path is not included in the python path cwd = abspath(".") @@ -97,7 +97,7 @@ def main() -> None: and "args" in input_dict and "kwargs" in input_dict ): - memory = call_funct(input_dict=input_dict, funct=None) + memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory)) if __name__ == "__main__": diff --git a/executorlib/backend/interactive_serial.py b/executorlib/backend/interactive_serial.py index c72d95a0..859e2905 100644 --- a/executorlib/backend/interactive_serial.py +++ b/executorlib/backend/interactive_serial.py @@ -29,7 +29,7 @@ def main(argument_lst: Optional[list[str]] = None): host=argument_dict["host"], port=argument_dict["zmqport"] ) - memory = None + memory = {"executorlib_worker_id": int(argument_dict["worker_id"])} # required for flux interface - otherwise the current path is not included in the python path cwd = abspath(".") @@ -72,7 +72,7 @@ def main(argument_lst: Optional[list[str]] = None): and "args" in input_dict and "kwargs" in input_dict ): - memory = call_funct(input_dict=input_dict, funct=None) + memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory)) if __name__ == "__main__": diff --git a/executorlib/standalone/interactive/backend.py b/executorlib/standalone/interactive/backend.py index 2fb4fdd6..53d014c5 100644 --- a/executorlib/standalone/interactive/backend.py +++ b/executorlib/standalone/interactive/backend.py @@ -48,8 +48,9 @@ def parse_arguments(argument_lst: list[str]) -> dict: argument_dict={ "zmqport": "--zmqport", "host": "--host", + "worker_id": "--worker-id", }, - default_dict={"host": "localhost"}, + default_dict={"host": "localhost", "worker_id": 0}, ) diff --git a/executorlib/standalone/interactive/communication.py b/executorlib/standalone/interactive/communication.py index 5d0c6711..9cc0cc68 100644 --- a/executorlib/standalone/interactive/communication.py +++ b/executorlib/standalone/interactive/communication.py @@ -136,6 +136,7 @@ def interface_bootup( connections, hostname_localhost: Optional[bool] = None, log_obj_size: bool = False, + worker_id: Optional[int] = None, ) -> SocketInterface: """ Start interface for ZMQ communication @@ -152,6 +153,8 @@ def interface_bootup( this look up for security reasons. So on MacOS it is required to set this option to true log_obj_size (boolean): Enable debug mode which reports the size of the communicated objects. + worker_id (int): Communicate the worker which ID was assigned to it for future reference and resource + distribution. Returns: executorlib.shared.communication.SocketInterface: socket interface for zmq communication @@ -165,6 +168,8 @@ def interface_bootup( "--host", gethostname(), ] + if worker_id is not None: + command_lst += ["--worker-id", str(worker_id)] interface = SocketInterface( spawner=connections, log_obj_size=log_obj_size, diff --git a/executorlib/task_scheduler/interactive/blockallocation.py b/executorlib/task_scheduler/interactive/blockallocation.py index b417c655..96cec2c1 100644 --- a/executorlib/task_scheduler/interactive/blockallocation.py +++ b/executorlib/task_scheduler/interactive/blockallocation.py @@ -65,9 +65,9 @@ def __init__( process=[ Thread( target=execute_tasks, - kwargs=executor_kwargs, + kwargs=executor_kwargs | {"worker_id": worker_id}, ) - for _ in range(self._max_workers) + for worker_id in range(self._max_workers) ], ) diff --git a/executorlib/task_scheduler/interactive/shared.py b/executorlib/task_scheduler/interactive/shared.py index 55d75fba..bdd1e2de 100644 --- a/executorlib/task_scheduler/interactive/shared.py +++ b/executorlib/task_scheduler/interactive/shared.py @@ -25,6 +25,7 @@ def execute_tasks( queue_join_on_shutdown: bool = True, log_obj_size: bool = False, error_log_file: Optional[str] = None, + worker_id: Optional[int] = None, **kwargs, ) -> None: """ @@ -49,6 +50,8 @@ def execute_tasks( log_obj_size (bool): Enable debug mode which reports the size of the communicated objects. error_log_file (str): Name of the error log file to use for storing exceptions raised by the Python functions submitted to the Executor. + worker_id (int): Communicate the worker which ID was assigned to it for future reference and resource + distribution. """ interface = interface_bootup( command_lst=get_interactive_execute_command( @@ -57,6 +60,7 @@ def execute_tasks( connections=spawner(cores=cores, **kwargs), hostname_localhost=hostname_localhost, log_obj_size=log_obj_size, + worker_id=worker_id, ) if init_function is not None: interface.send_dict( diff --git a/tests/test_singlenodeexecutor_noblock.py b/tests/test_singlenodeexecutor_noblock.py index 03f21ef6..0872359a 100644 --- a/tests/test_singlenodeexecutor_noblock.py +++ b/tests/test_singlenodeexecutor_noblock.py @@ -1,4 +1,5 @@ import unittest +from time import sleep from executorlib import SingleNodeExecutor from executorlib.standalone.serialize import cloudpickle_register @@ -12,6 +13,15 @@ def resource_dict(resource_dict): return resource_dict +def get_worker_id(executorlib_worker_id): + sleep(0.1) + return executorlib_worker_id + + +def init_function(): + return {"a": 1, "b": 2} + + class TestExecutorBackend(unittest.TestCase): def test_meta_executor_serial_with_dependencies(self): with SingleNodeExecutor( @@ -75,3 +85,58 @@ def test_errors(self): block_allocation=True, ) as exe: exe.submit(resource_dict, resource_dict={}) + + +class TestWorkerID(unittest.TestCase): + def test_block_allocation_True(self): + with SingleNodeExecutor( + max_cores=1, + block_allocation=True, + ) as exe: + worker_id = exe.submit(get_worker_id, resource_dict={}).result() + self.assertEqual(worker_id, 0) + + def test_block_allocation_True_two_workers(self): + with SingleNodeExecutor( + max_cores=2, + block_allocation=True, + ) as exe: + f1_worker_id = exe.submit(get_worker_id, resource_dict={}) + f2_worker_id = exe.submit(get_worker_id, resource_dict={}) + self.assertEqual(sum([f1_worker_id.result(), f2_worker_id.result()]), 1) + + def test_init_function(self): + with SingleNodeExecutor( + max_cores=1, + block_allocation=True, + init_function=init_function, + ) as exe: + worker_id = exe.submit(get_worker_id, resource_dict={}).result() + self.assertEqual(worker_id, 0) + + def test_init_function_two_workers(self): + with SingleNodeExecutor( + max_cores=2, + block_allocation=True, + init_function=init_function, + ) as exe: + f1_worker_id = exe.submit(get_worker_id, resource_dict={}) + f2_worker_id = exe.submit(get_worker_id, resource_dict={}) + self.assertEqual(sum([f1_worker_id.result(), f2_worker_id.result()]), 1) + + def test_block_allocation_False(self): + with SingleNodeExecutor( + max_cores=1, + block_allocation=False, + ) as exe: + worker_id = exe.submit(get_worker_id, resource_dict={}).result() + self.assertEqual(worker_id, 0) + + def test_block_allocation_False_two_workers(self): + with SingleNodeExecutor( + max_cores=2, + block_allocation=False, + ) as exe: + f1_worker_id = exe.submit(get_worker_id, resource_dict={}) + f2_worker_id = exe.submit(get_worker_id, resource_dict={}) + self.assertEqual(sum([f1_worker_id.result(), f2_worker_id.result()]), 0) \ No newline at end of file diff --git a/tests/test_standalone_interactive_backend.py b/tests/test_standalone_interactive_backend.py index c2306cae..226cdb26 100644 --- a/tests/test_standalone_interactive_backend.py +++ b/tests/test_standalone_interactive_backend.py @@ -11,6 +11,7 @@ class TestParser(unittest.TestCase): def test_command_local(self): result_dict = { "host": "localhost", + "worker_id": 0, "zmqport": "22", } command_lst = [ @@ -35,6 +36,7 @@ def test_command_local(self): def test_command_slurm(self): result_dict = { "host": "127.0.0.1", + "worker_id": 0, "zmqport": "22", } command_lst = [ @@ -76,6 +78,7 @@ def test_command_slurm(self): def test_command_slurm_user_command(self): result_dict = { "host": "127.0.0.1", + "worker_id": 0, "zmqport": "22", } command_lst = [