diff --git a/Lib/concurrent/futures/__init__.py b/Lib/concurrent/futures/__init__.py index 292e886d5a88ac..062430dd80b9f9 100644 --- a/Lib/concurrent/futures/__init__.py +++ b/Lib/concurrent/futures/__init__.py @@ -30,6 +30,7 @@ 'as_completed', 'ProcessPoolExecutor', 'ThreadPoolExecutor', + 'InterpreterPoolExecutor', ) @@ -38,7 +39,7 @@ def __dir__(): def __getattr__(name): - global ProcessPoolExecutor, ThreadPoolExecutor + global ProcessPoolExecutor, ThreadPoolExecutor, InterpreterPoolExecutor if name == 'ProcessPoolExecutor': from .process import ProcessPoolExecutor as pe @@ -50,4 +51,9 @@ def __getattr__(name): ThreadPoolExecutor = te return te + if name == 'InterpreterPoolExecutor': + from .interpreter import InterpreterPoolExecutor as ie + InterpreterPoolExecutor = ie + return ie + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/Lib/concurrent/futures/interpreter.py b/Lib/concurrent/futures/interpreter.py new file mode 100644 index 00000000000000..7d1cdb433c034e --- /dev/null +++ b/Lib/concurrent/futures/interpreter.py @@ -0,0 +1,71 @@ + +"""Implements InterpreterPoolExecutor.""" + +import concurrent.futures.thread as _thread +import pickle +from test.support import interpreters + + +class InterpreterPoolExecutor(_thread.ThreadPoolExecutor): + + @classmethod + def _normalize_initializer(cls, initializer, initargs): + shared, initargs = initargs + if initializer is None: + if shared: + def initializer(ctx, *ignored): + interp, _ = ctx + interp.prepare_main(**shared) + return initializer, initargs + + pickled = pickle.dumps((initializer, initargs)) + def initializer(ctx): + interp, _ = ctx + if shared: + interp.prepare_main(**shared) + interp.exec(f'''if True: + import pickle + initializer, initargs = pickle.loads({pickled!r}) + initializer(*initargs) + ''') + return initializer, () + + @classmethod + def _normalize_task(cls, fn, args, kwargs): + if isinstance(fn, str): + if args or kwargs: + raise ValueError(f'a script does not take args or kwargs, got {args!r} and {kwargs!r}') + script = fn + def wrapped(ctx): + interp, _ = ctx + interp.exec(script) + return None + else: + pickled = pickle.dumps((fn, args, kwargs)) + def wrapped(ctx): + interp, results = ctx + interp.exec(f'''if True: + import pickle + fn, args, kwargs = pickle.loads({pickled!r}) + res = fn(*args, **kwargs) + _interp_pool_executor_results.put(res) + ''') + return results.get_nowait() + return wrapped, (), {} + + @classmethod + def _run_worker(cls, *args): + interp = interpreters.create() + interp.exec('import test.support.interpreters.queues') + results = interpreters.create_queue() + interp.prepare_main(_interp_pool_executor_results=results) + ctx = (interp, results) + try: + _thread._worker(ctx, *args) + finally: + interp.close() + + def __init__(self, max_workers=None, thread_name_prefix='', + initializer=None, initargs=(), shared=None): + initargs = (shared, initargs) + super().__init__(max_workers, thread_name_prefix, initializer, initargs) diff --git a/Lib/concurrent/futures/thread.py b/Lib/concurrent/futures/thread.py index a024033f35fb54..d3164c988418b8 100644 --- a/Lib/concurrent/futures/thread.py +++ b/Lib/concurrent/futures/thread.py @@ -50,12 +50,12 @@ def __init__(self, future, fn, args, kwargs): self.args = args self.kwargs = kwargs - def run(self): + def run(self, ctx=None): if not self.future.set_running_or_notify_cancel(): return try: - result = self.fn(*self.args, **self.kwargs) + result = self.fn(ctx, *self.args, **self.kwargs) except BaseException as exc: self.future.set_exception(exc) # Break a reference cycle with the exception 'exc' @@ -66,10 +66,10 @@ def run(self): __class_getitem__ = classmethod(types.GenericAlias) -def _worker(executor_reference, work_queue, initializer, initargs): +def _worker(ctx, executor_reference, work_queue, initializer, initargs): if initializer is not None: try: - initializer(*initargs) + initializer(ctx, *initargs) except BaseException: _base.LOGGER.critical('Exception in initializer:', exc_info=True) executor = executor_reference() @@ -89,7 +89,7 @@ def _worker(executor_reference, work_queue, initializer, initargs): work_item = work_queue.get(block=True) if work_item is not None: - work_item.run() + work_item.run(ctx) # Delete references to object. See GH-60488 del work_item continue @@ -123,6 +123,26 @@ class ThreadPoolExecutor(_base.Executor): # Used to assign unique thread names when thread_name_prefix is not supplied. _counter = itertools.count().__next__ + @classmethod + def _normalize_initializer(cls, initializer, initargs): + if initializer is None: + return None, () + actual = initializer + def initializer(ctx, *args): + actual(*args) + return initializer, initargs + + @classmethod + def _normalize_task(cls, fn, args, kwargs): + def wrapped(ctx, *args, **kwargs): + return fn(*args, **kwargs) + return wrapped, args, kwargs + + @classmethod + def _run_worker(cls, *args): + ctx = None + return _worker(ctx, *args) + def __init__(self, max_workers=None, thread_name_prefix='', initializer=None, initargs=()): """Initializes a new ThreadPoolExecutor instance. @@ -158,8 +178,9 @@ def __init__(self, max_workers=None, thread_name_prefix='', self._shutdown_lock = threading.Lock() self._thread_name_prefix = (thread_name_prefix or ("ThreadPoolExecutor-%d" % self._counter())) - self._initializer = initializer - self._initargs = initargs + (self._initializer, + self._initargs + ) = type(self)._normalize_initializer(initializer, initargs) def submit(self, fn, /, *args, **kwargs): with self._shutdown_lock, _global_shutdown_lock: @@ -173,6 +194,7 @@ def submit(self, fn, /, *args, **kwargs): 'interpreter shutdown') f = _base.Future() + fn, args, kwargs = type(self)._normalize_task(fn, args, kwargs) w = _WorkItem(f, fn, args, kwargs) self._work_queue.put(w) @@ -194,7 +216,8 @@ def weakref_cb(_, q=self._work_queue): if num_threads < self._max_workers: thread_name = '%s_%d' % (self._thread_name_prefix or self, num_threads) - t = threading.Thread(name=thread_name, target=_worker, + t = threading.Thread(name=thread_name, + target=type(self)._run_worker, args=(weakref.ref(self, weakref_cb), self._work_queue, self._initializer, diff --git a/Lib/test/test_concurrent_futures/executor.py b/Lib/test/test_concurrent_futures/executor.py index 6a79fe69ec37cf..52fa4e924b1aa9 100644 --- a/Lib/test/test_concurrent_futures/executor.py +++ b/Lib/test/test_concurrent_futures/executor.py @@ -26,6 +26,10 @@ def make_dummy_object(_): class ExecutorTest: # Executor.shutdown() and context manager usage is tested by # ExecutorShutdownTest. + + def _assertRaises(self, exctype, *args, **kwargs): + return self.assertRaises(exctype, *args, **kwargs) + def test_submit(self): future = self.executor.submit(pow, 2, 8) self.assertEqual(256, future.result()) @@ -53,7 +57,7 @@ def test_map_exception(self): i = self.executor.map(divmod, [1, 1, 1, 1], [2, 3, 0, 5]) self.assertEqual(i.__next__(), (0, 1)) self.assertEqual(i.__next__(), (0, 1)) - self.assertRaises(ZeroDivisionError, i.__next__) + self._assertRaises(ZeroDivisionError, i.__next__) @support.requires_resource('walltime') def test_map_timeout(self): diff --git a/Lib/test/test_concurrent_futures/test_interpreter_pool.py b/Lib/test/test_concurrent_futures/test_interpreter_pool.py new file mode 100644 index 00000000000000..1bd359750b40db --- /dev/null +++ b/Lib/test/test_concurrent_futures/test_interpreter_pool.py @@ -0,0 +1,104 @@ +import contextlib +from test.support import interpreters +import os +import unittest + +from concurrent.futures import InterpreterPoolExecutor +from .executor import ExecutorTest, mul +from .util import BaseTestCase, InterpreterPoolMixin, setup_module + + +def wait_for_token(queue): + while True: + try: + queue.get(timeout=0.1) + except interpreters.QueueEmpty: + continue + break + + +def log_n_wait(args): + logqueue, waitqueue, ident = args + logqueue.put_nowait(f"{ident=} started") + try: + wait_for_token(waitqueue) + finally: + logqueue.put_nowait(f"{ident=} stopped") + + +class InterpreterPoolExecutorTest(InterpreterPoolMixin, ExecutorTest, BaseTestCase): + + def _assertRaises(self, exctype, *args, **kwargs): + return self.assertRaises(interpreters.ExecutionFailed, *args, **kwargs) + + def test_default_workers(self): + executor = InterpreterPoolExecutor() + expected = min(32, (os.process_cpu_count() or 1) + 4) + self.assertEqual(executor._max_workers, expected) + + def test_saturation(self): + executor = InterpreterPoolExecutor(max_workers=4) + waitqueue = interpreters.create_queue(syncobj=True) + + for i in range(15 * executor._max_workers): + executor.submit(wait_for_token, waitqueue) + self.assertEqual(len(executor._threads), executor._max_workers) + for i in range(15 * executor._max_workers): + waitqueue.put(None) + executor.shutdown(wait=True) + + def test_idle_thread_reuse(self): + executor = InterpreterPoolExecutor() + executor.submit(mul, 21, 2).result() + executor.submit(mul, 6, 7).result() + executor.submit(mul, 3, 14).result() + self.assertEqual(len(executor._threads), 1) + executor.shutdown(wait=True) + + def test_executor_map_current_future_cancel(self): + logqueue = interpreters.create_queue(syncobj=True) + waitqueue = interpreters.create_queue(syncobj=True) + idents = ['first', 'second', 'third'] + _idents = iter(idents) + + with InterpreterPoolExecutor(max_workers=1) as pool: + # submit work to saturate the pool + fut = pool.submit(log_n_wait, (logqueue, waitqueue, next(_idents))) + try: + with contextlib.closing( + pool.map(log_n_wait, + [(logqueue, waitqueue, ident) for ident in _idents], + timeout=0) + ) as gen: + with self.assertRaises(TimeoutError): + next(gen) + finally: + for i, ident in enumerate(idents, 1): + waitqueue.put_nowait(None) + #for ident in idents: + # waitqueue.put_nowait(None) + fut.result() + + # When the pool shuts down (due to the context manager), + # each worker's interpreter will be finalized. When that + # happens, each item in the queue owned by each finalizing + # interpreter will be removed. Thus, we must copy the + # items *before* leaving the context manager. + # XXX This can be surprising. Perhaps give users + # the option to keep objects beyond the original interpreter? + assert not logqueue.empty(), logqueue.qsize() + log = [] + while not logqueue.empty(): + log.append(logqueue.get_nowait()) + + # ident='second' is cancelled as a result of raising a TimeoutError + # ident='third' is cancelled because it remained in the collection of futures + self.assertListEqual(log, ["ident='first' started", "ident='first' stopped"]) + + +def setUpModule(): + setup_module() + + +if __name__ == "__main__": + unittest.main() diff --git a/Lib/test/test_concurrent_futures/util.py b/Lib/test/test_concurrent_futures/util.py index 3e855031913042..b941129d9b5091 100644 --- a/Lib/test/test_concurrent_futures/util.py +++ b/Lib/test/test_concurrent_futures/util.py @@ -74,6 +74,10 @@ class ThreadPoolMixin(ExecutorMixin): executor_type = futures.ThreadPoolExecutor +class InterpreterPoolMixin(ExecutorMixin): + executor_type = futures.InterpreterPoolExecutor + + class ProcessPoolForkMixin(ExecutorMixin): executor_type = futures.ProcessPoolExecutor ctx = "fork"