Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion executorlib/backend/interactive_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,23 @@ def main() -> None:
and "args" in input_dict
and "kwargs" in input_dict
):
memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory))
try:
memory.update(
call_funct(input_dict=input_dict, funct=None, memory=memory)
)
except Exception as error:
if mpi_rank_zero:
interface_send(
socket=socket,
result_dict={"error": error},
)
backend_write_error_file(
error=error,
apply_dict=input_dict,
)
else:
if mpi_rank_zero:
interface_send(socket=socket, result_dict={"result": True})
Comment on lines +100 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Bug: init failures on non-root ranks are ignored; client may receive success

If any non-root rank raises during init, rank 0 still sends {"result": True}. This leaves ranks with divergent memory and hides failures.

Aggregate errors across ranks and only acknowledge success if all ranks succeeded. Example patch:

-            try:
-                memory.update(
-                    call_funct(input_dict=input_dict, funct=None, memory=memory)
-                )
-            except Exception as error:
-                if mpi_rank_zero:
-                    interface_send(
-                        socket=socket,
-                        result_dict={"error": error},
-                    )
-                    backend_write_error_file(
-                        error=error,
-                        apply_dict=input_dict,
-                    )
-            else:
-                if mpi_rank_zero:
-                    interface_send(socket=socket, result_dict={"result": True})
+            # Run init on all ranks and consolidate errors on rank 0
+            local_error = None
+            local_update = None
+            try:
+                local_update = call_funct(input_dict=input_dict, funct=None, memory=memory)
+            except Exception as error:  # noqa: BLE001
+                local_error = error
+
+            from mpi4py import MPI  # already available; keeps scope explicit
+            # Did any rank fail?
+            any_failed = MPI.COMM_WORLD.allreduce(bool(local_error), op=MPI.LOR)
+
+            if any_failed:
+                # Gather first non-None error to rank 0 and report
+                errors = MPI.COMM_WORLD.gather(local_error, root=0)
+                if mpi_rank_zero:
+                    first_error = next((e for e in errors if e is not None), None)
+                    interface_send(socket=socket, result_dict={"error": first_error})
+                    backend_write_error_file(error=first_error, apply_dict=input_dict)
+            else:
+                # All ranks succeeded: update memory consistently and acknowledge
+                memory.update(local_update)  # same update on all ranks
+                if mpi_rank_zero:
+                    interface_send(socket=socket, result_dict={"result": True})
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
memory.update(
call_funct(input_dict=input_dict, funct=None, memory=memory)
)
except Exception as error:
if mpi_rank_zero:
interface_send(
socket=socket,
result_dict={"error": error},
)
backend_write_error_file(
error=error,
apply_dict=input_dict,
)
else:
if mpi_rank_zero:
interface_send(socket=socket, result_dict={"result": True})
# Run init on all ranks and consolidate errors on rank 0
local_error = None
local_update = None
try:
local_update = call_funct(input_dict=input_dict, funct=None, memory=memory)
except Exception as error: # noqa: BLE001
local_error = error
from mpi4py import MPI # already available; keeps scope explicit
# Did any rank fail?
any_failed = MPI.COMM_WORLD.allreduce(bool(local_error), op=MPI.LOR)
if any_failed:
# Gather first non-None error to rank 0 and report
errors = MPI.COMM_WORLD.gather(local_error, root=0)
if mpi_rank_zero:
first_error = next((e for e in errors if e is not None), None)
interface_send(socket=socket, result_dict={"error": first_error})
backend_write_error_file(error=first_error, apply_dict=input_dict)
else:
# All ranks succeeded: update memory consistently and acknowledge
memory.update(local_update) # same update on all ranks
if mpi_rank_zero:
interface_send(socket=socket, result_dict={"result": True})
🧰 Tools
🪛 Ruff (0.12.2)

104-104: Do not catch blind exception: Exception

(BLE001)



if __name__ == "__main__":
Expand Down
16 changes: 15 additions & 1 deletion executorlib/backend/interactive_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,21 @@ def main(argument_lst: Optional[list[str]] = None):
and "args" in input_dict
and "kwargs" in input_dict
):
memory.update(call_funct(input_dict=input_dict, funct=None, memory=memory))
try:
memory.update(
call_funct(input_dict=input_dict, funct=None, memory=memory)
)
except Exception as error:
interface_send(
socket=socket,
result_dict={"error": error},
)
backend_write_error_file(
error=error,
apply_dict=input_dict,
)
else:
interface_send(socket=socket, result_dict={"result": True})


