diff --git a/executorlib/backend/interactive_parallel.py b/executorlib/backend/interactive_parallel.py index 36880dc9..7f968391 100644 --- a/executorlib/backend/interactive_parallel.py +++ b/executorlib/backend/interactive_parallel.py @@ -97,7 +97,23 @@ def main() -> None: and "args" in input_dict and "kwargs" in input_dict ): - memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory)) + try: + memory.update( + call_funct(input_dict=input_dict, funct=None, memory=memory) + ) + except Exception as error: + if mpi_rank_zero: + interface_send( + socket=socket, + result_dict={"error": error}, + ) + backend_write_error_file( + error=error, + apply_dict=input_dict, + ) + else: + if mpi_rank_zero: + interface_send(socket=socket, result_dict={"result": True}) if __name__ == "__main__": diff --git a/executorlib/backend/interactive_serial.py b/executorlib/backend/interactive_serial.py index 859e2905..08eab5b9 100644 --- a/executorlib/backend/interactive_serial.py +++ b/executorlib/backend/interactive_serial.py @@ -72,7 +72,21 @@ def main(argument_lst: Optional[list[str]] = None): and "args" in input_dict and "kwargs" in input_dict ): - memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory)) + try: + memory.update( + call_funct(input_dict=input_dict, funct=None, memory=memory) + ) + except Exception as error: + interface_send( + socket=socket, + result_dict={"error": error}, + ) + backend_write_error_file( + error=error, + apply_dict=input_dict, + ) + else: + interface_send(socket=socket, result_dict={"result": True}) if __name__ == "__main__": diff --git a/executorlib/task_scheduler/interactive/blockallocation.py b/executorlib/task_scheduler/interactive/blockallocation.py index 005a2ffa..29a6bae3 100644 --- a/executorlib/task_scheduler/interactive/blockallocation.py +++ b/executorlib/task_scheduler/interactive/blockallocation.py @@ -226,10 +226,14 @@ def _execute_multiple_tasks( log_obj_size=log_obj_size, worker_id=worker_id, ) + interface_initialization_exception = None if init_function is not None: - interface.send_dict( - input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}} - ) + try: + _ = interface.send_and_receive_dict( + input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}} + ) + except Exception as init_exception: + interface_initialization_exception = init_exception while True: task_dict = future_queue.get() if "shutdown" in task_dict and task_dict["shutdown"]: @@ -240,12 +244,15 @@ def _execute_multiple_tasks( break elif "fn" in task_dict and "future" in task_dict: f = task_dict.pop("future") - execute_task_dict( - task_dict=task_dict, - future_obj=f, - interface=interface, - cache_directory=cache_directory, - cache_key=cache_key, - error_log_file=error_log_file, - ) + if interface_initialization_exception is not None: + f.set_exception(exception=interface_initialization_exception) + else: + execute_task_dict( + task_dict=task_dict, + future_obj=f, + interface=interface, + cache_directory=cache_directory, + cache_key=cache_key, + error_log_file=error_log_file, + ) task_done(future_queue=future_queue) diff --git a/tests/test_backend_interactive_serial.py b/tests/test_backend_interactive_serial.py index 6c9ee40d..a17cc2be 100644 --- a/tests/test_backend_interactive_serial.py +++ b/tests/test_backend_interactive_serial.py @@ -15,6 +15,10 @@ def set_global(): return {"j": 5} +def raise_error(): + raise ValueError("interface error") + + def submit(socket): socket.send( cloudpickle.dumps({"init": True, "fn": set_global, "args": (), "kwargs": {}}) @@ -31,6 +35,14 @@ def submit_error(socket): socket.send(cloudpickle.dumps({"shutdown": True, "wait": True})) +def submit_init_error(socket): + socket.send( + cloudpickle.dumps({"init": True, "fn": raise_error, "args": (), "kwargs": {}}) + ) + socket.send(cloudpickle.dumps({"fn": calc, "args": (), "kwargs": {"i": 2}})) + socket.send(cloudpickle.dumps({"shutdown": True, "wait": True})) + + class TestSerial(unittest.TestCase): def test_main_as_thread(self): context = zmq.Context() @@ -39,6 +51,7 @@ def test_main_as_thread(self): t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]}) t.start() submit(socket=socket) + self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) self.assertEqual(cloudpickle.loads(socket.recv()), {"result": 7}) self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) socket.close() @@ -51,6 +64,23 @@ def test_main_as_thread_error(self): t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]}) t.start() submit_error(socket=socket) + self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) + self.assertEqual( + str(type(cloudpickle.loads(socket.recv())["error"])), "" + ) + self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) + socket.close() + context.term() + + def test_main_as_thread_init_error(self): + context = zmq.Context() + socket = context.socket(zmq.PAIR) + port = socket.bind_to_random_port("tcp://*") + t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]}) + t.start() + submit_init_error(socket=socket) + self.assertEqual( + str(type(cloudpickle.loads(socket.recv())["error"])), "") self.assertEqual( str(type(cloudpickle.loads(socket.recv())["error"])), "" ) @@ -65,6 +95,7 @@ def test_submit_as_thread(self): t = Thread(target=submit, kwargs={"socket": socket}) t.start() main(argument_lst=["--zmqport", str(port)]) + self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) self.assertEqual(cloudpickle.loads(socket.recv()), {"result": 7}) self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) socket.close() @@ -77,6 +108,7 @@ def test_submit_as_thread_error(self): t = Thread(target=submit_error, kwargs={"socket": socket}) t.start() main(argument_lst=["--zmqport", str(port)]) + self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True}) self.assertEqual( str(type(cloudpickle.loads(socket.recv())["error"])), "" )