Skip to content

Commit 2821841

Browse files
committed
Cancel remaining fields on exceptions
gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236
1 parent 0107e30 commit 2821841

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

src/graphql/execution/execute.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from __future__ import annotations
44

5-
from asyncio import ensure_future, gather, shield, wait_for
5+
from asyncio import (
6+
CancelledError,
7+
create_task,
8+
ensure_future,
9+
gather,
10+
shield,
11+
wait_for,
12+
)
613
from contextlib import suppress
714
from copy import copy
815
from typing import (
@@ -459,12 +466,23 @@ async def get_results() -> dict[str, Any]:
459466
field = awaitable_fields[0]
460467
results[field] = await results[field]
461468
else:
462-
results.update(
463-
zip(
464-
awaitable_fields,
465-
await gather(*(results[field] for field in awaitable_fields)),
466-
)
467-
)
469+
tasks = {
470+
create_task(results[field]): field # type: ignore[arg-type]
471+
for field in awaitable_fields
472+
}
473+
474+
try:
475+
awaited_results = await gather(*tasks)
476+
except Exception:
477+
# Cancel unfinished tasks before raising the exception
478+
for task in tasks:
479+
if not task.done():
480+
task.cancel()
481+
await gather(*tasks, return_exceptions=True)
482+
raise
483+
484+
results.update(zip(awaitable_fields, awaited_results))
485+
468486
return results
469487

470488
return get_results()
@@ -538,6 +556,10 @@ async def await_completed() -> Any:
538556
try:
539557
return await completed
540558
except Exception as raw_error:
559+
# Before Python 3.8 CancelledError inherits Exception and
560+
# so gets caught here.
561+
if isinstance(raw_error, CancelledError):
562+
raise # pragma: no cover
541563
self.handle_field_error(
542564
raw_error,
543565
return_type,
@@ -745,6 +767,10 @@ async def complete_awaitable_value(
745767
if self.is_awaitable(completed):
746768
completed = await completed
747769
except Exception as raw_error:
770+
# Before Python 3.8 CancelledError inherits Exception and
771+
# so gets caught here.
772+
if isinstance(raw_error, CancelledError):
773+
raise # pragma: no cover
748774
self.handle_field_error(
749775
raw_error, return_type, field_group, path, incremental_data_record
750776
)

tests/execution/test_parallel.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GraphQLInt,
1212
GraphQLInterfaceType,
1313
GraphQLList,
14+
GraphQLNonNull,
1415
GraphQLObjectType,
1516
GraphQLSchema,
1617
GraphQLString,
@@ -193,3 +194,42 @@ async def is_type_of_baz(obj, *_args):
193194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
194195
None,
195196
)
197+
198+
@pytest.mark.asyncio
199+
async def cancel_on_exception():
200+
barrier = Barrier(2)
201+
completed = False
202+
203+
async def succeed(*_args):
204+
nonlocal completed
205+
await barrier.wait()
206+
completed = True # pragma: no cover
207+
208+
async def fail(*_args):
209+
raise RuntimeError("Oops")
210+
211+
schema = GraphQLSchema(
212+
GraphQLObjectType(
213+
"Query",
214+
{
215+
"foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail),
216+
"bar": GraphQLField(GraphQLBoolean, resolve=succeed),
217+
},
218+
)
219+
)
220+
221+
ast = parse("{foo, bar}")
222+
223+
awaitable_result = execute(schema, ast)
224+
assert isinstance(awaitable_result, Awaitable)
225+
result = await asyncio.wait_for(awaitable_result, 1.0)
226+
227+
assert result == (
228+
None,
229+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}],
230+
)
231+
232+
# Unblock succeed() and check that it does not complete
233+
await barrier.wait()
234+
await asyncio.sleep(0)
235+
assert not completed

0 commit comments

Comments
 (0)