Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 42 additions & 9 deletions asgiref/current_thread_executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import queue
import sys
import threading
from concurrent.futures import Executor, Future
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

_T = TypeVar("_T")
_P = ParamSpec("_P")
_R = TypeVar("_R")


class _WorkItem:
Expand All @@ -9,13 +20,19 @@ class _WorkItem:
Copied from ThreadPoolExecutor (but it's private, so we're not going to rely on importing it)
"""

def __init__(self, future, fn, args, kwargs):
def __init__(
self,
future: "Future[_R]",
fn: Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
):
self.future = future
self.fn = fn
self.args = args
self.kwargs = kwargs

def run(self):
def run(self) -> None:
__traceback_hide__ = True # noqa: F841
if not self.future.set_running_or_notify_cancel():
return
Expand All @@ -24,7 +41,7 @@ def run(self):
except BaseException as exc:
self.future.set_exception(exc)
# Break a reference cycle with the exception 'exc'
self = None
self = None # type: ignore[assignment]
else:
self.future.set_result(result)

Expand All @@ -36,12 +53,12 @@ class CurrentThreadExecutor(Executor):
the thread they came from.
"""

def __init__(self):
def __init__(self) -> None:
self._work_thread = threading.current_thread()
self._work_queue = queue.Queue()
self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue()
self._broken = False

def run_until_future(self, future):
def run_until_future(self, future: "Future[Any]") -> None:
"""
Runs the code in the work queue until a result is available from the future.
Should be run from the thread the executor is initialised in.
Expand All @@ -60,12 +77,18 @@ def run_until_future(self, future):
work_item = self._work_queue.get()
if work_item is future:
return
assert isinstance(work_item, _WorkItem)
work_item.run()
del work_item
finally:
self._broken = True

def submit(self, fn, *args, **kwargs):
def _submit(
self,
fn: Callable[_P, _R],
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Future[_R]":
# Check they're not submitting from the same thread
if threading.current_thread() == self._work_thread:
raise RuntimeError(
Expand All @@ -75,8 +98,18 @@ def submit(self, fn, *args, **kwargs):
if self._broken:
raise RuntimeError("CurrentThreadExecutor already quit or is broken")
# Add to work queue
f = Future()
work_item = _WorkItem(f, fn, args, kwargs)
f: "Future[_R]" = Future()
work_item = _WorkItem(f, fn, *args, **kwargs)
self._work_queue.put(work_item)
# Return the future
return f

# Python 3.9+ has a new signature for submit with a "/" after `fn`, to enforce
# it to be a positional argument. If we ignore[override] mypy on 3.9+ will be
# happy but 3.7/3.8 will say that the ignore comment is unused, even when
# defining them differently based on sys.version_info.
# We should be able to remove this when we drop support for 3.7/3.8.
if not TYPE_CHECKING:

def submit(self, fn, *args, **kwargs):
return self._submit(fn, *args, **kwargs)
Loading