Skip to content

Implementation of automatic batching for async #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ The complete documentation for GQL can be found at
* Supports [sync or async usage](https://gql.readthedocs.io/en/latest/async/index.html), [allowing concurrent requests](https://gql.readthedocs.io/en/latest/advanced/async_advanced_usage.html#async-advanced-usage)
* Supports [File uploads](https://gql.readthedocs.io/en/latest/usage/file_upload.html)
* Supports [Custom scalars / Enums](https://gql.readthedocs.io/en/latest/usage/custom_scalars_and_enums.html)
* Supports [Batching requests](https://gql.readthedocs.io/en/latest/advanced/batching_requests.html)
* [gql-cli script](https://gql.readthedocs.io/en/latest/gql-cli/intro.html) to execute GraphQL queries or download schemas from the command line
* [DSL module](https://gql.readthedocs.io/en/latest/advanced/dsl_module.html) to compose GraphQL queries dynamically

Expand Down
96 changes: 96 additions & 0 deletions docs/advanced/batching_requests.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
.. _batching_requests:

Batching requests
=================

If you need to send multiple GraphQL queries to a backend,
and if the backend supports batch requests,
then you might want to send those requests in a batch instead of
making multiple execution requests.

.. warning::
- Some backends do not support batch requests
- File uploads and subscriptions are not supported with batch requests

Batching requests manually
^^^^^^^^^^^^^^^^^^^^^^^^^^

To execute a batch of requests manually:

- First Make a list of :class:`GraphQLRequest <gql.GraphQLRequest>` objects, containing:
* your GraphQL query
* Optional variable_values
* Optional operation_name

.. code-block:: python

request1 = GraphQLRequest("""
query getContinents {
continents {
code
name
}
}
"""
)

request2 = GraphQLRequest("""
query getContinentName ($code: ID!) {
continent (code: $code) {
name
}
}
""",
variable_values={
"code": "AF",
},
)

requests = [request1, request2]

- Then use one of the `execute_batch` methods, either on Client,
or in a sync or async session

**Sync**:

.. code-block:: python

transport = RequestsHTTPTransport(url=url)
# Or transport = HTTPXTransport(url=url)

with Client(transport=transport) as session:

results = session.execute_batch(requests)

result1 = results[0]
result2 = results[1]

**Async**:

.. code-block:: python

transport = AIOHTTPTransport(url=url)
# Or transport = HTTPXAsyncTransport(url=url)

async with Client(transport=transport) as session:

results = await session.execute_batch(requests)

result1 = results[0]
result2 = results[1]

.. note::
If any request in the batch returns an error, then a TransportQueryError will be raised
with the first error found.

Automatic Batching of requests
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

If your code execute multiple requests independently in a short time
(either from different threads in sync code, or from different asyncio tasks in async code),
then you can use gql automatic batching of request functionality.

You define a :code:`batching_interval` in your :class:`Client <gql.Client>`
and each time a new execution request is received through an `execute` method,
we will wait that interval (in seconds) for other requests to arrive
before sending all the requests received in that interval in a single batch.
1 change: 1 addition & 0 deletions docs/advanced/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Advanced

async_advanced_usage
async_permanent_session
batching_requests
logging
error_handling
local_schema
Expand Down
177 changes: 160 additions & 17 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,15 +829,11 @@ async def connect_async(self, reconnecting=False, **kwargs):

if reconnecting:
self.session = ReconnectingAsyncClientSession(client=self, **kwargs)
await self.session.start_connecting_task()
else:
try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e
self.session = AsyncClientSession(client=self)

await self.session.connect()

# Get schema from transport if needed
try:
if self.fetch_schema_from_transport and not self.schema:
Expand All @@ -846,18 +842,15 @@ async def connect_async(self, reconnecting=False, **kwargs):
# we don't know what type of exception is thrown here because it
# depends on the underlying transport; we just make sure that the
# transport is closed and re-raise the exception
await self.transport.close()
await self.session.close()
raise

return self.session

async def close_async(self):
"""Close the async transport and stop the optional reconnecting task."""

if isinstance(self.session, ReconnectingAsyncClientSession):
await self.session.stop_connecting_task()

await self.transport.close()
await self.session.close()

async def __aenter__(self):
return await self.connect_async()
Expand Down Expand Up @@ -1564,12 +1557,17 @@ async def _execute(
):
request = request.serialize_variable_values(self.client.schema)

# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)
# Check if batching is enabled
if self.client.batching_enabled:
future_result = await self._execute_future(request)
result = await future_result
else:
# Execute the query with the transport with a timeout
with fail_after(self.client.execute_timeout):
result = await self.transport.execute(
request,
**kwargs,
)

# Unserialize the result if requested
if self.client.schema:
Expand Down Expand Up @@ -1828,6 +1826,134 @@ async def execute_batch(

return cast(List[Dict[str, Any]], [result.data for result in results])

async def _batch_loop(self) -> None:
"""Main loop of the task used to wait for requests
to execute them in a batch"""

stop_loop = False

while not stop_loop:
# First wait for a first request in from the batch queue
requests_and_futures: List[Tuple[GraphQLRequest, asyncio.Future]] = []

# Wait for the first request
request_and_future: Optional[Tuple[GraphQLRequest, asyncio.Future]] = (
await self.batch_queue.get()
)

if request_and_future is None:
# None is our sentinel value to stop the loop
break

requests_and_futures.append(request_and_future)

# Then wait the requested batch interval except if we already
# have the maximum number of requests in the queue
if self.batch_queue.qsize() < self.client.batch_max - 1:
# Wait for the batch interval
await asyncio.sleep(self.client.batch_interval)

# Then get the requests which had been made during that wait interval
for _ in range(self.client.batch_max - 1):
try:
# Use get_nowait since we don't want to wait here
request_and_future = self.batch_queue.get_nowait()

if request_and_future is None:
# Sentinel value - stop after processing current batch
stop_loop = True
break

requests_and_futures.append(request_and_future)

except asyncio.QueueEmpty:
# No more requests in queue, that's fine
break

# Extract requests and futures
requests = [request for request, _ in requests_and_futures]
futures = [future for _, future in requests_and_futures]

# Execute the batch
try:
results: List[ExecutionResult] = await self._execute_batch(
requests,
serialize_variables=False, # already done
parse_result=False, # will be done later
validate_document=False, # already validated
)

# Set the result for each future
for result, future in zip(results, futures):
if not future.cancelled():
future.set_result(result)

except Exception as exc:
# If batch execution fails, propagate the error to all futures
for future in futures:
if not future.cancelled():
future.set_exception(exc)

# Signal that the task has stopped
self._batch_task_stopped_event.set()

async def _execute_future(
self,
request: GraphQLRequest,
) -> asyncio.Future:
"""If batching is enabled, this method will put a request in the batching queue
instead of executing it directly so that the requests could be put in a batch.
"""

assert hasattr(self, "batch_queue"), "Batching is not enabled"
assert not self._batch_task_stop_requested, "Batching task has been stopped"

future: asyncio.Future = asyncio.Future()
await self.batch_queue.put((request, future))

return future

async def _batch_init(self):
"""Initialize the batch task loop if batching is enabled."""
if self.client.batching_enabled:
self.batch_queue: asyncio.Queue = asyncio.Queue()
self._batch_task_stop_requested = False
self._batch_task_stopped_event = asyncio.Event()
self._batch_task = asyncio.create_task(self._batch_loop())

async def _batch_cleanup(self):
"""Cleanup the batching task if batching is enabled."""
if hasattr(self, "_batch_task_stopped_event"):
# Send a None in the queue to indicate that the batching task must stop
# after having processed the remaining requests in the queue
self._batch_task_stop_requested = True
await self.batch_queue.put(None)

# Wait for the task to process remaining requests and stop
await self._batch_task_stopped_event.wait()

async def connect(self):
"""Connect the transport and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

try:
await self.transport.connect()
except Exception as e:
await self.transport.close()
raise e

async def close(self):
"""Close the transport and cleanup the batching task if batching is enabled.

Will wait until all the remaining requests in the batch processing queue
have been executed.
"""
await self._batch_cleanup()

await self.transport.close()

async def fetch_schema(self) -> None:
"""Fetch the GraphQL schema explicitly using introspection.

Expand Down Expand Up @@ -1954,6 +2080,23 @@ async def stop_connecting_task(self):
self._connect_task.cancel()
self._connect_task = None

async def connect(self):
"""Start the connect task and initialize the batch task loop if batching
is enabled."""

await self._batch_init()

await self.start_connecting_task()

async def close(self):
"""Stop the connect task and cleanup the batching task
if batching is enabled."""
await self._batch_cleanup()

await self.stop_connecting_task()

await self.transport.close()

async def _execute_once(
self,
request: GraphQLRequest,
Expand Down
39 changes: 27 additions & 12 deletions gql/graphql_request.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
from dataclasses import dataclass
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

from graphql import DocumentNode, GraphQLSchema, print_ast

from .gql import gql
from .utilities import serialize_variable_values


@dataclass(frozen=True)
class GraphQLRequest:
"""GraphQL Request to be executed."""

document: DocumentNode
"""GraphQL query as AST Node object."""
def __init__(
self,
document: Union[DocumentNode, str],
*,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
):
"""
Initialize a GraphQL request.

variable_values: Optional[Dict[str, Any]] = None
"""Dictionary of input parameters (Default: None)."""
Args:
document: GraphQL query as AST Node object or as a string.
If string, it will be converted to DocumentNode using gql().
variable_values: Dictionary of input parameters (Default: None).
operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""
if isinstance(document, str):
self.document = gql(document)
else:
self.document = document

operation_name: Optional[str] = None
"""
Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
"""
self.variable_values = variable_values
self.operation_name = operation_name

def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
assert self.variable_values
Expand Down Expand Up @@ -48,3 +60,6 @@ def payload(self) -> Dict[str, Any]:
payload["variables"] = self.variable_values

return payload

def __str__(self):
return str(self.payload)
Loading
Loading