Skip to content

Fix running execute and subscribe of client in a new Thread #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions gql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,13 @@ def execute(self, document: DocumentNode, *args, **kwargs) -> Dict:

if isinstance(self.transport, AsyncTransport):

loop = asyncio.get_event_loop()
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

assert not loop.is_running(), (
"Cannot run client.execute(query) if an asyncio loop is running."
Expand Down Expand Up @@ -146,9 +152,15 @@ def subscribe(
We need an async transport for this functionality.
"""

async_generator = self.subscribe_async(document, *args, **kwargs)
# Get the current asyncio event loop
# Or create a new event loop if there isn't one (in a new Thread)
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

loop = asyncio.get_event_loop()
async_generator = self.subscribe_async(document, *args, **kwargs)

assert not loop.is_running(), (
"Cannot run client.subscribe(query) if an asyncio loop is running."
Expand Down
8 changes: 8 additions & 0 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def __init__(
self.receive_data_task: Optional[asyncio.Future] = None
self.close_task: Optional[asyncio.Future] = None

# We need to set an event loop here if there is none
# Or else we will not be able to create an asyncio.Event()
try:
self._loop = asyncio.get_event_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)

self._wait_closed: asyncio.Event = asyncio.Event()
self._wait_closed.set()

Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pathlib
import ssl
import types
from concurrent.futures import ThreadPoolExecutor

import pytest
import websockets
Expand Down Expand Up @@ -266,3 +267,21 @@ async def client_and_server(server):

# Yield both client session and server
yield session, server


@pytest.fixture
async def run_sync_test():
async def run_sync_test_inner(event_loop, server, test_function):
"""This function will run the test in a different Thread.

This allows us to run sync code while aiohttp server can still run.
"""
executor = ThreadPoolExecutor(max_workers=2)
test_task = event_loop.run_in_executor(executor, test_function)

await test_task

if hasattr(server, "close"):
await server.close()

return run_sync_test_inner
59 changes: 59 additions & 0 deletions tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,62 @@ async def handler(request):
continent = result["continent"]

assert continent["name"] == "Europe"


@pytest.mark.asyncio
async def test_aiohttp_execute_running_in_thread(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = AIOHTTPTransport(url=url)

client = Client(transport=sample_transport)

query = gql(query1_str)

client.execute(query)

await run_sync_test(event_loop, server, test_code)


@pytest.mark.asyncio
async def test_aiohttp_subscribe_running_in_thread(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
sample_transport = AIOHTTPTransport(url=url)

client = Client(transport=sample_transport)

query = gql(query1_str)

# Note: subscriptions are not supported on the aiohttp transport
# But we add this test in order to have 100% code coverage
# It is to check that we will correctly set an event loop
# in the subscribe function if there is none (in a Thread for example)
# We cannot test this with the websockets transport because
# the websockets transport will set an event loop in its init

with pytest.raises(NotImplementedError):
for result in client.subscribe(query):
pass

await run_sync_test(event_loop, server, test_code)
33 changes: 11 additions & 22 deletions tests/test_requests.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from concurrent.futures import ThreadPoolExecutor

import pytest
from aiohttp import web

from gql import Client, gql
from gql import Client, RequestsHTTPTransport, gql
from gql.transport.exceptions import (
TransportAlreadyConnected,
TransportClosed,
TransportProtocolError,
TransportQueryError,
TransportServerError,
)
from gql.transport.requests import RequestsHTTPTransport

query1_str = """
query getContinents {
Expand All @@ -31,20 +28,8 @@
)


async def run_sync_test(event_loop, server, test_function):
"""This function will run the test in a different Thread.

This allows us to run sync code while aiohttp server can still run.
"""
executor = ThreadPoolExecutor(max_workers=2)
test_task = event_loop.run_in_executor(executor, test_function)

await test_task
await server.close()


@pytest.mark.asyncio
async def test_requests_query(event_loop, aiohttp_server):
async def test_requests_query(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand Down Expand Up @@ -74,7 +59,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_error_code_500(event_loop, aiohttp_server):
async def test_requests_error_code_500(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
# Will generate http error code 500
raise Exception("Server error")
Expand Down Expand Up @@ -102,7 +87,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_error_code(event_loop, aiohttp_server):
async def test_requests_error_code(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(
text=query1_server_error_answer, content_type="application/json"
Expand Down Expand Up @@ -136,7 +121,9 @@ def test_code():

@pytest.mark.asyncio
@pytest.mark.parametrize("response", invalid_protocol_responses)
async def test_requests_invalid_protocol(event_loop, aiohttp_server, response):
async def test_requests_invalid_protocol(
event_loop, aiohttp_server, response, run_sync_test
):
async def handler(request):
return web.Response(text=response, content_type="application/json")

Expand All @@ -160,7 +147,7 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server):
async def test_requests_cannot_connect_twice(event_loop, aiohttp_server, run_sync_test):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand All @@ -182,7 +169,9 @@ def test_code():


@pytest.mark.asyncio
async def test_requests_cannot_execute_if_not_connected(event_loop, aiohttp_server):
async def test_requests_cannot_execute_if_not_connected(
event_loop, aiohttp_server, run_sync_test
):
async def handler(request):
return web.Response(text=query1_server_answer, content_type="application/json")

Expand Down
29 changes: 29 additions & 0 deletions tests/test_websocket_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,32 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str)

# Check that the server received a connection_terminate message last
assert logged_messages.pop() == '{"type": "connection_terminate"}'


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_websocket_subscription_running_in_thread(
event_loop, server, subscription_str, run_sync_test
):
def test_code():
path = "/graphql"
url = f"ws://{server.hostname}:{server.port}{path}"
sample_transport = WebsocketsTransport(url=url)

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

for result in client.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert count == -1

await run_sync_test(event_loop, server, test_code)