diff --git a/executorlib/task_scheduler/base.py b/executorlib/task_scheduler/base.py index 36c46c21..a3e79d5b 100644 --- a/executorlib/task_scheduler/base.py +++ b/executorlib/task_scheduler/base.py @@ -6,6 +6,9 @@ from concurrent.futures import ( Future, ) +from concurrent.futures import ( + wait as wait_for_futures, +) from threading import Thread from typing import Callable, Optional, Union @@ -31,6 +34,7 @@ def __init__(self, max_cores: Optional[int] = None): self._max_cores = max_cores self._future_queue: Optional[queue.Queue] = queue.Queue() self._process: Optional[Union[Thread, list[Thread]]] = None + self._futures: set[Future] = set() @property def max_workers(self) -> Optional[int]: @@ -124,6 +128,7 @@ def submit( # type: ignore "resource_dict": resource_dict, } ) + self._futures.add(f) return f def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): @@ -143,11 +148,17 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): """ if cancel_futures and self._future_queue is not None: cancel_items_in_queue(que=self._future_queue) + if cancel_futures: + for f in self._futures: + f.cancel() + if wait: + wait_for_futures(self._futures) if self._process is not None and self._future_queue is not None: self._future_queue.put({"shutdown": True, "wait": wait}) if wait and isinstance(self._process, Thread): self._process.join() self._future_queue.join() + self._futures.clear() self._process = None self._future_queue = None