From 2821841c56a90d2aae87d6d951223dd99db18bf7 Mon Sep 17 00:00:00 2001 From: Michael Gorven Date: Mon, 12 May 2025 15:48:10 -0700 Subject: [PATCH] Cancel remaining fields on exceptions gather() returns when the first exception is raised, but does not cancel any remaining tasks. These continue to run which is inefficient, and can also cause problems if they access shared resources like database connections. Fixes: #236 --- src/graphql/execution/execute.py | 40 ++++++++++++++++++++++++++------ tests/execution/test_parallel.py | 40 ++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/src/graphql/execution/execute.py b/src/graphql/execution/execute.py index 1097e80f..6bff7873 100644 --- a/src/graphql/execution/execute.py +++ b/src/graphql/execution/execute.py @@ -2,7 +2,14 @@ from __future__ import annotations -from asyncio import ensure_future, gather, shield, wait_for +from asyncio import ( + CancelledError, + create_task, + ensure_future, + gather, + shield, + wait_for, +) from contextlib import suppress from copy import copy from typing import ( @@ -459,12 +466,23 @@ async def get_results() -> dict[str, Any]: field = awaitable_fields[0] results[field] = await results[field] else: - results.update( - zip( - awaitable_fields, - await gather(*(results[field] for field in awaitable_fields)), - ) - ) + tasks = { + create_task(results[field]): field # type: ignore[arg-type] + for field in awaitable_fields + } + + try: + awaited_results = await gather(*tasks) + except Exception: + # Cancel unfinished tasks before raising the exception + for task in tasks: + if not task.done(): + task.cancel() + await gather(*tasks, return_exceptions=True) + raise + + results.update(zip(awaitable_fields, awaited_results)) + return results return get_results() @@ -538,6 +556,10 @@ async def await_completed() -> Any: try: return await completed except Exception as raw_error: + # Before Python 3.8 CancelledError inherits Exception and + # so gets caught here. + if isinstance(raw_error, CancelledError): + raise # pragma: no cover self.handle_field_error( raw_error, return_type, @@ -745,6 +767,10 @@ async def complete_awaitable_value( if self.is_awaitable(completed): completed = await completed except Exception as raw_error: + # Before Python 3.8 CancelledError inherits Exception and + # so gets caught here. + if isinstance(raw_error, CancelledError): + raise # pragma: no cover self.handle_field_error( raw_error, return_type, field_group, path, incremental_data_record ) diff --git a/tests/execution/test_parallel.py b/tests/execution/test_parallel.py index f4dc86b1..4279952c 100644 --- a/tests/execution/test_parallel.py +++ b/tests/execution/test_parallel.py @@ -11,6 +11,7 @@ GraphQLInt, GraphQLInterfaceType, GraphQLList, + GraphQLNonNull, GraphQLObjectType, GraphQLSchema, GraphQLString, @@ -193,3 +194,42 @@ async def is_type_of_baz(obj, *_args): {"foo": [{"foo": "bar", "foobar": 1}, {"foo": "baz", "foobaz": 2}]}, None, ) + + @pytest.mark.asyncio + async def cancel_on_exception(): + barrier = Barrier(2) + completed = False + + async def succeed(*_args): + nonlocal completed + await barrier.wait() + completed = True # pragma: no cover + + async def fail(*_args): + raise RuntimeError("Oops") + + schema = GraphQLSchema( + GraphQLObjectType( + "Query", + { + "foo": GraphQLField(GraphQLNonNull(GraphQLBoolean), resolve=fail), + "bar": GraphQLField(GraphQLBoolean, resolve=succeed), + }, + ) + ) + + ast = parse("{foo, bar}") + + awaitable_result = execute(schema, ast) + assert isinstance(awaitable_result, Awaitable) + result = await asyncio.wait_for(awaitable_result, 1.0) + + assert result == ( + None, + [{"message": "Oops", "locations": [(1, 2)], "path": ["foo"]}], + ) + + # Unblock succeed() and check that it does not complete + await barrier.wait() + await asyncio.sleep(0) + assert not completed