Skip to content

Commit a266dfc

Browse files
authored
Merge pull request #1513 from guilledk/issue1244
Fixed unhelpful error message in from_thread_run functions.
2 parents 54151c5 + f77c200 commit a266dfc

File tree

8 files changed

+227
-161
lines changed

8 files changed

+227
-161
lines changed

newsfragments/1244.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added a helpful error message if an async function is passed to `trio.from_thread.run_sync` or a sync function to `trio.from_thread.run`.

trio/_core/_run.py

Lines changed: 5 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from sniffio import current_async_library_cvar
1818

1919
import attr
20-
from async_generator import isasyncgen
2120
from sortedcontainers import SortedDict
2221
from outcome import Error, Value, capture
2322

@@ -36,7 +35,7 @@
3635
)
3736
from .. import _core
3837
from .._deprecate import deprecated
39-
from .._util import Final, NoPublicConstructor
38+
from .._util import Final, NoPublicConstructor, coroutine_or_error
4039

4140
_NO_SEND = object()
4241

@@ -1247,86 +1246,7 @@ def spawn_impl(self, async_fn, args, nursery, name, *, system_task=False):
12471246
# Call the function and get the coroutine object, while giving helpful
12481247
# errors for common mistakes.
12491248
######
1250-
1251-
def _return_value_looks_like_wrong_library(value):
1252-
# Returned by legacy @asyncio.coroutine functions, which includes
1253-
# a surprising proportion of asyncio builtins.
1254-
if isinstance(value, collections.abc.Generator):
1255-
return True
1256-
# The protocol for detecting an asyncio Future-like object
1257-
if getattr(value, "_asyncio_future_blocking", None) is not None:
1258-
return True
1259-
# This janky check catches tornado Futures and twisted Deferreds.
1260-
# By the time we're calling this function, we already know
1261-
# something has gone wrong, so a heuristic is pretty safe.
1262-
if value.__class__.__name__ in ("Future", "Deferred"):
1263-
return True
1264-
return False
1265-
1266-
try:
1267-
coro = async_fn(*args)
1268-
except TypeError:
1269-
# Give good error for: nursery.start_soon(trio.sleep(1))
1270-
if isinstance(async_fn, collections.abc.Coroutine):
1271-
raise TypeError(
1272-
"Trio was expecting an async function, but instead it got "
1273-
"a coroutine object {async_fn!r}\n"
1274-
"\n"
1275-
"Probably you did something like:\n"
1276-
"\n"
1277-
" trio.run({async_fn.__name__}(...)) # incorrect!\n"
1278-
" nursery.start_soon({async_fn.__name__}(...)) # incorrect!\n"
1279-
"\n"
1280-
"Instead, you want (notice the parentheses!):\n"
1281-
"\n"
1282-
" trio.run({async_fn.__name__}, ...) # correct!\n"
1283-
" nursery.start_soon({async_fn.__name__}, ...) # correct!"
1284-
.format(async_fn=async_fn)
1285-
) from None
1286-
1287-
# Give good error for: nursery.start_soon(future)
1288-
if _return_value_looks_like_wrong_library(async_fn):
1289-
raise TypeError(
1290-
"Trio was expecting an async function, but instead it got "
1291-
"{!r} – are you trying to use a library written for "
1292-
"asyncio/twisted/tornado or similar? That won't work "
1293-
"without some sort of compatibility shim."
1294-
.format(async_fn)
1295-
) from None
1296-
1297-
raise
1298-
1299-
# We can't check iscoroutinefunction(async_fn), because that will fail
1300-
# for things like functools.partial objects wrapping an async
1301-
# function. So we have to just call it and then check whether the
1302-
# return value is a coroutine object.
1303-
if not isinstance(coro, collections.abc.Coroutine):
1304-
# Give good error for: nursery.start_soon(func_returning_future)
1305-
if _return_value_looks_like_wrong_library(coro):
1306-
raise TypeError(
1307-
"start_soon got unexpected {!r} – are you trying to use a "
1308-
"library written for asyncio/twisted/tornado or similar? "
1309-
"That won't work without some sort of compatibility shim."
1310-
.format(coro)
1311-
)
1312-
1313-
if isasyncgen(coro):
1314-
raise TypeError(
1315-
"start_soon expected an async function but got an async "
1316-
"generator {!r}".format(coro)
1317-
)
1318-
1319-
# Give good error for: nursery.start_soon(some_sync_fn)
1320-
raise TypeError(
1321-
"Trio expected an async function, but {!r} appears to be "
1322-
"synchronous".format(
1323-
getattr(async_fn, "__qualname__", async_fn)
1324-
)
1325-
)
1326-
1327-
######
1328-
# Set up the Task object
1329-
######
1249+
coro = coroutine_or_error(async_fn, *args)
13301250