if __name__ == "__main__":
Expand Down
29 changes: 18 additions & 11 deletions executorlib/task_scheduler/interactive/blockallocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,14 @@ def _execute_multiple_tasks(
log_obj_size=log_obj_size,
worker_id=worker_id,
)
interface_initialization_exception = None
if init_function is not None:
interface.send_dict(
input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}}
)
try:
_ = interface.send_and_receive_dict(
input_dict={"init": True, "fn": init_function, "args": (), "kwargs": {}}
)
except Exception as init_exception:
interface_initialization_exception = init_exception
while True:
task_dict = future_queue.get()
if "shutdown" in task_dict and task_dict["shutdown"]:
Expand All @@ -240,12 +244,15 @@ def _execute_multiple_tasks(
break
elif "fn" in task_dict and "future" in task_dict:
f = task_dict.pop("future")
execute_task_dict(
task_dict=task_dict,
future_obj=f,
interface=interface,
cache_directory=cache_directory,
cache_key=cache_key,
error_log_file=error_log_file,
)
if interface_initialization_exception is not None:
f.set_exception(exception=interface_initialization_exception)
else:
execute_task_dict(
task_dict=task_dict,
future_obj=f,
interface=interface,
cache_directory=cache_directory,
cache_key=cache_key,
error_log_file=error_log_file,
)
Comment on lines +247 to +257
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Respect cancellation before setting exception to avoid InvalidStateError

Setting an exception on a cancelled Future can raise InvalidStateError and break the worker loop. Mirror execute_task_dict’s pattern.

-            if interface_initialization_exception is not None:
-                f.set_exception(exception=interface_initialization_exception)
+            if interface_initialization_exception is not None:
+                if (not f.done()) and f.set_running_or_notify_cancel():
+                    f.set_exception(interface_initialization_exception)
             else:
                 execute_task_dict(
                     task_dict=task_dict,
                     future_obj=f,
                     interface=interface,
                     cache_directory=cache_directory,
                     cache_key=cache_key,
                     error_log_file=error_log_file,
                 )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if interface_initialization_exception is not None:
f.set_exception(exception=interface_initialization_exception)
else:
execute_task_dict(
task_dict=task_dict,
future_obj=f,
interface=interface,
cache_directory=cache_directory,
cache_key=cache_key,
error_log_file=error_log_file,
)
if interface_initialization_exception is not None:
# Only set the exception if the future wasn't already done/cancelled
if (not f.done()) and f.set_running_or_notify_cancel():
f.set_exception(interface_initialization_exception)
else:
execute_task_dict(
task_dict=task_dict,
future_obj=f,
interface=interface,
cache_directory=cache_directory,
cache_key=cache_key,
error_log_file=error_log_file,
)
🤖 Prompt for AI Agents
In executorlib/task_scheduler/interactive/blockallocation.py around lines 248 to
258, avoid calling f.set_exception on a Future that may already be cancelled
(which can raise InvalidStateError); before setting the exception, check if
f.cancelled() (or otherwise mirror the guard used in execute_task_dict) and only
call set_exception when the future is not cancelled, otherwise skip setting the
exception (or handle via the same cancellation/ignore path used by
execute_task_dict).

task_done(future_queue=future_queue)
32 changes: 32 additions & 0 deletions tests/test_backend_interactive_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def set_global():
return {"j": 5}


def raise_error():
raise ValueError("interface error")


def submit(socket):
socket.send(
cloudpickle.dumps({"init": True, "fn": set_global, "args": (), "kwargs": {}})
Expand All @@ -31,6 +35,14 @@ def submit_error(socket):
socket.send(cloudpickle.dumps({"shutdown": True, "wait": True}))


def submit_init_error(socket):
socket.send(
cloudpickle.dumps({"init": True, "fn": raise_error, "args": (), "kwargs": {}})
)
socket.send(cloudpickle.dumps({"fn": calc, "args": (), "kwargs": {"i": 2}}))
socket.send(cloudpickle.dumps({"shutdown": True, "wait": True}))


class TestSerial(unittest.TestCase):
def test_main_as_thread(self):
context = zmq.Context()
Expand All @@ -39,6 +51,7 @@ def test_main_as_thread(self):
t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]})
t.start()
submit(socket=socket)
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": 7})
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
socket.close()
Expand All @@ -51,6 +64,23 @@ def test_main_as_thread_error(self):
t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]})
t.start()
submit_error(socket=socket)
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
self.assertEqual(
str(type(cloudpickle.loads(socket.recv())["error"])), "<class 'TypeError'>"
)
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
socket.close()
context.term()

def test_main_as_thread_init_error(self):
context = zmq.Context()
socket = context.socket(zmq.PAIR)
port = socket.bind_to_random_port("tcp://*")
t = Thread(target=main, kwargs={"argument_lst": ["--zmqport", str(port)]})
t.start()
submit_init_error(socket=socket)
self.assertEqual(
str(type(cloudpickle.loads(socket.recv())["error"])), "<class 'ValueError'>")
self.assertEqual(
str(type(cloudpickle.loads(socket.recv())["error"])), "<class 'TypeError'>"
)
Expand All @@ -65,6 +95,7 @@ def test_submit_as_thread(self):
t = Thread(target=submit, kwargs={"socket": socket})
t.start()
main(argument_lst=["--zmqport", str(port)])
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": 7})
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
socket.close()
Expand All @@ -77,6 +108,7 @@ def test_submit_as_thread_error(self):
t = Thread(target=submit_error, kwargs={"socket": socket})
t.start()
main(argument_lst=["--zmqport", str(port)])
self.assertEqual(cloudpickle.loads(socket.recv()), {"result": True})
self.assertEqual(
str(type(cloudpickle.loads(socket.recv())["error"])), "<class 'TypeError'>"
)
Expand Down
Loading