diff --git a/asgiref/current_thread_executor.py b/asgiref/current_thread_executor.py index 0a9317ee..c9a97dd5 100644 --- a/asgiref/current_thread_executor.py +++ b/asgiref/current_thread_executor.py @@ -1,6 +1,17 @@ import queue +import sys import threading from concurrent.futures import Executor, Future +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +_T = TypeVar("_T") +_P = ParamSpec("_P") +_R = TypeVar("_R") class _WorkItem: @@ -9,13 +20,19 @@ class _WorkItem: Copied from ThreadPoolExecutor (but it's private, so we're not going to rely on importing it) """ - def __init__(self, future, fn, args, kwargs): + def __init__( + self, + future: "Future[_R]", + fn: Callable[_P, _R], + *args: _P.args, + **kwargs: _P.kwargs, + ): self.future = future self.fn = fn self.args = args self.kwargs = kwargs - def run(self): + def run(self) -> None: __traceback_hide__ = True # noqa: F841 if not self.future.set_running_or_notify_cancel(): return @@ -24,7 +41,7 @@ def run(self): except BaseException as exc: self.future.set_exception(exc) # Break a reference cycle with the exception 'exc' - self = None + self = None # type: ignore[assignment] else: self.future.set_result(result) @@ -36,12 +53,12 @@ class CurrentThreadExecutor(Executor): the thread they came from. """ - def __init__(self): + def __init__(self) -> None: self._work_thread = threading.current_thread() - self._work_queue = queue.Queue() + self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue() self._broken = False - def run_until_future(self, future): + def run_until_future(self, future: "Future[Any]") -> None: """ Runs the code in the work queue until a result is available from the future. Should be run from the thread the executor is initialised in. @@ -60,12 +77,18 @@ def run_until_future(self, future): work_item = self._work_queue.get() if work_item is future: return + assert isinstance(work_item, _WorkItem) work_item.run() del work_item finally: self._broken = True - def submit(self, fn, *args, **kwargs): + def _submit( + self, + fn: Callable[_P, _R], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> "Future[_R]": # Check they're not submitting from the same thread if threading.current_thread() == self._work_thread: raise RuntimeError( @@ -75,8 +98,18 @@ def submit(self, fn, *args, **kwargs): if self._broken: raise RuntimeError("CurrentThreadExecutor already quit or is broken") # Add to work queue - f = Future() - work_item = _WorkItem(f, fn, args, kwargs) + f: "Future[_R]" = Future() + work_item = _WorkItem(f, fn, *args, **kwargs) self._work_queue.put(work_item) # Return the future return f + + # Python 3.9+ has a new signature for submit with a "/" after `fn`, to enforce + # it to be a positional argument. If we ignore[override] mypy on 3.9+ will be + # happy but 3.7/3.8 will say that the ignore comment is unused, even when + # defining them differently based on sys.version_info. + # We should be able to remove this when we drop support for 3.7/3.8. + if not TYPE_CHECKING: + + def submit(self, fn, *args, **kwargs): + return self._submit(fn, *args, **kwargs) diff --git a/asgiref/sync.py b/asgiref/sync.py index 6ba55dc3..455fed00 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -9,21 +9,48 @@ import warnings import weakref from concurrent.futures import Future, ThreadPoolExecutor -from typing import Any, Callable, Dict, Optional, overload +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Coroutine, + Dict, + Generic, + List, + Optional, + TypeVar, + Union, + overload, +) from .current_thread_executor import CurrentThreadExecutor from .local import Local +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec -def _restore_context(context): +if TYPE_CHECKING: + # This is not available to import at runtime + from _typeshed import OptExcInfo + +_F = TypeVar("_F", bound=Callable[..., Any]) +_P = ParamSpec("_P") +_R = TypeVar("_R") + + +def _restore_context(context: contextvars.Context) -> None: # Check for changes in contextvars, and set them to the current # context for downstream consumers for cvar in context: + cvalue = context.get(cvar) try: - if cvar.get() != context.get(cvar): - cvar.set(context.get(cvar)) + if cvar.get() != cvalue: + cvar.set(cvalue) except LookupError: - cvar.set(context.get(cvar)) + cvar.set(cvalue) # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for @@ -32,29 +59,25 @@ def _restore_context(context): # Until 3.12 is the minimum supported Python version, provide a shim. # Django 4.0 only supports 3.8+, so don't concern with the _or_partial backport. -# Type hint: should be generic: whatever T it takes it returns. (Same id) -def markcoroutinefunction(func: Any) -> Any: - if hasattr(inspect, "markcoroutinefunction"): - return inspect.markcoroutinefunction(func) - else: +if hasattr(inspect, "markcoroutinefunction"): + iscoroutinefunction = inspect.iscoroutinefunction + markcoroutinefunction: Callable[[_F], _F] = inspect.markcoroutinefunction +else: + iscoroutinefunction = asyncio.iscoroutinefunction # type: ignore[assignment] + + def markcoroutinefunction(func: _F) -> _F: func._is_coroutine = asyncio.coroutines._is_coroutine # type: ignore return func -def iscoroutinefunction(func: Any) -> bool: - if hasattr(inspect, "markcoroutinefunction"): - return inspect.iscoroutinefunction(func) - else: - return asyncio.iscoroutinefunction(func) +if sys.version_info >= (3, 8): + _iscoroutinefunction_or_partial = iscoroutinefunction +else: - -def _iscoroutinefunction_or_partial(func: Any) -> bool: - # Python < 3.8 does not correctly determine partially wrapped - # coroutine functions are coroutine functions, hence the need for - # this to exist. Code taken from CPython. - if sys.version_info >= (3, 8): - return iscoroutinefunction(func) - else: + def _iscoroutinefunction_or_partial(func: Any) -> bool: + # Python < 3.8 does not correctly determine partially wrapped + # coroutine functions are coroutine functions, hence the need for + # this to exist. Code taken from CPython. while inspect.ismethod(func): func = func.__func__ while isinstance(func, functools.partial): @@ -104,7 +127,7 @@ async def __aexit__(self, exc, value, tb): SyncToAsync.thread_sensitive_context.reset(self.token) -class AsyncToSync: +class AsyncToSync(Generic[_P, _R]): """ Utility class which turns an awaitable that only works on the thread with the event loop into a synchronous callable that works in a subthread. @@ -128,7 +151,14 @@ class AsyncToSync: # inside create_task, we'll look it up here from the running event loop. loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {} - def __init__(self, awaitable, force_new_loop=False): + def __init__( + self, + awaitable: Union[ + Callable[_P, Coroutine[Any, Any, _R]], + Callable[_P, Awaitable[_R]], + ], + force_new_loop: bool = False, + ): if not callable(awaitable) or ( not _iscoroutinefunction_or_partial(awaitable) and not _iscoroutinefunction_or_partial( @@ -142,7 +172,7 @@ def __init__(self, awaitable, force_new_loop=False): ) self.awaitable = awaitable try: - self.__self__ = self.awaitable.__self__ + self.__self__ = self.awaitable.__self__ # type: ignore[union-attr] except AttributeError: pass if force_new_loop: @@ -166,7 +196,7 @@ def __init__(self, awaitable, force_new_loop=False): else: self.main_event_loop = None - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: __traceback_hide__ = True # noqa: F841 # You can't call AsyncToSync from a thread with a running event loop @@ -186,7 +216,7 @@ def __call__(self, *args, **kwargs): context = [contextvars.copy_context()] # Make a future for the return information - call_result = Future() + call_result: "Future[_R]" = Future() # Get the source thread source_thread = threading.current_thread() # Make a CurrentThreadExecutor we'll use to idle in this thread - we @@ -204,7 +234,12 @@ def __call__(self, *args, **kwargs): # in this thread. try: awaitable = self.main_wrap( - args, kwargs, call_result, source_thread, sys.exc_info(), context + call_result, + source_thread, + sys.exc_info(), + context, + *args, + **kwargs, ) if not (self.main_event_loop and self.main_event_loop.is_running()): @@ -277,7 +312,7 @@ async def gather(): loop.close() asyncio.set_event_loop(self.main_event_loop) - def __get__(self, parent, objtype): + def __get__(self, parent: Any, objtype: Any) -> Callable[_P, _R]: """ Include self for methods """ @@ -285,8 +320,14 @@ def __get__(self, parent, objtype): return functools.update_wrapper(func, self.awaitable) async def main_wrap( - self, args, kwargs, call_result, source_thread, exc_info, context - ): + self, + call_result: "Future[_R]", + source_thread: threading.Thread, + exc_info: "OptExcInfo", + context: List[contextvars.Context], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: """ Wraps the awaitable with something that puts the result into the result/exception future. @@ -298,6 +339,7 @@ async def main_wrap( _restore_context(context[0]) current_task = SyncToAsync.get_current_task() + assert current_task is not None self.launch_map[current_task] = source_thread try: # If we have an exception, run the function inside the except block @@ -319,7 +361,7 @@ async def main_wrap( context[0] = contextvars.copy_context() -class SyncToAsync: +class SyncToAsync(Generic[_P, _R]): """ Utility class which turns a synchronous callable into an awaitable that runs in a threadpool. It also sets a threadlocal inside the thread so @@ -352,8 +394,8 @@ class SyncToAsync: # Maintain a contextvar for the current execution context. Optionally used # for thread sensitive mode. - thread_sensitive_context: "contextvars.ContextVar[str]" = contextvars.ContextVar( - "thread_sensitive_context" + thread_sensitive_context: "contextvars.ContextVar[ThreadSensitiveContext]" = ( + contextvars.ContextVar("thread_sensitive_context") ) # Contextvar that is used to detect if the single thread executor @@ -364,13 +406,13 @@ class SyncToAsync: # Maintaining a weak reference to the context ensures that thread pools are # erased once the context goes out of scope. This terminates the thread pool. - context_to_thread_executor: "weakref.WeakKeyDictionary[object, ThreadPoolExecutor]" = ( + context_to_thread_executor: "weakref.WeakKeyDictionary[ThreadSensitiveContext, ThreadPoolExecutor]" = ( weakref.WeakKeyDictionary() ) def __init__( self, - func: Callable[..., Any], + func: Callable[_P, _R], thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, ) -> None: @@ -392,7 +434,7 @@ def __init__( except AttributeError: pass - async def __call__(self, *args, **kwargs): + async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: __traceback_hide__ = True # noqa: F841 loop = asyncio.get_running_loop() @@ -431,12 +473,10 @@ async def __call__(self, *args, **kwargs): context = contextvars.copy_context() child = functools.partial(self.func, *args, **kwargs) func = context.run - args = (child,) - kwargs = {} try: # Run the code in the right thread - ret = await loop.run_in_executor( + ret: _R = await loop.run_in_executor( executor, functools.partial( self.thread_handler, @@ -444,8 +484,7 @@ async def __call__(self, *args, **kwargs): self.get_current_task(), sys.exc_info(), func, - *args, - **kwargs, + child, ), ) @@ -455,7 +494,9 @@ async def __call__(self, *args, **kwargs): return ret - def __get__(self, parent, objtype): + def __get__( + self, parent: Any, objtype: Any + ) -> Callable[_P, Coroutine[Any, Any, _R]]: """ Include self for methods """ @@ -502,7 +543,7 @@ def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs): del self.launch_map[current_thread] @staticmethod - def get_current_task(): + def get_current_task() -> Optional["asyncio.Task[Any]"]: """ Implementation of asyncio.current_task() that returns None if there is no task. @@ -519,27 +560,28 @@ def get_current_task(): @overload def sync_to_async( - func: None = None, + *, thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> Callable[[Callable[..., Any]], SyncToAsync]: +) -> Callable[[Callable[_P, _R]], SyncToAsync[_P, _R]]: ... @overload def sync_to_async( - func: Callable[..., Any], + func: Callable[_P, _R], + *, thread_sensitive: bool = True, executor: Optional["ThreadPoolExecutor"] = None, -) -> SyncToAsync: +) -> SyncToAsync[_P, _R]: ... def sync_to_async( - func=None, - thread_sensitive=True, - executor=None, -): + func: Optional[Callable[_P, _R]] = None, + thread_sensitive: bool = True, + executor: Optional["ThreadPoolExecutor"] = None, +) -> Union[Callable[[Callable[_P, _R]], SyncToAsync[_P, _R]], SyncToAsync[_P, _R]]: if func is None: return lambda f: SyncToAsync( f,