From 35108796bc8891a7dde1b08ac8948f83fec1c800 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 25 May 2025 22:00:17 +0200 Subject: [PATCH 1/2] Using GraphQLRequest as main argument for transports --- gql/client.py | 224 +++++++++----------- gql/transport/aiohttp.py | 38 +--- gql/transport/appsync_websockets.py | 19 +- gql/transport/async_transport.py | 26 ++- gql/transport/common/base.py | 23 +- gql/transport/httpx.py | 58 ++--- gql/transport/local_schema.py | 38 +++- gql/transport/phoenix_channel_websockets.py | 11 +- gql/transport/requests.py | 39 +--- gql/transport/transport.py | 25 ++- gql/transport/websockets_protocol.py | 13 +- tests/starwars/test_subscription.py | 6 +- tests/test_aiohttp.py | 8 +- tests/test_client.py | 4 +- tests/test_httpx.py | 44 ++-- tests/test_httpx_async.py | 8 +- tests/test_requests.py | 56 ++--- 17 files changed, 295 insertions(+), 345 deletions(-) diff --git a/gql/client.py b/gql/client.py index a4e80dcb..4e269a2a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -40,7 +40,6 @@ from .transport.transport import Transport from .utilities import build_client_schema, get_introspection_query_ast from .utilities import parse_result as parse_result_fn -from .utilities import serialize_variable_values from .utils import str_first_element log = logging.getLogger(__name__) @@ -68,6 +67,7 @@ class Client: def __init__( self, + *, schema: Optional[Union[str, GraphQLSchema]] = None, introspection: Optional[IntrospectionQuery] = None, transport: Optional[Union[Transport, AsyncTransport]] = None, @@ -206,11 +206,11 @@ def _get_event_loop() -> asyncio.AbstractEventLoop: def execute_sync( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -219,11 +219,11 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -232,11 +232,11 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -244,6 +244,7 @@ def execute_sync( def execute_sync( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -319,11 +320,11 @@ def execute_batch_sync( async def execute_async( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -332,11 +333,11 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -345,11 +346,11 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -357,6 +358,7 @@ async def execute_async( async def execute_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -432,11 +434,11 @@ async def execute_batch_async( def execute( self, document: DocumentNode, + *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, # https://github.com/python/mypy/issues/7333#issuecomment-788255229 get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -445,11 +447,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -458,11 +460,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -470,6 +472,7 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -629,11 +632,11 @@ def execute_batch( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @@ -642,11 +645,11 @@ def subscribe_async( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @@ -655,11 +658,11 @@ def subscribe_async( def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -669,6 +672,7 @@ def subscribe_async( async def subscribe_async( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -697,11 +701,11 @@ async def subscribe_async( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Generator[Dict[str, Any], None, None]: ... # pragma: no cover @@ -710,11 +714,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> Generator[ExecutionResult, None, None]: ... # pragma: no cover @@ -723,11 +727,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -737,11 +741,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, - *, get_execution_result: bool = False, **kwargs: Any, ) -> Union[ @@ -925,19 +929,17 @@ def __init__(self, client: Client): def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST synchronously using + """Execute the provided request synchronously using the sync transport, returning an ExecutionResult object. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -948,34 +950,22 @@ def _execute( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) if self.client.batching_enabled: - request = GraphQLRequest( - document, - variable_values=variable_values, - operation_name=operation_name, - ) future_result = self._execute_future(request) result = future_result.result() else: result = self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) @@ -984,9 +974,9 @@ def _execute( if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -995,11 +985,11 @@ def _execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -1008,11 +998,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -1021,11 +1011,11 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -1033,6 +1023,7 @@ def execute( def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1059,11 +1050,16 @@ def execute( The extra arguments are passed to the transport execute method.""" - # Validate and execute on the transport - result = self._execute( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + # Validate and execute on the transport + result = self._execute( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1337,7 +1333,9 @@ def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = self.transport.execute(introspection_query) + execution_result = self.transport.execute( + GraphQLRequest(document=introspection_query) + ) self.client._build_schema_from_introspection(execution_result) @@ -1360,23 +1358,21 @@ def __init__(self, client: Client): async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: - """Coroutine to subscribe asynchronously to the provided document AST + """Coroutine to subscribe asynchronously to the provided request asynchronously using the async transport, returning an async generator producing ExecutionResult objects. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1387,26 +1383,19 @@ async def _subscribe( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) # Subscribe to the transport inner_generator: AsyncGenerator[ExecutionResult, None] = ( self.transport.subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) ) @@ -1423,9 +1412,9 @@ async def _subscribe( ): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) yield result @@ -1437,11 +1426,11 @@ async def _subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> AsyncGenerator[Dict[str, Any], None]: ... # pragma: no cover @@ -1450,11 +1439,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: ... # pragma: no cover @@ -1463,11 +1452,11 @@ def subscribe( def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[ @@ -1477,6 +1466,7 @@ def subscribe( async def subscribe( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1505,10 +1495,15 @@ async def subscribe( The extra arguments are passed to the transport subscribe method.""" - inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + inner_generator: AsyncGenerator[ExecutionResult, None] = self._subscribe( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1536,22 +1531,20 @@ async def subscribe( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, ) -> ExecutionResult: - """Coroutine to execute the provided document AST asynchronously using + """Coroutine to execute the provided request asynchronously using the async transport, returning an ExecutionResult object. * Validate the query with the schema if provided. * Serialize the variable_values if requested. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters. - :param operation_name: Name of the operation that shall be executed. + :param request: graphql request as a + :class:`graphqlrequest ` object. :param serialize_variables: whether the variable values should be serialized. Used for custom scalars and/or enums. By default use the serialize_variables argument of the client. @@ -1562,26 +1555,19 @@ async def _execute( # Validate document if self.client.schema: - self.client.validate(document) + self.client.validate(request.document) # Parse variable values for custom scalars if requested - if variable_values is not None: + if request.variable_values is not None: if serialize_variables or ( serialize_variables is None and self.client.serialize_variables ): - variable_values = serialize_variable_values( - self.client.schema, - document, - variable_values, - operation_name=operation_name, - ) + request = request.serialize_variable_values(self.client.schema) # Execute the query with the transport with a timeout with fail_after(self.client.execute_timeout): result = await self.transport.execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, **kwargs, ) @@ -1590,9 +1576,9 @@ async def _execute( if parse_result or (parse_result is None and self.client.parse_results): result.data = parse_result_fn( self.client.schema, - document, + request.document, result.data, - operation_name=operation_name, + operation_name=request.operation_name, ) return result @@ -1601,11 +1587,11 @@ async def _execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[False] = ..., **kwargs: Any, ) -> Dict[str, Any]: ... # pragma: no cover @@ -1614,11 +1600,11 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: Literal[True], **kwargs: Any, ) -> ExecutionResult: ... # pragma: no cover @@ -1627,11 +1613,11 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = ..., operation_name: Optional[str] = ..., serialize_variables: Optional[bool] = ..., parse_result: Optional[bool] = ..., - *, get_execution_result: bool, **kwargs: Any, ) -> Union[Dict[str, Any], ExecutionResult]: ... # pragma: no cover @@ -1639,6 +1625,7 @@ async def execute( async def execute( self, document: DocumentNode, + *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, serialize_variables: Optional[bool] = None, @@ -1665,11 +1652,16 @@ async def execute( The extra arguments are passed to the transport execute method.""" - # Validate and execute on the transport - result = await self._execute( - document, + # Make GraphQLRequest object + request = GraphQLRequest( + document=document, variable_values=variable_values, operation_name=operation_name, + ) + + # Validate and execute on the transport + result = await self._execute( + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1844,7 +1836,9 @@ async def fetch_schema(self) -> None: introspection_query = get_introspection_query_ast( **self.client.introspection_args ) - execution_result = await self.transport.execute(introspection_query) + execution_result = await self.transport.execute( + GraphQLRequest(introspection_query) + ) self.client._build_schema_from_introspection(execution_result) @@ -1869,6 +1863,7 @@ class ReconnectingAsyncClientSession(AsyncClientSession): def __init__( self, client: Client, + *, retry_connect: Union[bool, _Decorator] = True, retry_execute: Union[bool, _Decorator] = True, ): @@ -1961,9 +1956,8 @@ async def stop_connecting_task(self): async def _execute_once( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -1974,9 +1968,7 @@ async def _execute_once( try: answer = await super()._execute( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -1989,9 +1981,8 @@ async def _execute_once( async def _execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -2002,9 +1993,7 @@ async def _execute( """ return await self._execute_with_retries( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, @@ -2012,9 +2001,8 @@ async def _execute( async def _subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, serialize_variables: Optional[bool] = None, parse_result: Optional[bool] = None, **kwargs: Any, @@ -2024,9 +2012,7 @@ async def _subscribe( """ inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( - document, - variable_values=variable_values, - operation_name=operation_name, + request, serialize_variables=serialize_variables, parse_result=parse_result, **kwargs, diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 9535eef4..89c410d6 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -20,7 +20,7 @@ from aiohttp.client_reqrep import Fingerprint from aiohttp.helpers import BasicAuth from aiohttp.typedefs import LooseCookies, LooseHeaders -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from multidict import CIMultiDictProxy from ..graphql_request import GraphQLRequest @@ -164,18 +164,6 @@ async def close(self) -> None: self.session = None - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload - def _prepare_batch_request( self, reqs: List[GraphQLRequest], @@ -359,13 +347,12 @@ def _raise_invalid_result(self, result_text: str, reason: str) -> None: async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server + """Execute the provided request against the configured remote server using the current session. This uses the aiohttp library to perform a HTTP POST request asynchronously to the remote server. @@ -373,22 +360,15 @@ async def execute( Don't call this coroutine directly on the transport, instead use :code:`execute` on a client or a session. - :param document: the parsed GraphQL request - :param variable_values: An optional Dict of variable values - :param operation_name: An optional Operation name for the request + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the aiohttp post method :param upload_files: Set to True if you want to put files in the variable values :returns: an ExecutionResult object. """ - req = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( - req, + request, extra_args, upload_files, ) @@ -434,9 +414,7 @@ async def execute_batch( def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index a6a7d180..bb9fdb6a 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Optional, Tuple, Union, cast from urllib.parse import urlparse -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult +from ..graphql_request import GraphQLRequest from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase @@ -150,22 +151,14 @@ def _parse_answer( async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: query_id = self.next_query_id self.next_query_id += 1 - data: Dict = {"query": print_ast(document)} - - if variable_values: - data["variables"] = variable_values - - if operation_name: - data["operationName"] = operation_name + data: Dict[str, Any] = self._build_payload(request) serialized_data = json.dumps(data, separators=(",", ":")) @@ -203,9 +196,7 @@ async def _send_query( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: """This method is not available. diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 243746e6..f2ce6b24 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,7 @@ import abc -from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult, print_ast from ..graphql_request import GraphQLRequest @@ -24,11 +24,9 @@ async def close(self): @abc.abstractmethod async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST for either a remote or local GraphQL + """Execute the provided request for either a remote or local GraphQL Schema.""" raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" @@ -54,9 +52,7 @@ async def execute_batch( @abc.abstractmethod def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using an async generator @@ -67,3 +63,15 @@ def subscribe( raise NotImplementedError( "Any AsyncTransport subclass must implement subscribe method" ) # pragma: no cover + + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index a285ad2c..f2070fe1 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -5,8 +5,9 @@ from contextlib import suppress from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult +from ...graphql_request import GraphQLRequest from ..async_transport import AsyncTransport from ..exceptions import ( TransportAlreadyConnected, @@ -158,9 +159,7 @@ async def _receive(self) -> str: @abstractmethod async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: raise NotImplementedError # pragma: no cover @@ -267,9 +266,8 @@ async def _handle_answer( async def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, send_stop: Optional[bool] = True, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using a python async generator. @@ -281,7 +279,7 @@ async def subscribe( # Send the query and receive the id query_id: int = await self._send_query( - document, variable_values, operation_name + request, ) # Create a queue to receive the answers for this query_id @@ -325,11 +323,9 @@ async def subscribe( async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server + """Execute the provided request against the configured remote server using the current session. Send a query but close the async generator as soon as we have the first answer. @@ -339,7 +335,8 @@ async def execute( first_result = None generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + request, + send_stop=False, ) async for result in generator: diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 406c0523..43e14a6f 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -1,3 +1,4 @@ +import abc import io import json import logging @@ -15,7 +16,7 @@ ) import httpx -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest from . import AsyncTransport, Transport @@ -57,17 +58,10 @@ def __init__( self.json_deserialize = json_deserialize self.kwargs = kwargs + @abc.abstractmethod def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload + """This is Implemented in Transport and AsyncTransport""" + raise NotImplementedError() # pragma: no cover def _prepare_request( self, @@ -243,21 +237,18 @@ def connect(self): def execute( # type: ignore self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -267,12 +258,6 @@ def execute( # type: ignore if not self.client: raise TransportClosed("Transport is not connected") - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( request, extra_args, @@ -343,22 +328,19 @@ async def connect(self): async def execute( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, + *, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the httpx library to perform a HTTP POST request asynchronously to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param extra_args: additional arguments to send to the httpx post method :param upload_files: Set to True if you want to put files in the variable values :return: The result of execution. @@ -368,12 +350,6 @@ async def execute( if not self.client: raise TransportClosed("Transport is not connected") - request = GraphQLRequest( - document=document, - variable_values=variable_values, - operation_name=operation_name, - ) - post_args = self._prepare_request( request, extra_args, @@ -420,9 +396,7 @@ async def execute_batch( def subscribe( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> AsyncGenerator[ExecutionResult, None]: """Subscribe is not supported on HTTP. diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index 19760ad6..f87854e2 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -2,10 +2,12 @@ from inspect import isawaitable from typing import Any, AsyncGenerator, Awaitable, cast -from graphql import DocumentNode, ExecutionResult, GraphQLSchema, execute, subscribe +from graphql import ExecutionResult, GraphQLSchema, execute, subscribe from gql.transport import AsyncTransport +from ..graphql_request import GraphQLRequest + class LocalSchemaTransport(AsyncTransport): """A transport for executing GraphQL queries against a local schema.""" @@ -30,13 +32,24 @@ async def close(self): async def execute( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> ExecutionResult: - """Execute the provided document AST for on a local GraphQL Schema.""" - - result_or_awaitable = execute(self.schema, document, *args, **kwargs) + """Execute the provided request for on a local GraphQL Schema.""" + + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + + result_or_awaitable = execute( + self.schema, + request.document, + *args, + **inner_kwargs, + ) execution_result: ExecutionResult @@ -57,7 +70,7 @@ async def _await_if_necessary(obj): async def subscribe( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: @@ -66,8 +79,19 @@ async def subscribe( The results are sent as an ExecutionResult object """ + inner_kwargs = { + "variable_values": request.variable_values, + "operation_name": request.operation_name, + **kwargs, + } + subscribe_result = await self._await_if_necessary( - subscribe(self.schema, document, *args, **kwargs) + subscribe( + self.schema, + request.document, + *args, + **inner_kwargs, + ) ) if isinstance(subscribe_result, ExecutionResult): diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 8a975b73..8e7455e2 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -3,8 +3,9 @@ import logging from typing import Any, Dict, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult, print_ast +from ..graphql_request import GraphQLRequest from .common.adapters.websockets import WebSocketsAdapter from .common.base import SubscriptionTransportBase from .exceptions import ( @@ -182,9 +183,7 @@ async def _connection_terminate(self): async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: """Send a query to the provided websocket connection. @@ -201,8 +200,8 @@ async def _send_query( "topic": self.channel_name, "event": "doc", "payload": { - "query": print_ast(document), - "variables": variable_values or {}, + "query": print_ast(request.document), + "variables": request.variable_values or {}, }, "ref": query_id, } diff --git a/gql/transport/requests.py b/gql/transport/requests.py index d84ba9d3..6480260e 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -15,7 +15,7 @@ ) import requests -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult from requests.adapters import HTTPAdapter, Retry from requests.auth import AuthBase from requests.cookies import RequestsCookieJar @@ -139,22 +139,18 @@ def connect(self): def execute( # type: ignore self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, timeout: Optional[int] = None, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST against the configured remote server. This + Execute the provided request against the configured remote server. This uses the requests library to perform a HTTP POST request to the remote server. - :param document: GraphQL query as AST Node object. - :param variable_values: Dictionary of input parameters (Default: None). - :param operation_name: Name of the operation that shall be executed. - Only required in multi-operation documents (Default: None). + :param request: GraphQL request as a + :class:`GraphQLRequest ` object. :param timeout: Specifies a default timeout for requests (Default: None). :param extra_args: additional arguments to send to the requests post method :param upload_files: Set to True if you want to put files in the variable values @@ -166,11 +162,7 @@ def execute( # type: ignore if not self.session: raise TransportClosed("Transport is not connected") - query_str = print_ast(document) - payload: Dict[str, Any] = {"query": query_str} - - if operation_name: - payload["operationName"] = operation_name + payload = self._build_payload(request) post_args: Dict[str, Any] = { "headers": self.headers, @@ -182,12 +174,12 @@ def execute( # type: ignore if upload_files: # If the upload_files flag is set, then we need variable_values - assert variable_values is not None + assert request.variable_values is not None # If we upload files, we will extract the files present in the # variable_values dict and replace them by null values nulled_variable_values, files = extract_files( - variables=variable_values, + variables=request.variable_values, file_classes=self.file_classes, ) @@ -241,9 +233,6 @@ def execute( # type: ignore post_args["headers"]["Content-Type"] = data.content_type else: - if variable_values: - payload["variables"] = variable_values - data_key = "json" if self.use_json else "data" post_args[data_key] = payload @@ -398,18 +387,6 @@ def _build_batch_post_args( return post_args - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload - def close(self): """Closing the transport by closing the inner session""" if self.session: diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 49d0aa34..2b06ebdb 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,7 @@ import abc -from typing import Any, List +from typing import Any, Dict, List -from graphql import DocumentNode, ExecutionResult +from graphql import ExecutionResult, print_ast from ..graphql_request import GraphQLRequest @@ -9,13 +9,16 @@ class Transport(abc.ABC): @abc.abstractmethod def execute( - self, document: DocumentNode, *args: Any, **kwargs: Any + self, + request: GraphQLRequest, + *args: Any, + **kwargs: Any, ) -> ExecutionResult: """Execute GraphQL query. - Execute the provided document AST for either a remote or local GraphQL Schema. + Execute the provided request for either a remote or local GraphQL Schema. - :param document: GraphQL query as AST Node or Document object. + :param request: GraphQL request as a GraphQLRequest object. :return: ExecutionResult """ raise NotImplementedError( @@ -51,3 +54,15 @@ def close(self): the session's connection pool. """ pass # pragma: no cover + + def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: + query_str = print_ast(req.document) + payload: Dict[str, Any] = {"query": query_str} + + if req.operation_name: + payload["operationName"] = req.operation_name + + if req.variable_values: + payload["variables"] = req.variable_values + + return payload diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index 61a4bb85..f1946665 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -4,8 +4,9 @@ from contextlib import suppress from typing import Any, Dict, List, Optional, Tuple, Union -from graphql import DocumentNode, ExecutionResult, print_ast +from graphql import ExecutionResult +from ..graphql_request import GraphQLRequest from .common.adapters.connection import AdapterConnection from .common.base import SubscriptionTransportBase from .exceptions import ( @@ -224,9 +225,7 @@ async def _send_connection_terminate_message(self) -> None: async def _send_query( self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, + request: GraphQLRequest, ) -> int: """Send a query to the provided websocket connection. @@ -238,11 +237,7 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name + payload: Dict[str, Any] = self._build_payload(request) query_type = "start" diff --git a/tests/starwars/test_subscription.py b/tests/starwars/test_subscription.py index 0f412acc..bbaafd5c 100644 --- a/tests/starwars/test_subscription.py +++ b/tests/starwars/test_subscription.py @@ -3,7 +3,7 @@ import pytest from graphql import ExecutionResult, GraphQLError, subscribe -from gql import Client, gql +from gql import Client, GraphQLRequest, gql from .fixtures import reviews from .schema import StarWarsSchema @@ -93,7 +93,9 @@ async def test_subscription_support_using_client_invalid_field(): results = [ result async for result in await await_if_coroutine( - session.transport.subscribe(subs, variable_values=params) + session.transport.subscribe( + GraphQLRequest(subs, variable_values=params) + ) ) ] diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 0642e536..24f82c9d 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -6,7 +6,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -421,7 +421,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(query) + await transport.execute(GraphQLRequest(query)) @pytest.mark.asyncio @@ -533,7 +533,9 @@ async def handler(request): query = gql(query2_str) # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute( + query, variable_values=params, operation_name="getEurope" + ) continent = result["continent"] diff --git a/tests/test_client.py b/tests/test_client.py index 55993a9e..3412059e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,7 +4,7 @@ from unittest import mock import pytest -from graphql import DocumentNode, ExecutionResult, build_ast_schema, parse +from graphql import ExecutionResult, build_ast_schema, parse from gql import Client, GraphQLRequest, gql from gql.transport import Transport @@ -40,7 +40,7 @@ class RandomTransport(Transport): class RandomTransport2(Transport): def execute( self, - document: DocumentNode, + request: GraphQLRequest, *args: Any, **kwargs: Any, ) -> ExecutionResult: diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 0991355a..b944391f 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -3,7 +3,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -470,7 +470,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(query) + transport.execute(GraphQLRequest(query)) await run_sync_test(server, test_code) @@ -578,32 +578,32 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: params = {"file": FileVar(f), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an filename string inside a FileVar object params = { "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -650,22 +650,22 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using FileVar params = { "file": FileVar(file_path, content_type="application/pdf"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -710,11 +710,11 @@ def test_code(): "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -751,11 +751,11 @@ def test_code(): file_path = test_file.filename params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -797,11 +797,11 @@ def test_code(): params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -867,11 +867,11 @@ def test_code(): "file2": FileVar(file_path_2), } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -940,11 +940,11 @@ def test_code(): ], } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) diff --git a/tests/test_httpx_async.py b/tests/test_httpx_async.py index 87f1675a..56c65873 100644 --- a/tests/test_httpx_async.py +++ b/tests/test_httpx_async.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -429,7 +429,7 @@ async def handler(request): query = gql(query1_str) with pytest.raises(TransportClosed): - await transport.execute(query) + await transport.execute(GraphQLRequest(query)) @pytest.mark.aiohttp @@ -541,7 +541,9 @@ async def handler(request): query = gql(query2_str) # Execute query asynchronously - result = await session.execute(query, params, operation_name="getEurope") + result = await session.execute( + query, variable_values=params, operation_name="getEurope" + ) continent = result["continent"] diff --git a/tests/test_requests.py b/tests/test_requests.py index c184e230..ff6a5651 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -471,7 +471,7 @@ def test_code(): query = gql(query1_str) with pytest.raises(TransportClosed): - transport.execute(query) + transport.execute(GraphQLRequest(query)) await run_sync_test(server, test_code) @@ -580,11 +580,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: @@ -592,19 +592,19 @@ def test_code(): params = {"file": FileVar(f), "other_var": 42} with warnings.catch_warnings(): warnings.simplefilter("error") # Turn warnings into errors - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an filename string inside a FileVar object params = {"file": FileVar(file_path), "other_var": 42} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -651,11 +651,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] # Using an opened file inside a FileVar object with open(file_path, "rb") as f: @@ -664,11 +664,11 @@ def test_code(): "file": FileVar(f, content_type="application/pdf"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -713,11 +713,11 @@ def test_code(): "file": FileVar(file_path), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -760,11 +760,11 @@ def test_code(): "file": FileVar(f, filename="filename1.txt"), "other_var": 42, } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -807,11 +807,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -859,11 +859,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] await run_sync_test(server, test_code) @@ -937,11 +937,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_1, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -958,11 +958,11 @@ def test_code(): "file2": FileVar(f2), } - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_2, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -1037,11 +1037,11 @@ def test_code(): DeprecationWarning, match="Not using FileVar for file upload is deprecated", ): - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() @@ -1055,11 +1055,11 @@ def test_code(): params_2 = {"files": [FileVar(f1), FileVar(f2)]} - execution_result = session._execute( + execution_result = session.execute( query, variable_values=params_2, upload_files=True ) - assert execution_result.data["success"] + assert execution_result["success"] f1.close() f2.close() From 5d234ede2701c05e936f48d512c3428c6633b581 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 25 May 2025 22:35:28 +0200 Subject: [PATCH 2/2] set payload as GraphQLRequest property --- gql/graphql_request.py | 15 ++++++++++++++- gql/transport/aiohttp.py | 12 ++++++------ gql/transport/appsync_websockets.py | 2 +- gql/transport/async_transport.py | 16 ++-------------- gql/transport/httpx.py | 10 ++-------- gql/transport/requests.py | 4 ++-- gql/transport/transport.py | 16 ++-------------- gql/transport/websockets_protocol.py | 2 +- 8 files changed, 30 insertions(+), 47 deletions(-) diff --git a/gql/graphql_request.py b/gql/graphql_request.py index b0c68f5c..7289a8f9 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Dict, Optional -from graphql import DocumentNode, GraphQLSchema +from graphql import DocumentNode, GraphQLSchema, print_ast from .utilities import serialize_variable_values @@ -35,3 +35,16 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": ), operation_name=self.operation_name, ) + + @property + def payload(self) -> Dict[str, Any]: + query_str = print_ast(self.document) + payload: Dict[str, Any] = {"query": query_str} + + if self.operation_name: + payload["operationName"] = self.operation_name + + if self.variable_values: + payload["variables"] = self.variable_values + + return payload diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 89c410d6..0a677af3 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -170,7 +170,7 @@ def _prepare_batch_request( extra_args: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - payload = [self._build_payload(req) for req in reqs] + payload = [req.payload for req in reqs] post_args = {"json": payload} @@ -186,15 +186,15 @@ def _prepare_batch_request( def _prepare_request( self, - req: GraphQLRequest, + request: GraphQLRequest, extra_args: Optional[Dict[str, Any]] = None, upload_files: bool = False, ) -> Dict[str, Any]: - payload = self._build_payload(req) + payload = request.payload if upload_files: - post_args = self._prepare_file_uploads(req, payload) + post_args = self._prepare_file_uploads(request, payload) else: post_args = {"json": payload} @@ -216,11 +216,11 @@ def _prepare_request( return post_args def _prepare_file_uploads( - self, req: GraphQLRequest, payload: Dict[str, Any] + self, request: GraphQLRequest, payload: Dict[str, Any] ) -> Dict[str, Any]: # If the upload_files flag is set, then we need variable_values - variable_values = req.variable_values + variable_values = request.variable_values assert variable_values is not None # If we upload files, we will extract the files present in the diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index bb9fdb6a..e2ab4f96 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -158,7 +158,7 @@ async def _send_query( self.next_query_id += 1 - data: Dict[str, Any] = self._build_payload(request) + data: Dict[str, Any] = request.payload serialized_data = json.dumps(data, separators=(",", ":")) diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index f2ce6b24..526c97ba 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,7 +1,7 @@ import abc -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, List -from graphql import ExecutionResult, print_ast +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest @@ -63,15 +63,3 @@ def subscribe( raise NotImplementedError( "Any AsyncTransport subclass must implement subscribe method" ) # pragma: no cover - - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload diff --git a/gql/transport/httpx.py b/gql/transport/httpx.py index 43e14a6f..f3416c24 100644 --- a/gql/transport/httpx.py +++ b/gql/transport/httpx.py @@ -1,4 +1,3 @@ -import abc import io import json import logging @@ -58,11 +57,6 @@ def __init__( self.json_deserialize = json_deserialize self.kwargs = kwargs - @abc.abstractmethod - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - """This is Implemented in Transport and AsyncTransport""" - raise NotImplementedError() # pragma: no cover - def _prepare_request( self, req: GraphQLRequest, @@ -70,7 +64,7 @@ def _prepare_request( upload_files: bool = False, ) -> Dict[str, Any]: - payload = self._build_payload(req) + payload = req.payload if upload_files: post_args = self._prepare_file_uploads(req, payload) @@ -93,7 +87,7 @@ def _prepare_batch_request( extra_args: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: - payload = [self._build_payload(req) for req in reqs] + payload = [req.payload for req in reqs] post_args = {"json": payload} diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 6480260e..2087bbd0 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -162,7 +162,7 @@ def execute( # type: ignore if not self.session: raise TransportClosed("Transport is not connected") - payload = self._build_payload(request) + payload = request.payload post_args: Dict[str, Any] = { "headers": self.headers, @@ -372,7 +372,7 @@ def _build_batch_post_args( } data_key = "json" if self.use_json else "data" - post_args[data_key] = [self._build_payload(req) for req in reqs] + post_args[data_key] = [req.payload for req in reqs] # Log the payload if log.isEnabledFor(logging.INFO): diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 2b06ebdb..7a72f9a6 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,7 +1,7 @@ import abc -from typing import Any, Dict, List +from typing import Any, List -from graphql import ExecutionResult, print_ast +from graphql import ExecutionResult from ..graphql_request import GraphQLRequest @@ -54,15 +54,3 @@ def close(self): the session's connection pool. """ pass # pragma: no cover - - def _build_payload(self, req: GraphQLRequest) -> Dict[str, Any]: - query_str = print_ast(req.document) - payload: Dict[str, Any] = {"query": query_str} - - if req.operation_name: - payload["operationName"] = req.operation_name - - if req.variable_values: - payload["variables"] = req.variable_values - - return payload diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py index f1946665..3b66a0cb 100644 --- a/gql/transport/websockets_protocol.py +++ b/gql/transport/websockets_protocol.py @@ -237,7 +237,7 @@ async def _send_query( query_id = self.next_query_id self.next_query_id += 1 - payload: Dict[str, Any] = self._build_payload(request) + payload: Dict[str, Any] = request.payload query_type = "start"