Skip to content

Commit 57c083d

Browse files
authored
Cancel remaining fields on exceptions (#238)
1 parent 0107e30 commit 57c083d

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

src/graphql/execution/execute.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22

33
from __future__ import annotations
44

5-
from asyncio import ensure_future, gather, shield, wait_for
5+
from asyncio import (
6+
CancelledError,
7+
create_task,
8+
ensure_future,
9+
gather,
10+
shield,
11+
wait_for,
12+
)
613
from contextlib import suppress
714
from copy import copy
815
from typing import (
@@ -459,12 +466,23 @@ async def get_results() -> dict[str, Any]:
459466
field = awaitable_fields[0]
460467
results[field] = await results[field]
461468
else:
462-
results.update(
463-
zip(
464-
awaitable_fields,
465-
await gather(*(results[field] for field in awaitable_fields)),
466-
)
467-
)
469+
tasks = {
470+
create_task(results[field]): field # type: ignore[arg-type]
471+
for field in awaitable_fields
472+
}
473+
474+
try:
475+
awaited_results = await gather(*tasks)
476+
except Exception:
477+
# Cancel unfinished tasks before raising the exception
478+
for task in tasks:
479+
if not task.done():
480+
task.cancel()
481+
await gather(*tasks, return_exceptions=True)
482+
raise
483+
484+
results.update(zip(awaitable_fields, awaited_results))
485+
468486
return results
469487

470488
return get_results()
@@ -538,6 +556,10 @@ async def await_completed() -> Any:
538556
try:
539557
return await completed
540558
except Exception as raw_error:
559+
# Before Python 3.8 CancelledError inherits Exception and
560+
# so gets caught here.
561+
if isinstance(raw_error, CancelledError):
562+
raise # pragma: no cover
541563
self.handle_field_error(
542564
raw_error,
543565
return_type,
@@ -745,6 +767,10 @@ async def complete_awaitable_value(
745767
if self.is_awaitable(completed):
746768
completed = await completed
747769
except Exception as raw_error:
770+
# Before Python 3.8 CancelledError inherits Exception and
771+
# so gets caught here.
772+
if isinstance(raw_error, CancelledError):
773+
raise # pragma: no cover
748774
self.handle_field_error(
749775
raw_error, return_type, field_group, path, incremental_data_record
750776
)

tests/execution/test_parallel.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
GraphQLInt,
1212
GraphQLInterfaceType,
1313
GraphQLList,
14+
GraphQLNonNull,
1415
GraphQLObjectType,
1516
GraphQLSchema,
1617
GraphQLString,
@@ -193,3 +194,42 @@ async def is_type_of_baz(obj, *_args):
193194
{"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]},
194195
None,
195196
)
197+
198+
@pytest.mark.asyncio
199+
async def cancel_on_exception():
200+
barrier = Barrier(2)
201+
completed = False
202+
203+
async def succeed(*_args):
204+
nonlocal completed
205+
await barrier.wait()
206+
completed = True # pragma: no cover
207+
208+
async def fail(*_args):
209+
raise RuntimeError("Oops")
210+
211+
schema = GraphQLSchema(
212+
GraphQLObjectType(
213+
"Query",
214+
{
215+
"foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail),
216+
"bar": GraphQLField(GraphQLBoolean, resolve=succeed),
217+
},
218+
)
219+
)
220+
221+
ast = parse("{foo, bar}")
222+
223+
awaitable_result = execute(schema, ast)
224+
assert isinstance(awaitable_result, Awaitable)
225+
result = await asyncio.wait_for(awaitable_result, 1.0)
226+
227+
assert result == (
228+
None,
229+
[{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}],
230+
)
231+
232+
# Unblock succeed() and check that it does not complete
233+
await barrier.wait()
234+
await asyncio.sleep(0)
235+
assert not completed

0 commit comments

Comments
 (0)