diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 1097e80f..6bff7873 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -2,7 +2,14 @@ from __future__ import annotations -from asyncio import ensure_future, gather, shield, wait_for +from asyncio import ( + CancelledError, + create_task, + ensure_future, + gather, + shield, + wait_for, +) from contextlib import suppress from copy import copy from typing import ( @@ -459,12 +466,23 @@ async def get_results() -> dict[str, Any]: field = awaitable_fields[0] results[field] = await results[field] else: - results.update( - zip( - awaitable_fields, - await gather(*(results[field] for field in awaitable_fields)), - ) - ) + tasks = { + create_task(results[field]): field # type: ignore[arg-type] + for field in awaitable_fields + } + + try: + awaited_results = await gather(*tasks) + except Exception: + # Cancel unfinished tasks before raising the exception + for task in tasks: + if not task.done(): + task.cancel() + await gather(*tasks, return_exceptions=True) + raise + + results.update(zip(awaitable_fields, awaited_results)) + return results return get_results() @@ -538,6 +556,10 @@ async def await_completed() -> Any: try: return await completed except Exception as raw_error: + # Before Python 3.8 CancelledError inherits Exception and + # so gets caught here. + if isinstance(raw_error, CancelledError): + raise # pragma: no cover self.handle_field_error( raw_error, return_type, @@ -745,6 +767,10 @@ async def complete_awaitable_value( if self.is_awaitable(completed): completed = await completed except Exception as raw_error: + # Before Python 3.8 CancelledError inherits Exception and + # so gets caught here. + if isinstance(raw_error, CancelledError): + raise # pragma: no cover self.handle_field_error( raw_error, return_type, field_group, path, incremental_data_record ) diff --git a/tests/execution/test_parallel.py b/tests/execution/test_parallel.py index f4dc86b1..4279952c 100644 --- a/tests/execution/test_parallel.py +++ b/tests/execution/test_parallel.py @@ -11,6 +11,7 @@ GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString, @@ -193,3 +194,42 @@ async def is_type_of_baz(obj, *_args): {"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]}, None, ) + + @pytest.mark.asyncio + async def cancel_on_exception(): + barrier = Barrier(2) + completed = False + + async def succeed(*_args): + nonlocal completed + await barrier.wait() + completed = True # pragma: no cover + + async def fail(*_args): + raise RuntimeError("Oops") + + schema = GraphQLSchema( + GraphQLObjectType( + "Query", + { + "foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail), + "bar": GraphQLField(GraphQLBoolean, resolve=succeed), + }, + ) + ) + + ast = parse("{foo, bar}") + + awaitable_result = execute(schema, ast) + assert isinstance(awaitable_result, Awaitable) + result = await asyncio.wait_for(awaitable_result, 1.0) + + assert result == ( + None, + [{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}], + ) + + # Unblock succeed() and check that it does not complete + await barrier.wait() + await asyncio.sleep(0) + assert not completed