13311251
if name is None:
13321252
name = async_fn
@@ -1353,6 +1273,9 @@ async def python_wrapper(orig_coro):
13531273
LOCALS_KEY_KI_PROTECTION_ENABLED, system_task
13541274
)
13551275

1276+
######
1277+
# Set up the Task object
1278+
######
13561279
task = Task._create(
13571280
coro=coro,
13581281
parent_nursery=nursery,

trio/_core/tests/test_run.py

Lines changed: 17 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
import sniffio
1616
import pytest
1717

18-
from .tutil import slow, check_sequence_matches, gc_collect_harder
18+
from .tutil import (
19+
slow, check_sequence_matches, gc_collect_harder,
20+
ignore_coroutine_never_awaited_warnings
21+
)
22+
1923
from ... import _core
2024
from ..._threads import to_thread_run_sync
2125
from ..._timeouts import sleep, fail_after
@@ -33,24 +37,6 @@ async def sleep_forever():
3337
return await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
3438

3539

36-
# Some of our tests need to leak coroutines, and thus trigger the
37-
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
38-
# manager should be used anywhere this happens to hide those messages, because
39-
# when expected they're clutter.
40-
@contextmanager
41-
def ignore_coroutine_never_awaited_warnings():
42-
with warnings.catch_warnings():
43-
warnings.filterwarnings(
44-
"ignore", message="coroutine '.*' was never awaited"
45-
)
46-
try:
47-
yield
48-
finally:
49-
# Make sure to trigger any coroutine __del__ methods now, before
50-
# we leave the context manager.
51-
gc_collect_harder()
52-
53-
5440
def test_basic():
5541
async def trivial(x):
5642
return x
@@ -1696,8 +1682,6 @@ async def test_current_effective_deadline(mock_clock):
16961682
assert _core.current_effective_deadline() == inf
16971683

16981684

1699-
# @coroutine is deprecated since python 3.8, which is fine with us.
1700-
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
17011685
def test_nice_error_on_bad_calls_to_run_or_spawn():
17021686
def bad_call_run(*args):
17031687
_core.run(*args)
@@ -1709,59 +1693,22 @@ async def main():
17091693

17101694
_core.run(main)
17111695

1712-
class Deferred:
1713-
"Just kidding"
1714-
1715-
with ignore_coroutine_never_awaited_warnings():
1716-
for bad_call in bad_call_run, bad_call_spawn:
1717-
1718-
async def f(): # pragma: no cover
1719-
pass
1720-
1721-
with pytest.raises(TypeError) as excinfo:
1722-
bad_call(f())
1723-
assert "expecting an async function" in str(excinfo.value)
1724-
1725-
import asyncio
1726-
1727-
@asyncio.coroutine
1728-
def generator_based_coro(): # pragma: no cover
1729-
yield from asyncio.sleep(1)
1730-
1731-
with pytest.raises(TypeError) as excinfo:
1732-
bad_call(generator_based_coro())
1733-
assert "asyncio" in str(excinfo.value)
1696+
for bad_call in bad_call_run, bad_call_spawn:
17341697

1735-
with pytest.raises(TypeError) as excinfo:
1736-
bad_call(asyncio.Future())
1737-
assert "asyncio" in str(excinfo.value)
1738-
1739-
with pytest.raises(TypeError) as excinfo:
1740-
bad_call(lambda: asyncio.Future())
1741-
assert "asyncio" in str(excinfo.value)
1742-
1743-
with pytest.raises(TypeError) as excinfo:
1744-
bad_call(Deferred())
1745-
assert "twisted" in str(excinfo.value)
1746-
1747-
with pytest.raises(TypeError) as excinfo:
1748-
bad_call(lambda: Deferred())
1749-
assert "twisted" in str(excinfo.value)
1750-
1751-
with pytest.raises(TypeError) as excinfo:
1752-
bad_call(len, [1, 2, 3])
1753-
assert "appears to be synchronous" in str(excinfo.value)
1698+
async def f(): # pragma: no cover
1699+
pass
17541700

1755-
async def async_gen(arg): # pragma: no cover
1756-
yield
1701+
with pytest.raises(TypeError, match="expecting an async function"):
1702+
bad_call(f())
17571703

1758-
with pytest.raises(TypeError) as excinfo:
1759-
bad_call(async_gen, 0)
1760-
msg = "expected an async function but got an async generator"
1761-
assert msg in str(excinfo.value)
1704+
async def async_gen(arg): # pragma: no cover
1705+
yield arg
17621706

1763-
# Make sure no references are kept around to keep anything alive
1764-
del excinfo
1707+
with pytest.raises(
1708+
TypeError,
1709+
match="expected an async function but got an async generator"
1710+
):
1711+
bad_call(async_gen, 0)
17651712

17661713

17671714
def test_calling_asyncio_function_gives_nice_error():

trio/_core/tests/tutil.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44

55
import pytest
6+
import warnings
7+
from contextlib import contextmanager
68

79
import gc
810

@@ -52,6 +54,24 @@ def gc_collect_harder():
5254
gc.collect()
5355

5456

57+
# Some of our tests need to leak coroutines, and thus trigger the
58+
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
59+
# manager should be used anywhere this happens to hide those messages, because
60+
# when expected they're clutter.
61+
@contextmanager
62+
def ignore_coroutine_never_awaited_warnings():
63+
with warnings.catch_warnings():
64+
warnings.filterwarnings(
65+
"ignore", message="coroutine '.*' was never awaited"
66+
)
67+
try:
68+
yield
69+
finally:
70+
# Make sure to trigger any coroutine __del__ methods now, before
71+
# we leave the context manager.
72+
gc_collect_harder()
73+
74+
5575
# template is like:
5676
# [1, {2.1, 2.2}, 3] -> matches [1, 2.1, 2.2, 3] or [1, 2.2, 2.1, 3]
5777
def check_sequence_matches(seq, template):

trio/_threads.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33
from itertools import count
44

55
import attr
6+
import inspect
67
import outcome
78

89
import trio
910

1011
from ._sync import CapacityLimiter
1112
from ._core import enable_ki_protection, disable_ki_protection, RunVar, TrioToken
13+
from ._util import coroutine_or_error
1214

1315
# Global due to Threading API, thread local storage for trio token
1416
TOKEN_LOCAL = threading.local()
@@ -365,6 +367,7 @@ def from_thread_run(afn, *args, trio_token=None):
365367
which would otherwise cause a deadlock.
366368
AttributeError: if no ``trio_token`` was provided, and we can't infer
367369
one from context.
370+
TypeError: if ``afn`` is not an asynchronous function.
368371
369372
**Locating a Trio Token**: There are two ways to specify which
370373
`trio.run` loop to reenter:
@@ -380,7 +383,8 @@ def from_thread_run(afn, *args, trio_token=None):
380383
def callback(q, afn, args):
381384
@disable_ki_protection
382385
async def unprotected_afn():
383-
return await afn(*args)
386+
coro = coroutine_or_error(afn, *args)
387+
return await coro
384388

385389
async def await_in_trio_thread_task():
386390
q.put_nowait(await outcome.acapture(unprotected_afn))
@@ -403,13 +407,11 @@ def from_thread_run_sync(fn, *args, trio_token=None):
403407
Raises:
404408
RunFinishedError: if the corresponding call to `trio.run` has
405409
already completed.
406-
Cancelled: if the corresponding call to `trio.run` completes
407-
while ``afn(*args)`` is running, then ``afn`` is likely to raise
408-
:exc:`trio.Cancelled`, and this will propagate out into
409410
RuntimeError: if you try calling this from inside the Trio thread,
410411
which would otherwise cause a deadlock.
411412
AttributeError: if no ``trio_token`` was provided, and we can't infer
412413
one from context.
414+
TypeError: if ``fn`` is an async function.
413415
414416
**Locating a Trio Token**: There are two ways to specify which
415417
`trio.run` loop to reenter:
@@ -425,7 +427,17 @@ def from_thread_run_sync(fn, *args, trio_token=None):
425427
def callback(q, fn, args):
426428
@disable_ki_protection
427429
def unprotected_fn():
428-
return fn(*args)
430+
ret = fn(*args)
431+
432+
if inspect.iscoroutine(ret):
433+
# Manually close coroutine to avoid RuntimeWarnings
434+
ret.close()
435+
raise TypeError(
436+
"Trio expected a sync function, but {!r} appears to be "
437+
"asynchronous".format(getattr(fn, "__qualname__", fn))
438+
)
439+
440+
return ret
429441

430442
res = outcome.capture(unprotected_fn)
431443
q.put_nowait(res)

0 commit comments

Comments
 (0)