Skip to content

Commit 4765e3e

Browse files
authored
Preserve context vars for sync tasks (#483)
1 parent d01e152 commit 4765e3e

File tree

2 files changed

+71
-26
lines changed

2 files changed

+71
-26
lines changed

taskiq/receiver/receiver.py

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import asyncio
2+
import contextvars
3+
import functools
24
import inspect
35
from concurrent.futures import Executor
46
from logging import getLogger
57
from time import time
6-
from typing import Any, Callable, Dict, List, Optional, Set, Union, get_type_hints
8+
from typing import Any, Callable, Dict, Optional, Set, Union, get_type_hints
79

810
import anyio
911
from taskiq_dependencies import DependencyGraph
@@ -23,25 +25,6 @@
2325
QUEUE_DONE = b"-1"
2426

2527

26-
def _run_sync(
27-
target: Callable[..., Any],
28-
args: List[Any],
29-
kwargs: Dict[str, Any],
30-
) -> Any:
31-
"""
32-
Runs function synchronously.
33-
34-
We use this function, because
35-
we cannot pass kwargs in loop.run_with_executor().
36-
37-
:param target: function to execute.
38-
:param args: list of function's args.
39-
:param kwargs: dict of function's kwargs.
40-
:return: result of function's execution.
41-
"""
42-
return target(*args, **kwargs)
43-
44-
4528
class Receiver:
4629
"""Class that uses as a callback handler."""
4730

@@ -255,13 +238,13 @@ async def run_task( # noqa: C901, PLR0912, PLR0915
255238
else:
256239
is_coroutine = False
257240
# If this is a synchronous function, we
258-
# run it in executor.
241+
# run it in executor and preserve the context.
242+
ctx = contextvars.copy_context()
243+
func = functools.partial(target, *message.args, **kwargs)
259244
target_future = loop.run_in_executor(
260245
self.executor,
261-
_run_sync,
262-
target,
263-
message.args,
264-
kwargs,
246+
ctx.run,
247+
func,
265248
)
266249
timeout = message.labels.get("timeout")
267250
if timeout is not None:

tests/receiver/test_receiver.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
2+
import contextvars
23
import random
34
import time
45
from concurrent.futures import ThreadPoolExecutor
5-
from typing import Any, ClassVar, List, Optional
6+
from typing import Any, ClassVar, Generator, List, Optional
67

78
import pytest
89
from taskiq_dependencies import Depends
@@ -472,3 +473,64 @@ async def task_no_result() -> str:
472473
assert resp.return_value is None
473474
assert not broker._running_tasks
474475
assert isinstance(resp.error, ValueError)
476+
477+
478+
EXPECTED_CTX_VALUE = 42
479+
480+
481+
@pytest.fixture()
482+
def ctxvar() -> Generator[contextvars.ContextVar[int], None, None]:
483+
_ctx_variable: contextvars.ContextVar[int] = contextvars.ContextVar(
484+
"taskiq_test_ctx_var",
485+
)
486+
token = _ctx_variable.set(EXPECTED_CTX_VALUE)
487+
yield _ctx_variable
488+
_ctx_variable.reset(token)
489+
490+
491+
@pytest.mark.anyio
492+
async def test_run_task_successful_sync_preserve_contextvars(
493+
ctxvar: contextvars.ContextVar[int],
494+
) -> None:
495+
"""Running sync tasks should preserve context vars."""
496+
497+
def test_func() -> int:
498+
return ctxvar.get()
499+
500+
receiver = get_receiver()
501+
502+
result = await receiver.run_task(
503+
test_func,
504+
TaskiqMessage(
505+
task_id="",
506+
task_name="",
507+
labels={},
508+
args=[],
509+
kwargs={},
510+
),
511+
)
512+
assert result.return_value == EXPECTED_CTX_VALUE
513+
514+
515+
@pytest.mark.anyio
516+
async def test_run_task_successful_async_preserve_contextvars(
517+
ctxvar: contextvars.ContextVar[int],
518+
) -> None:
519+
"""Running async tasks should preserve context vars."""
520+
521+
async def test_func() -> int:
522+
return ctxvar.get()
523+
524+
receiver = get_receiver()
525+
526+
result = await receiver.run_task(
527+
test_func,
528+
TaskiqMessage(
529+
task_id="",
530+
task_name="",
531+
labels={},
532+
args=[],
533+
kwargs={},
534+
),
535+
)
536+
assert result.return_value == EXPECTED_CTX_VALUE

0 commit comments

Comments
 (0)