diff --git a/gql/__init__.py b/gql/__init__.py index 7c21c1c8..bad425d4 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,6 +1,7 @@ from .client import Client from .gql import gql from .transport.aiohttp import AIOHTTPTransport +from .transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport from .transport.requests import RequestsHTTPTransport from .transport.websockets import WebsocketsTransport @@ -8,6 +9,7 @@ "gql", "AIOHTTPTransport", "Client", + "PhoenixChannelWebsocketsTransport", "RequestsHTTPTransport", "WebsocketsTransport", ] diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py new file mode 100644 index 00000000..6e96b72e --- /dev/null +++ b/gql/transport/phoenix_channel_websockets.py @@ -0,0 +1,250 @@ +import asyncio +import json +from typing import Dict, Optional, Tuple + +from graphql import DocumentNode, ExecutionResult, print_ast +from websockets.exceptions import ConnectionClosed + +from .exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, +) +from .websockets import WebsocketsTransport + + +class PhoenixChannelWebsocketsTransport(WebsocketsTransport): + def __init__( + self, channel_name: str, heartbeat_interval: float = 30, *args, **kwargs + ) -> None: + self.channel_name = channel_name + self.heartbeat_interval = heartbeat_interval + self.subscription_ids_to_query_ids: Dict[str, int] = {} + super(PhoenixChannelWebsocketsTransport, self).__init__(*args, **kwargs) + """Initialize the transport with the given request parameters. + + :param channel_name Channel on the server this transport will join + :param heartbeat_interval Interval in second between each heartbeat messages + sent by the client + """ + + async def _send_init_message_and_wait_ack(self) -> None: + """Join the specified channel and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + init_message = json.dumps( + { + "topic": self.channel_name, + "event": "phx_join", + "payload": {}, + "ref": query_id, + } + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout) + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type != "reply": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def heartbeat_coro(): + while True: + await asyncio.sleep(self.heartbeat_interval) + try: + query_id = self.next_query_id + self.next_query_id += 1 + + await self._send( + json.dumps( + { + "topic": "phoenix", + "event": "heartbeat", + "payload": {}, + "ref": query_id, + } + ) + ) + except ConnectionClosed: # pragma: no cover + return + + self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) + + async def _send_stop_message(self, query_id: int) -> None: + try: + await self.listeners[query_id].put(("complete", None)) + except KeyError: # pragma: no cover + pass + + async def _send_connection_terminate_message(self) -> None: + """Send a phx_leave message to disconnect from the provided channel. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + connection_terminate_message = json.dumps( + { + "topic": self.channel_name, + "event": "phx_leave", + "payload": {}, + "ref": query_id, + } + ) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + { + "topic": self.channel_name, + "event": "doc", + "payload": { + "query": print_ast(document), + "variables": variable_values or {}, + }, + "ref": query_id, + } + ) + + await self._send(query_str) + + return query_id + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server + + Returns a list consisting of: + - the answer_type (between: + 'heartbeat', 'data', 'reply', 'error', 'close') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + event: str = "" + answer_id: Optional[int] = None + answer_type: str = "" + execution_result: Optional[ExecutionResult] = None + + try: + json_answer = json.loads(answer) + + event = str(json_answer.get("event")) + + if event == "subscription:data": + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + subscription_id = str(payload.get("subscriptionId")) + try: + answer_id = self.subscription_ids_to_query_ids[subscription_id] + except KeyError: + raise ValueError( + f"subscription '{subscription_id}' has not been registerd" + ) + + result = payload.get("result") + + if not isinstance(result, dict): + raise ValueError("result is not a dict") + + answer_type = "data" + + execution_result = ExecutionResult( + errors=payload.get("errors"), data=result.get("data") + ) + + elif event == "phx_reply": + answer_id = int(json_answer.get("ref")) + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + status = str(payload.get("status")) + + if status == "ok": + + answer_type = "reply" + response = payload.get("response") + + if isinstance(response, dict) and "subscriptionId" in response: + subscription_id = str(response.get("subscriptionId")) + self.subscription_ids_to_query_ids[subscription_id] = answer_id + + elif status == "error": + response = payload.get("response") + + if isinstance(response, dict): + if "errors" in response: + raise TransportQueryError( + str(response.get("errors")), query_id=answer_id + ) + elif "reason" in response: + raise TransportQueryError( + str(response.get("reason")), query_id=answer_id + ) + raise ValueError("reply error") + + elif status == "timeout": + raise TransportQueryError("reply timeout", query_id=answer_id) + + elif event == "phx_error": + raise TransportServerError("Server error") + elif event == "phx_close": + answer_type = "close" + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + "Server did not return a GraphQL result" + ) from e + + return answer_type, answer_id, execution_result + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + if answer_type == "close": + await self.close() + else: + await super()._handle_answer(answer_type, answer_id, execution_result) + + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: + if self.heartbeat_task is not None: + self.heartbeat_task.cancel() + + await super()._close_coro(e, clean_close) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 0c88b150..b4552b8c 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -371,19 +371,25 @@ async def _receive_data_loop(self) -> None: await self._fail(e, clean_close=False) break - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put( - (answer_type, execution_result) - ) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass + await self._handle_answer(answer_type, answer_id, execution_result) finally: log.debug("Exiting _receive_data_loop()") + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put((answer_type, execution_result)) + except KeyError: + # Do nothing if no one is listening to this query_id. + pass + async def subscribe( self, document: DocumentNode, diff --git a/tests/conftest.py b/tests/conftest.py index acbbb3af..c2edc236 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -136,6 +136,8 @@ async def stop(self): print("Server stopped\n\n\n") + +class WebSocketServerHelper: @staticmethod async def send_complete(ws, query_id): await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') @@ -165,6 +167,26 @@ async def wait_connection_terminate(ws): assert json_result["type"] == "connection_terminate" +class PhoenixChannelServerHelper: + @staticmethod + async def send_close(ws): + await ws.send('{"event":"phx_close"}') + + @staticmethod + async def send_connection_ack(ws): + + # Line return for easy debugging + print("") + + # Wait for init + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "phx_join" + + # Send ack + await ws.send('{"event":"phx_reply", "payload": {"status": "ok"}, "ref": 1}') + + def get_server_handler(request): """Get the server handler. @@ -181,7 +203,7 @@ def get_server_handler(request): async def default_server_handler(ws, path): try: - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) query_id = 1 for answer in answers: @@ -195,10 +217,10 @@ async def default_server_handler(ws, path): formatted_answer = answer await ws.send(formatted_answer) - await WebSocketServer.send_complete(ws, query_id) + await WebSocketServerHelper.send_complete(ws, query_id) query_id += 1 - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() except ConnectionClosed: pass diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index 55239f9e..7b9ca253 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -8,7 +8,7 @@ from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef starwars_expected_one = { @@ -25,7 +25,7 @@ async def server_starwars(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) try: await ws.recv() @@ -42,8 +42,8 @@ async def server_starwars(ws, path): await ws.send(data) await asyncio.sleep(2 * MS) - await WebSocketServer.send_complete(ws, 1) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.wait_connection_terminate(ws) except websockets.exceptions.ConnectionClosedOK: pass diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py new file mode 100644 index 00000000..97283650 --- /dev/null +++ b/tests/test_phoenix_channel_exceptions.py @@ -0,0 +1,198 @@ +import pytest + +from gql import Client, gql +from gql.transport.exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, +) +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import PhoenixChannelServerHelper + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +default_subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +error_with_reason_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"reason":"internal error"},' + '"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +multiple_errors_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"errors": ["error 1", "error 2"]},' + '"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +timeout_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"timeout"},' + '"ref":2,' + '"topic":"test_topic"}' +) + + +def server( + query_server_answer, subscription_server_answer=default_subscription_server_answer, +): + async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + if query_server_answer is not None: + await ws.send(query_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + return phoenix_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + server(error_with_reason_server_answer), + server(multiple_errors_server_answer), + server(timeout_server_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_query_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportQueryError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +invalid_subscription_id_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"INVALID","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":3,' + '"topic":"test_topic"}' +) + +invalid_payload_server_answer = ( + '{"event":"subscription:data",' + '"payload":"INVALID",' + '"ref":3,' + '"topic":"test_topic"}' +) + +invalid_result_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"test_subscription","result": "INVALID"},' + '"ref":3,' + '"topic":"test_topic"}' +) + +generic_error_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"status":"error"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +protocol_server_answer = '{"event":"unknown"}' + +invalid_payload_subscription_server_answer = ( + '{"event":"phx_reply", "payload":"INVALID", "ref":2, "topic":"test_topic"}' +) + + +async def no_connection_ack_phoenix_server(ws, path): + await ws.recv() + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + server(invalid_subscription_id_server_answer), + server(invalid_result_server_answer), + server(generic_error_server_answer), + no_connection_ack_phoenix_server, + server(protocol_server_answer), + server(invalid_payload_server_answer), + server(None, invalid_payload_subscription_server_answer), + ], + indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_protocol_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportProtocolError): + async with Client(transport=sample_transport) as session: + await session.execute(query) + + +server_error_subscription_server_answer = ( + '{"event":"phx_error", "ref":2, "topic":"test_topic"}' +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server(None, server_error_subscription_server_answer)], indirect=True, +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_server_error(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + with pytest.raises(TransportServerError): + async with Client(transport=sample_transport) as session: + await session.execute(query) diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py new file mode 100644 index 00000000..d59050ac --- /dev/null +++ b/tests/test_phoenix_channel_query.py @@ -0,0 +1,69 @@ +import pytest + +from gql import Client, gql +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import PhoenixChannelServerHelper + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +query1_server_answer = ( + '{"event":"subscription:data","payload":' + '{"subscriptionId":"test_subscription","result":' + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}},' + '"ref":3,' + '"topic":"test_topic"}' +) + + +@pytest.fixture +def ws_server_helper(request): + yield PhoenixChannelServerHelper + + +async def phoenix_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + await ws.send(query1_server_answer) + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [phoenix_server], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_phoenix_channel_simple_query(event_loop, server, query_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + query = gql(query_str) + async with Client(transport=sample_transport) as session: + result = await session.execute(query) + + print("Client received:", result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py new file mode 100644 index 00000000..8efc9899 --- /dev/null +++ b/tests/test_phoenix_channel_subscription.py @@ -0,0 +1,175 @@ +import asyncio +import json + +import pytest +import websockets +from parse import search + +from gql import Client, gql +from gql.transport.phoenix_channel_websockets import PhoenixChannelWebsocketsTransport + +from .conftest import MS, PhoenixChannelServerHelper + +subscription_server_answer = ( + '{"event":"phx_reply",' + '"payload":' + '{"response":' + '{"subscriptionId":"test_subscription"},' + '"status":"ok"},' + '"ref":2,' + '"topic":"test_topic"}' +) + +countdown_server_answer = ( + '{{"event":"subscription:data",' + '"payload":{{"subscriptionId":"test_subscription","result":' + '{{"data":{{"number":{number}}}}}}},' + '"ref":{query_id}}}' +) + + +async def server_countdown(ws, path): + try: + await PhoenixChannelServerHelper.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["event"] == "doc" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["ref"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + await ws.send(subscription_server_answer) + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + result = await ws.recv() + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + stopping_task = asyncio.ensure_future(stopping_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + await PhoenixChannelServerHelper.send_close(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_phoenix_channel_subscription(event_loop, server, subscription_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url + ) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with Client(transport=sample_transport) as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +heartbeat_server_answer = ( + '{{"event":"subscription:data",' + '"payload":{{"subscriptionId":"test_subscription","result":' + '{{"data":{{"heartbeat_count":{count}}}}}}},' + '"ref":1}}' +) + + +async def phoenix_heartbeat_server(ws, path): + await PhoenixChannelServerHelper.send_connection_ack(ws) + await ws.recv() + await ws.send(subscription_server_answer) + + for i in range(3): + heartbeat_result = await ws.recv() + json_result = json.loads(heartbeat_result) + assert json_result["event"] == "heartbeat" + await ws.send(heartbeat_server_answer.format(count=i)) + + await PhoenixChannelServerHelper.send_close(ws) + await ws.wait_closed() + + +heartbeat_subscription_str = """ + subscription { + heartbeat { + heartbeat_count + } + } +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [phoenix_heartbeat_server], indirect=True) +@pytest.mark.parametrize("subscription_str", [heartbeat_subscription_str]) +async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = PhoenixChannelWebsocketsTransport( + channel_name="test_channel", url=url, heartbeat_interval=1 + ) + + subscription = gql(heartbeat_subscription_str) + async with Client(transport=sample_transport) as session: + i = 0 + async for result in session.subscribe(subscription): + heartbeat_count = result["heartbeat_count"] + print(f"Heartbeat count received: {heartbeat_count}") + + assert heartbeat_count == i + i += 1 diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index aed27189..a2678a4a 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -15,7 +15,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper invalid_query_str = """ query getContinents { @@ -69,10 +69,10 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) async def server_invalid_subscription(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -123,7 +123,7 @@ async def test_websocket_server_does_not_send_ack(event_loop, server, query_str) async def server_connection_error(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(connection_error_server_answer) @@ -150,11 +150,11 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que async def server_invalid_payload(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(invalid_payload_server_answer) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() @@ -241,7 +241,7 @@ async def test_websocket_transport_protocol_errors(event_loop, client_and_server async def server_without_ack(ws, path): # Sending something else than an ack - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -278,7 +278,7 @@ async def test_websocket_server_closing_directly(event_loop, server): async def server_closing_after_ack(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) await ws.close() @@ -300,7 +300,7 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) async def server_sending_invalid_query_errors(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) invalid_error = ( '{"type":"error","id":"404","payload":' '{"message":"error for no good reason on non existing query"}}' diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 13014c28..d44aa779 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -15,7 +15,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper query1_str = """ query getContinents { @@ -153,16 +153,16 @@ async def test_websocket_two_queries_in_series( async def server1_two_queries_in_parallel(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) await ws.send(query1_server_answer.format(query_id=2)) - await WebSocketServer.send_complete(ws, 1) - await WebSocketServer.send_complete(ws, 2) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 2) + await WebSocketServerHelper.wait_connection_terminate(ws) await ws.wait_closed() @@ -200,11 +200,11 @@ async def task2_coro(): async def server_closing_while_we_are_doing_something_else(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await asyncio.sleep(1 * MS) # Closing server after first query @@ -350,7 +350,7 @@ async def server_with_authentication_in_connection_init_payload(ws, path): result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) else: await ws.send( '{"type":"connection_error", "payload": "Invalid Authorization token"}' @@ -475,15 +475,15 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): async def server_sending_keep_alive_before_connection_ack(ws, path): - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_keepalive(ws) - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() print(f"Server received: {result}") await ws.send(query1_server_answer.format(query_id=1)) - await WebSocketServer.send_complete(ws, 1) + await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 19133bd6..8152e07c 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -10,7 +10,7 @@ from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from .conftest import MS, WebSocketServer +from .conftest import MS, WebSocketServerHelper countdown_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' @@ -28,9 +28,9 @@ async def server_countdown(ws, path): global WITH_KEEPALIVE try: - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) if WITH_KEEPALIVE: - await WebSocketServer.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) result = await ws.recv() logged_messages.append(result) @@ -74,7 +74,7 @@ async def keepalive_coro(): while True: await asyncio.sleep(5 * MS) try: - await WebSocketServer.send_keepalive(ws) + await WebSocketServerHelper.send_keepalive(ws) except websockets.exceptions.ConnectionClosed: break @@ -100,8 +100,8 @@ async def keepalive_coro(): except asyncio.CancelledError: print("Now keepalive task is cancelled") - await WebSocketServer.send_complete(ws, query_id) - await WebSocketServer.wait_connection_terminate(ws) + await WebSocketServerHelper.send_complete(ws, query_id) + await WebSocketServerHelper.wait_connection_terminate(ws) except websockets.exceptions.ConnectionClosedOK: pass finally: @@ -246,7 +246,7 @@ async def close_transport_task_coro(): async def server_countdown_close_connection_in_middle(ws, path): - await WebSocketServer.send_connection_ack(ws) + await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() json_result = json.loads(result)