diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 0c332205..c1302794 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,5 +1,4 @@ import asyncio -import functools import io import json import logging @@ -28,6 +27,7 @@ from ..utils import extract_files from .appsync_auth import AppSyncAuthentication from .async_transport import AsyncTransport +from .common.aiohttp_closed_event import create_aiohttp_closed_event from .exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -147,59 +147,6 @@ async def connect(self) -> None: else: raise TransportAlreadyConnected("Transport is already connected") - @staticmethod - def create_aiohttp_closed_event(session) -> asyncio.Event: - """Work around aiohttp issue that doesn't properly close transports on exit. - - See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 - - Returns: - An event that will be set once all transports have been properly closed. - """ - - ssl_transports = 0 - all_is_lost = asyncio.Event() - - def connection_lost(exc, orig_lost): - nonlocal ssl_transports - - try: - orig_lost(exc) - finally: - ssl_transports -= 1 - if ssl_transports == 0: - all_is_lost.set() - - def eof_received(orig_eof_received): - try: # pragma: no cover - orig_eof_received() - except AttributeError: # pragma: no cover - # It may happen that eof_received() is called after - # _app_protocol and _transport are set to None. - pass - - for conn in session.connector._conns.values(): - for handler, _ in conn: - proto = getattr(handler.transport, "_ssl_protocol", None) - if proto is None: - continue - - ssl_transports += 1 - orig_lost = proto.connection_lost - orig_eof_received = proto.eof_received - - proto.connection_lost = functools.partial( - connection_lost, orig_lost=orig_lost - ) - proto.eof_received = functools.partial( - eof_received, orig_eof_received=orig_eof_received - ) - - if ssl_transports == 0: - all_is_lost.set() - - return all_is_lost - async def close(self) -> None: """Coroutine which will close the aiohttp session. @@ -219,7 +166,7 @@ async def close(self) -> None: log.debug("connector_owner is False -> not closing connector") else: - closed_event = self.create_aiohttp_closed_event(self.session) + closed_event = create_aiohttp_closed_event(self.session) await self.session.close() try: await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) diff --git a/gql/transport/aiohttp_websockets.py b/gql/transport/aiohttp_websockets.py index 18699b5e..59d870f6 100644 --- a/gql/transport/aiohttp_websockets.py +++ b/gql/transport/aiohttp_websockets.py @@ -1,106 +1,26 @@ -import asyncio -import json -import logging -import warnings -from contextlib import suppress from ssl import SSLContext -from typing import ( - Any, - AsyncGenerator, - Collection, - Dict, - Literal, - Mapping, - Optional, - Tuple, - Union, -) +from typing import Any, Dict, List, Literal, Mapping, Optional, Union -import aiohttp -from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp import BasicAuth, ClientSession, Fingerprint from aiohttp.typedefs import LooseHeaders, StrOrURL -from graphql import DocumentNode, ExecutionResult, print_ast -from multidict import CIMultiDictProxy -from gql.transport.aiohttp import AIOHTTPTransport -from gql.transport.async_transport import AsyncTransport -from gql.transport.exceptions import ( - TransportAlreadyConnected, - TransportClosed, - TransportProtocolError, - TransportQueryError, - TransportServerError, -) +from .common.adapters.aiohttp import AIOHTTPWebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger("gql.transport.aiohttp_websockets") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] +class AIOHTTPWebsocketsTransport(WebsocketsProtocolTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue + This transport uses asyncio and the provided aiohttp adapter library + in order to send requests on a websocket connection. """ - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class AIOHTTPWebsocketsTransport(AsyncTransport): - - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL: str = "graphql-ws" - GRAPHQLWS_SUBPROTOCOL: str = "graphql-transport-ws" - def __init__( self, url: StrOrURL, *, - subprotocols: Optional[Collection[str]] = None, + subprotocols: Optional[List[str]] = None, heartbeat: Optional[float] = None, auth: Optional[BasicAuth] = None, origin: Optional[str] = None, @@ -121,8 +41,9 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, + session: Optional[ClientSession] = None, client_session_args: Optional[Dict[str, Any]] = None, - connect_args: Dict[str, Any] = {}, + connect_args: Optional[Dict[str, Any]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -193,6 +114,7 @@ def __init__( :param answer_pings: Whether the client answers the pings from the backend (for the graphql-ws protocol). By default: True + :param session: Optional aiohttp.ClientSession instance. :param client_session_args: Dict of extra args passed to `aiohttp.ClientSession`_ :param connect_args: Dict of extra args passed to @@ -203,986 +125,46 @@ def __init__( .. _aiohttp.ClientSession: https://docs.aiohttp.org/en/stable/client_reference.html#aiohttp.ClientSession """ - self.url: StrOrURL = url - self.heartbeat: Optional[float] = heartbeat - self.auth: Optional[BasicAuth] = auth - self.origin: Optional[str] = origin - self.params: Optional[Mapping[str, str]] = params - self.headers: Optional[LooseHeaders] = headers - - self.proxy: Optional[StrOrURL] = proxy - self.proxy_auth: Optional[BasicAuth] = proxy_auth - self.proxy_headers: Optional[LooseHeaders] = proxy_headers - - self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl - - self.websocket_close_timeout: float = websocket_close_timeout - self.receive_timeout: Optional[float] = receive_timeout - - self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout - self.connect_timeout: Optional[Union[int, float]] = connect_timeout - self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout - self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout - - self.init_payload: Dict[str, Any] = init_payload - - # We need to set an event loop here if there is none - # Or else we will not be able to create an asyncio.Event() - try: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="There is no current event loop" - ) - self._loop = asyncio.get_event_loop() - except RuntimeError: - self._loop = asyncio.new_event_loop() - asyncio.set_event_loop(self._loop) - - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() - - self.session: Optional[aiohttp.ClientSession] = None - self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None - self.next_query_id: int = 1 - self.listeners: Dict[int, ListenerQueue] = {} - self._connecting: bool = False - self.response_headers: Optional[CIMultiDictProxy[str]] = None - - self.receive_data_task: Optional[asyncio.Future] = None - self.check_keep_alive_task: Optional[asyncio.Future] = None - self.close_task: Optional[asyncio.Future] = None - - self._wait_closed: asyncio.Event = asyncio.Event() - self._wait_closed.set() - - self._no_more_listeners: asyncio.Event = asyncio.Event() - self._no_more_listeners.set() - - self.payloads: Dict[str, Any] = {} - - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - self.supported_subprotocols: Collection[str] = subprotocols or ( - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ) - - self.close_exception: Optional[Exception] = None - - self.client_session_args = client_session_args - self.connect_args = connect_args - - def _parse_answer_graphqlws( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, _, _ = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = {"type": "connection_init", "payload": self.init_payload} - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - """Hook to send the initialization messages after the connection - and potentially wait for the backend ack. - """ - await self._send_init_message_and_wait_ack() - - async def _stop_listener(self, query_id: int): - """Hook to stop to listen to a specific query. - Will send a stop message in some subclasses. - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _after_connect(self): - """Hook to add custom code for subclasses after the connection - has been established. - """ - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket._response.headers - log.debug(f"Response headers: {response_headers!r}") - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(ping_message) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(pong_message) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = {"id": str(query_id), "type": "stop"} - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = {"id": str(query_id), "type": "complete"} - - await self._send(complete_message) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _after_initialize(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - """Hook to add custom code for subclasses for the connection close""" - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError): - await self.send_ping_task - self.send_ping_task = None - - async def _connection_terminate(self): - """Hook to add custom code for subclasses after the initialization - has been done. - """ - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = {"type": "connection_terminate"} - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query = {"id": str(query_id), "type": query_type, "payload": payload} - - await self._send(query) - - return query_id - - async def _send(self, message: Dict[str, Any]) -> None: - """Send the provided message to the websocket connection and log the message""" - - if self.websocket is None: - raise TransportClosed("WebSocket connection is closed") - - try: - await self.websocket.send_json(message) - log.info(">>> %s", message) - except ConnectionResetError as e: - await self._fail(e, clean_close=False) - raise e - - async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" - - # It is possible that the websocket has been already closed in another task - if self.websocket is None: - raise TransportClosed("Transport is already closed") - - while True: - ws_message = await self.websocket.receive() - - # Ignore low-level ping and pong received - if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): - break - - if ws_message.type in ( - WSMsgType.CLOSE, - WSMsgType.CLOSED, - WSMsgType.CLOSING, - WSMsgType.ERROR, - ): - raise ConnectionResetError - elif ws_message.type is WSMsgType.BINARY: - raise TransportProtocolError("Binary data received in the websocket") - - assert ws_message.type is WSMsgType.TEXT - - answer: str = ws_message.data - - log.info("<<< %s", answer) - - return answer - - def _remove_listener(self, query_id) -> None: - """After exiting from a subscription, remove the listener and - signal an event if this was the last listener for the client. - """ - if query_id in self.listeners: - del self.listeners[query_id] - - remaining = len(self.listeners) - log.debug(f"listener {query_id} deleted, {remaining} remaining") - - if remaining == 0: - self._no_more_listeners.set() - - async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection - through keep-alive messages - """ - - try: - while True: - await asyncio.wait_for( - self._next_keep_alive_message.wait(), self.keep_alive_timeout - ) - - # Reset for the next iteration - self._next_keep_alive_message.clear() - - except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=False, - ) - - except asyncio.CancelledError: - # The client is probably closing, handle it properly - pass - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - try: - # Put the answer in the queue - if answer_id is not None: - await self.listeners[answer_id].put((answer_type, execution_result)) - except KeyError: - # Do nothing if no one is listening to this query_id. - pass - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _receive_data_loop(self) -> None: - """Main asyncio task which will listen to the incoming messages and will - call the parse_answer and handle_answer methods of the subclass.""" - log.debug("Entering _receive_data_loop()") - - try: - while True: - - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except (ConnectionResetError, TransportProtocolError) as e: - await self._fail(e, clean_close=False) - break - except TransportClosed as e: - await self._fail(e, clean_close=False) - raise e - - # Parse the answer - try: - answer_type, answer_id, execution_result = self._parse_answer( - answer - ) - except TransportQueryError as e: - # Received an exception for a specific query - # ==> Add an exception to this query queue - # The exception is raised for this specific query, - # but the transport is not closed. - assert isinstance( - e.query_id, int - ), "TransportQueryError should have a query_id defined here" - try: - await self.listeners[e.query_id].set_exception(e) - except KeyError: - # Do nothing if no one is listening to this query_id - pass - - continue - - except (TransportServerError, TransportProtocolError) as e: - # Received a global exception for this transport - # ==> close the transport - # The exception will be raised for all current queries. - await self._fail(e, clean_close=False) - break - - await self._handle_answer(answer_type, answer_id, execution_result) - - finally: - log.debug("Exiting _receive_data_loop()") - - async def connect(self) -> None: - log.debug("connect: starting") - - if self.session is None: - client_session_args: Dict[str, Any] = {} - - # Adding custom parameters passed from init - if self.client_session_args: - client_session_args.update(self.client_session_args) # type: ignore - - self.session = aiohttp.ClientSession(**client_session_args) - - if self.websocket is None and not self._connecting: - self._connecting = True - - connect_args: Dict[str, Any] = { - "url": self.url, - "headers": self.headers, - "auth": self.auth, - "heartbeat": self.heartbeat, - "origin": self.origin, - "params": self.params, - "protocols": self.supported_subprotocols, - "proxy": self.proxy, - "proxy_auth": self.proxy_auth, - "proxy_headers": self.proxy_headers, - "timeout": self.websocket_close_timeout, - "receive_timeout": self.receive_timeout, - } - - if self.ssl is not None: - connect_args.update( - { - "ssl": self.ssl, - } - ) - - # Adding custom parameters passed from init - if self.connect_args: - connect_args.update(self.connect_args) - - try: - # Connection to the specified url - # Generate a TimeoutError if taking more than connect_timeout seconds - # Set the _connecting flag to False after in all cases - self.websocket = await asyncio.wait_for( - self.session.ws_connect( - **connect_args, - ), - self.connect_timeout, - ) - finally: - self._connecting = False - - self.response_headers = self.websocket._response.headers - - await self._after_connect() - - self.next_query_id = 1 - self.close_exception = None - self._wait_closed.clear() - - # Send the init message and wait for the ack from the server - # Note: This should generate a TimeoutError - # if no ACKs are received within the ack_timeout - try: - await self._initialize() - except ConnectionResetError as e: - raise e - except ( - TransportProtocolError, - TransportServerError, - asyncio.TimeoutError, - ) as e: - await self._fail(e, clean_close=False) - raise e - - # Run the after_init hook of the subclass - await self._after_initialize() - - # If specified, create a task to check liveness of the connection - # through keep-alive messages - if self.keep_alive_timeout is not None: - self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness() - ) - - # Create a task to listen to the incoming websocket messages - self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) - - else: - raise TransportAlreadyConnected("Transport is already connected") - - log.debug("connect: done") - - async def _clean_close(self) -> None: - """Coroutine which will: - - - send stop messages for each active subscription to the server - - send the connection terminate message - """ - log.debug(f"Listeners: {self.listeners}") - - # Send 'stop' message for all current queries - for query_id, listener in self.listeners.items(): - print(f"Listener {query_id} send_stop: {listener.send_stop}") - - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False - - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) - except asyncio.TimeoutError: # pragma: no cover - log.debug("Timer close_timeout fired") - - # Calling the subclass hook - await self._connection_terminate() - - async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: - """Coroutine which will: - - - do a clean_close if possible: - - send stop messages for each active query to the server - - send the connection terminate message - - close the websocket connection - - send the exception to all the remaining listeners - """ - - log.debug("_close_coro: starting") - - try: - - try: - # Properly shut down liveness checker if enabled - if self.check_keep_alive_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.check_keep_alive_task.cancel() - with suppress(asyncio.CancelledError): - await self.check_keep_alive_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel keep alive task exception: " + repr(exc) - ) - - try: - # Calling the subclass close hook - await self._close_hook() - except Exception as exc: # pragma: no cover - log.warning("_close_coro close_hook exception: " + repr(exc)) - - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - - if clean_close: - log.debug("_close_coro: starting clean_close") - try: - await self._clean_close() - except Exception as exc: # pragma: no cover - log.warning("Ignoring exception in _clean_close: " + repr(exc)) - - log.debug("_close_coro: sending exception to listeners") - - # Send an exception to all remaining listeners - for query_id, listener in self.listeners.items(): - await listener.set_exception(e) - - log.debug("_close_coro: close websocket connection") - - try: - assert self.websocket is not None - - await self.websocket.close() - self.websocket = None - except Exception as exc: - log.warning("_close_coro websocket close exception: " + repr(exc)) - - log.debug("_close_coro: close aiohttp session") - - if ( - self.client_session_args - and self.client_session_args.get("connector_owner") is False - ): - - log.debug("connector_owner is False -> not closing connector") - - else: - try: - assert self.session is not None - - closed_event = AIOHTTPTransport.create_aiohttp_closed_event( - self.session - ) - await self.session.close() - try: - await asyncio.wait_for( - closed_event.wait(), self.ssl_close_timeout - ) - except asyncio.TimeoutError: - pass - except Exception as exc: # pragma: no cover - log.warning("_close_coro session close exception: " + repr(exc)) - - self.session = None - - log.debug("_close_coro: aiohttp session closed") - - try: - assert self.receive_data_task is not None - - self.receive_data_task.cancel() - with suppress(asyncio.CancelledError): - await self.receive_data_task - except Exception as exc: # pragma: no cover - log.warning( - "_close_coro cancel receive data task exception: " + repr(exc) - ) - - except Exception as exc: # pragma: no cover - log.warning("Exception catched in _close_coro: " + repr(exc)) - - finally: - - log.debug("_close_coro: final cleanup") - - self.websocket = None - self.close_task = None - self.check_keep_alive_task = None - self.receive_data_task = None - self._wait_closed.set() - - log.debug("_close_coro: exiting") - - async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) - - if self.close_task is None: - - if self._wait_closed.is_set(): - log.debug("_fail started but transport is already closed") - else: - self.close_task = asyncio.shield( - asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) - ) - else: - log.debug( - "close_task is not None in _fail. Previous exception is: " - + repr(self.close_exception) - + " New exception is: " - + repr(e) - ) - - async def close(self) -> None: - log.debug("close: starting") - - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) - await self.wait_closed() - - log.debug("close: done") - - async def wait_closed(self) -> None: - log.debug("wait_close: starting") - - if not self._wait_closed.is_set(): - await self._wait_closed.wait() - - log.debug("wait_close: done") - - async def execute( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST against the configured remote server - using the current session. - - Send a query but close the async generator as soon as we have the first answer. - - The result is sent as an ExecutionResult object. - """ - first_result = None - - generator = self.subscribe( - document, variable_values, operation_name, send_stop=False + # Instanciate a AIOHTTPWebSocketAdapter to indicate the use + # of the aiohttp dependency for this transport + self.adapter: AIOHTTPWebSocketsAdapter = AIOHTTPWebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + session=session, + client_session_args=client_session_args, + connect_args=connect_args, + heartbeat=heartbeat, + auth=auth, + origin=origin, + params=params, + proxy=proxy, + proxy_auth=proxy_auth, + proxy_headers=proxy_headers, + websocket_close_timeout=websocket_close_timeout, + receive_timeout=receive_timeout, + ssl_close_timeout=ssl_close_timeout, ) - async for result in generator: - first_result = result - break - - if first_result is None: - raise TransportQueryError( - "Query completed without any answer received from the server" - ) - - return first_result - - async def subscribe( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - send_stop: Optional[bool] = True, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using a python async generator. - - The query can be a graphql query, mutation or subscription. - - The results are sent as an ExecutionResult object. - """ - - # Send the query and receive the id - query_id: int = await self._send_query( - document, variable_values, operation_name + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - # Create a queue to receive the answers for this query_id - listener = ListenerQueue(query_id, send_stop=(send_stop is True)) - self.listeners[query_id] = listener - - # We will need to wait at close for this query to clean properly - self._no_more_listeners.clear() - - try: - # Loop over the received answers - while True: - - # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. - answer_type, execution_result = await listener.get() - - # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object - if execution_result is not None: - yield execution_result - - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output without errors - elif answer_type == "complete": - log.debug( - f"Complete received for query {query_id} --> exit without error" - ) - break - - except (asyncio.CancelledError, GeneratorExit) as e: - log.debug(f"Exception in subscribe: {e!r}") - if listener.send_stop: - await self._stop_listener(query_id) - listener.send_stop = False + @property + def headers(self) -> Optional[LooseHeaders]: + return self.adapter.headers - finally: - log.debug(f"In subscribe finally for query_id {query_id}") - self._remove_listener(query_id) + @property + def ssl(self) -> Optional[Union[SSLContext, Literal[False], Fingerprint]]: + return self.adapter.ssl diff --git a/gql/transport/appsync_websockets.py b/gql/transport/appsync_websockets.py index 66091747..f35cefe5 100644 --- a/gql/transport/appsync_websockets.py +++ b/gql/transport/appsync_websockets.py @@ -7,8 +7,10 @@ from graphql import DocumentNode, ExecutionResult, print_ast from .appsync_auth import AppSyncAuthentication, AppSyncIAMAuthentication +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import TransportProtocolError, TransportServerError -from .websockets import WebsocketsTransport, WebsocketsTransportBase +from .websockets import WebsocketsTransport log = logging.getLogger("gql.transport.appsync") @@ -19,7 +21,7 @@ pass -class AppSyncWebsocketsTransport(WebsocketsTransportBase): +class AppSyncWebsocketsTransport(SubscriptionTransportBase): """:ref:`Async Transport ` used to execute GraphQL subscription on AWS appsync realtime endpoint. @@ -32,6 +34,7 @@ class AppSyncWebsocketsTransport(WebsocketsTransportBase): def __init__( self, url: str, + *, auth: Optional[AppSyncAuthentication] = None, session: Optional["botocore.session.Session"] = None, ssl: Union[SSLContext, bool] = False, @@ -70,21 +73,29 @@ def __init__( auth = AppSyncIAMAuthentication(host=host, session=session) self.auth = auth + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + self.init_payload: Dict[str, Any] = {} url = self.auth.get_auth_url(url) - super().__init__( - url, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, ssl=ssl, + connect_args=connect_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, connect_timeout=connect_timeout, close_timeout=close_timeout, - ack_timeout=ack_timeout, keep_alive_timeout=keep_alive_timeout, - connect_args=connect_args, ) # Using the same 'graphql-ws' protocol as the apollo protocol - self.supported_subprotocols = [ + self.adapter.subprotocols = [ WebsocketsTransport.APOLLO_SUBPROTOCOL, ] self.subprotocol = WebsocketsTransport.APOLLO_SUBPROTOCOL @@ -181,7 +192,7 @@ async def _send_query( return query_id - subscribe = WebsocketsTransportBase.subscribe + subscribe = SubscriptionTransportBase.subscribe # type: ignore[assignment] """Send a subscription query and receive the results using a python async generator. @@ -212,3 +223,7 @@ async def execute( WebsocketsTransport._send_init_message_and_wait_ack ) _wait_ack = WebsocketsTransport._wait_ack + + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/common/__init__.py b/gql/transport/common/__init__.py new file mode 100644 index 00000000..a60ce0b0 --- /dev/null +++ b/gql/transport/common/__init__.py @@ -0,0 +1,10 @@ +from .adapters import AdapterConnection +from .base import SubscriptionTransportBase +from .listener_queue import ListenerQueue, ParsedAnswer + +__all__ = [ + "AdapterConnection", + "ListenerQueue", + "ParsedAnswer", + "SubscriptionTransportBase", +] diff --git a/gql/transport/common/adapters/__init__.py b/gql/transport/common/adapters/__init__.py new file mode 100644 index 00000000..593c46b6 --- /dev/null +++ b/gql/transport/common/adapters/__init__.py @@ -0,0 +1,3 @@ +from .connection import AdapterConnection + +__all__ = ["AdapterConnection"] diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py new file mode 100644 index 00000000..f2dff699 --- /dev/null +++ b/gql/transport/common/adapters/aiohttp.py @@ -0,0 +1,269 @@ +import asyncio +import logging +from ssl import SSLContext +from typing import Any, Dict, Literal, Mapping, Optional, Union + +import aiohttp +from aiohttp import BasicAuth, Fingerprint, WSMsgType +from aiohttp.typedefs import LooseHeaders, StrOrURL +from multidict import CIMultiDictProxy + +from ...exceptions import TransportConnectionFailed, TransportProtocolError +from ..aiohttp_closed_event import create_aiohttp_closed_event +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.aiohttp") + + +class AIOHTTPWebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the aiohttp library.""" + + def __init__( + self, + url: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = None, + session: Optional[aiohttp.ClientSession] = None, + client_session_args: Optional[Dict[str, Any]] = None, + connect_args: Optional[Dict[str, Any]] = None, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + proxy_headers: Optional[LooseHeaders] = None, + websocket_close_timeout: float = 10.0, + receive_timeout: Optional[float] = None, + ssl_close_timeout: Optional[Union[int, float]] = 10, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: SSL validation mode. ``True`` for default SSL check + (:func:`ssl.create_default_context` is used), + ``False`` for skip SSL certificate validation, + :class:`aiohttp.Fingerprint` for fingerprint + validation, :class:`ssl.SSLContext` for custom SSL + certificate validation. + :param session: Optional aiohttp opened session. + :param client_session_args: Dict of extra args passed to + `aiohttp.ClientSession`_ + :param connect_args: Dict of extra args passed to + `aiohttp.ClientSession.ws_connect`_ + + :param float heartbeat: Send low level `ping` message every `heartbeat` + seconds and wait `pong` response, close + connection if `pong` response is not + received. The timer is reset on any data reception. + :param auth: An object that represents HTTP Basic Authorization. + :class:`~aiohttp.BasicAuth` (optional) + :param str origin: Origin header to send to server(optional) + :param params: Mapping, iterable of tuple of *key*/*value* pairs or + string to be sent as parameters in the query + string of the new request. Ignored for subsequent + redirected requests (optional) + + Allowed values are: + + - :class:`collections.abc.Mapping` e.g. :class:`dict`, + :class:`multidict.MultiDict` or + :class:`multidict.MultiDictProxy` + - :class:`collections.abc.Iterable` e.g. :class:`tuple` or + :class:`list` + - :class:`str` with preferably url-encoded content + (**Warning:** content will not be encoded by *aiohttp*) + :param proxy: Proxy URL, :class:`str` or :class:`~yarl.URL` (optional) + :param aiohttp.BasicAuth proxy_auth: an object that represents proxy HTTP + Basic Authorization (optional) + :param float websocket_close_timeout: Timeout for websocket to close. + ``10`` seconds by default + :param float receive_timeout: Timeout for websocket to receive + complete message. ``None`` (unlimited) + seconds by default + :param ssl_close_timeout: Timeout in seconds to wait for the ssl connection + to close properly + """ + super().__init__( + url=str(url), + connect_args=connect_args, + ) + + self._headers: Optional[LooseHeaders] = headers + self.ssl: Optional[Union[SSLContext, Literal[False], Fingerprint]] = ssl + + self.session: Optional[aiohttp.ClientSession] = session + self._using_external_session = True if self.session else False + + if client_session_args is None: + client_session_args = {} + self.client_session_args = client_session_args + + self.heartbeat: Optional[float] = heartbeat + self.auth: Optional[BasicAuth] = auth + self.origin: Optional[str] = origin + self.params: Optional[Mapping[str, str]] = params + + self.proxy: Optional[StrOrURL] = proxy + self.proxy_auth: Optional[BasicAuth] = proxy_auth + self.proxy_headers: Optional[LooseHeaders] = proxy_headers + + self.websocket_close_timeout: float = websocket_close_timeout + self.receive_timeout: Optional[float] = receive_timeout + + self.ssl_close_timeout: Optional[Union[int, float]] = ssl_close_timeout + + self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None + self._response_headers: Optional[CIMultiDictProxy[str]] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + # Create a session if necessary + if self.session is None: + client_session_args: Dict[str, Any] = {} + + # Adding custom parameters passed from init + client_session_args.update(self.client_session_args) # type: ignore + + self.session = aiohttp.ClientSession(**client_session_args) + + connect_args: Dict[str, Any] = { + "url": self.url, + "headers": self.headers, + "auth": self.auth, + "heartbeat": self.heartbeat, + "origin": self.origin, + "params": self.params, + "proxy": self.proxy, + "proxy_auth": self.proxy_auth, + "proxy_headers": self.proxy_headers, + "timeout": self.websocket_close_timeout, + "receive_timeout": self.receive_timeout, + } + + if self.subprotocols: + connect_args["protocols"] = self.subprotocols + + if self.ssl is not None: + connect_args["ssl"] = self.ssl + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + try: + self.websocket = await self.session.ws_connect( + **connect_args, + ) + except Exception as e: + raise TransportConnectionFailed("Connect failed") from e + + self._response_headers = self.websocket._response.headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + try: + await self.websocket.send_str(message) + except ConnectionResetError as e: + raise TransportConnectionFailed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + while True: + ws_message = await self.websocket.receive() + + # Ignore low-level ping and pong received + if ws_message.type not in (WSMsgType.PING, WSMsgType.PONG): + break + + if ws_message.type in ( + WSMsgType.CLOSE, + WSMsgType.CLOSED, + WSMsgType.CLOSING, + WSMsgType.ERROR, + ): + raise TransportConnectionFailed("Connection was closed") + elif ws_message.type is WSMsgType.BINARY: + raise TransportProtocolError("Binary data received in the websocket") + + assert ws_message.type is WSMsgType.TEXT + + answer: str = ws_message.data + + return answer + + async def _close_session(self) -> None: + """Close the aiohttp session.""" + + assert self.session is not None + + closed_event = create_aiohttp_closed_event(self.session) + await self.session.close() + try: + await asyncio.wait_for(closed_event.wait(), self.ssl_close_timeout) + except asyncio.TimeoutError: + pass + finally: + self.session = None + + async def close(self) -> None: + """Close the WebSocket connection.""" + + if self.websocket: + websocket = self.websocket + self.websocket = None + try: + await websocket.close() + except Exception as exc: # pragma: no cover + log.warning("websocket.close() exception: " + repr(exc)) + + if self.session and not self._using_external_session: + await self._close_session() + + @property + def headers(self) -> Optional[LooseHeaders]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers) + return {} diff --git a/gql/transport/common/adapters/connection.py b/gql/transport/common/adapters/connection.py new file mode 100644 index 00000000..ac178bc6 --- /dev/null +++ b/gql/transport/common/adapters/connection.py @@ -0,0 +1,68 @@ +import abc +from typing import Any, Dict, List, Optional + + +class AdapterConnection(abc.ABC): + """Abstract interface for subscription connections. + + This allows different WebSocket implementations to be used interchangeably. + """ + + url: str + connect_args: Dict[str, Any] + subprotocols: Optional[List[str]] + + def __init__(self, url: str, connect_args: Optional[Dict[str, Any]]): + """Initialize the connection adapter.""" + self.url: str = url + + if connect_args is None: + connect_args = {} + self.connect_args = connect_args + + self.subprotocols = None + + @abc.abstractmethod + async def connect(self) -> None: + """Connect to the server.""" + pass # pragma: no cover + + @abc.abstractmethod + async def send(self, message: str) -> None: + """Send message to the server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + pass # pragma: no cover + + @abc.abstractmethod + async def receive(self) -> str: + """Receive message from the server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + pass # pragma: no cover + + @abc.abstractmethod + async def close(self) -> None: + """Close the connection.""" + pass # pragma: no cover + + @property + @abc.abstractmethod + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the connection. + + Returns: + Dictionary of response headers + """ + pass # pragma: no cover diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py new file mode 100644 index 00000000..c2524fb4 --- /dev/null +++ b/gql/transport/common/adapters/websockets.py @@ -0,0 +1,150 @@ +import logging +from ssl import SSLContext +from typing import Any, Dict, Optional, Union + +import websockets +from websockets.client import WebSocketClientProtocol +from websockets.datastructures import Headers, HeadersLike + +from ...exceptions import TransportConnectionFailed, TransportProtocolError +from .connection import AdapterConnection + +log = logging.getLogger("gql.transport.common.adapters.websockets") + + +class WebSocketsAdapter(AdapterConnection): + """AdapterConnection implementation using the websockets library.""" + + def __init__( + self, + url: str, + *, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + connect_args: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param connect_args: Other parameters forwarded to + `websockets.connect `_ + """ + super().__init__( + url=url, + connect_args=connect_args, + ) + + self._headers: Optional[HeadersLike] = headers + self.ssl = ssl + + self.websocket: Optional[WebSocketClientProtocol] = None + self._response_headers: Optional[Headers] = None + + async def connect(self) -> None: + """Connect to the WebSocket server.""" + + assert self.websocket is None + + ssl: Optional[Union[SSLContext, bool]] + if self.ssl: + ssl = self.ssl + else: + ssl = True if self.url.startswith("wss") else None + + # Set default arguments used in the websockets.connect call + connect_args: Dict[str, Any] = { + "ssl": ssl, + "extra_headers": self.headers, + } + + if self.subprotocols: + connect_args["subprotocols"] = self.subprotocols + + # Adding custom parameters passed from init + connect_args.update(self.connect_args) + + # Connection to the specified url + try: + self.websocket = await websockets.client.connect(self.url, **connect_args) + except Exception as e: + raise TransportConnectionFailed("Connect failed") from e + + self._response_headers = self.websocket.response_headers + + async def send(self, message: str) -> None: + """Send message to the WebSocket server. + + Args: + message: String message to send + + Raises: + TransportConnectionFailed: If connection closed + """ + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + try: + await self.websocket.send(message) + except Exception as e: + raise TransportConnectionFailed("Connection was closed") from e + + async def receive(self) -> str: + """Receive message from the WebSocket server. + + Returns: + String message received + + Raises: + TransportConnectionFailed: If connection closed + TransportProtocolError: If protocol error or binary data received + """ + # It is possible that the websocket has been already closed in another task + if self.websocket is None: + raise TransportConnectionFailed("Connection is already closed") + + # Wait for the next websocket frame. Can raise ConnectionClosed + try: + data = await self.websocket.recv() + except Exception as e: + raise TransportConnectionFailed("Connection was closed") from e + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") + + answer: str = data + + return answer + + async def close(self) -> None: + """Close the WebSocket connection.""" + if self.websocket: + websocket = self.websocket + self.websocket = None + await websocket.close() + + @property + def headers(self) -> Optional[HeadersLike]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._headers: + return self._headers + return {} + + @property + def response_headers(self) -> Dict[str, str]: + """Get the response headers from the WebSocket connection. + + Returns: + Dictionary of response headers + """ + if self._response_headers: + return dict(self._response_headers.raw_items()) + return {} diff --git a/gql/transport/common/aiohttp_closed_event.py b/gql/transport/common/aiohttp_closed_event.py new file mode 100644 index 00000000..412448f9 --- /dev/null +++ b/gql/transport/common/aiohttp_closed_event.py @@ -0,0 +1,59 @@ +import asyncio +import functools + +from aiohttp import ClientSession + + +def create_aiohttp_closed_event(session: ClientSession) -> asyncio.Event: + """Work around aiohttp issue that doesn't properly close transports on exit. + + See https://github.com/aio-libs/aiohttp/issues/1925#issuecomment-639080209 + + Returns: + An event that will be set once all transports have been properly closed. + """ + + ssl_transports = 0 + all_is_lost = asyncio.Event() + + def connection_lost(exc, orig_lost): + nonlocal ssl_transports + + try: + orig_lost(exc) + finally: + ssl_transports -= 1 + if ssl_transports == 0: + all_is_lost.set() + + def eof_received(orig_eof_received): + try: # pragma: no cover + orig_eof_received() + except AttributeError: # pragma: no cover + # It may happen that eof_received() is called after + # _app_protocol and _transport are set to None. + pass + + assert session.connector is not None + + for conn in session.connector._conns.values(): + for handler, _ in conn: + proto = getattr(handler.transport, "_ssl_protocol", None) + if proto is None: + continue + + ssl_transports += 1 + orig_lost = proto.connection_lost + orig_eof_received = proto.eof_received + + proto.connection_lost = functools.partial( + connection_lost, orig_lost=orig_lost + ) + proto.eof_received = functools.partial( + eof_received, orig_eof_received=orig_eof_received + ) + + if ssl_transports == 0: + all_is_lost.set() + + return all_is_lost diff --git a/gql/transport/websockets_base.py b/gql/transport/common/base.py similarity index 72% rename from gql/transport/websockets_base.py rename to gql/transport/common/base.py index accca275..770a8b34 100644 --- a/gql/transport/websockets_base.py +++ b/gql/transport/common/base.py @@ -3,132 +3,54 @@ import warnings from abc import abstractmethod from contextlib import suppress -from ssl import SSLContext -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union -import websockets from graphql import DocumentNode, ExecutionResult -from websockets.client import WebSocketClientProtocol -from websockets.datastructures import Headers, HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Data, Subprotocol -from .async_transport import AsyncTransport -from .exceptions import ( +from ..async_transport import AsyncTransport +from ..exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, ) +from .adapters import AdapterConnection +from .listener_queue import ListenerQueue -log = logging.getLogger("gql.transport.websockets") +log = logging.getLogger("gql.transport.common.base") -ParsedAnswer = Tuple[str, Optional[ExecutionResult]] - -class ListenerQueue: - """Special queue used for each query waiting for server answers - - If the server is stopped while the listener is still waiting, - Then we send an exception to the queue and this exception will be raised - to the consumer once all the previous messages have been consumed from the queue - """ - - def __init__(self, query_id: int, send_stop: bool) -> None: - self.query_id: int = query_id - self.send_stop: bool = send_stop - self._queue: asyncio.Queue = asyncio.Queue() - self._closed: bool = False - - async def get(self) -> ParsedAnswer: - - try: - item = self._queue.get_nowait() - except asyncio.QueueEmpty: - item = await self._queue.get() - - self._queue.task_done() - - # If we receive an exception when reading the queue, we raise it - if isinstance(item, Exception): - self._closed = True - raise item - - # Don't need to save new answers or - # send the stop message if we already received the complete message - answer_type, execution_result = item - if answer_type == "complete": - self.send_stop = False - self._closed = True - - return item - - async def put(self, item: ParsedAnswer) -> None: - - if not self._closed: - await self._queue.put(item) - - async def set_exception(self, exception: Exception) -> None: - - # Put the exception in the queue - await self._queue.put(exception) - - # Don't need to send stop messages in case of error - self.send_stop = False - self._closed = True - - -class WebsocketsTransportBase(AsyncTransport): +class SubscriptionTransportBase(AsyncTransport): """abstract :ref:`Async Transport ` used to implement - different websockets protocols. - - This transport uses asyncio and the websockets library in order to send requests - on a websocket connection. + different subscription protocols (mainly websockets). """ def __init__( self, - url: str, - headers: Optional[HeadersLike] = None, - ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + *, + adapter: AdapterConnection, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, - ack_timeout: Optional[Union[int, float]] = 10, keep_alive_timeout: Optional[Union[int, float]] = None, - connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. - :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. - :param headers: Dict of HTTP Headers. - :param ssl: ssl_context of the connection. Use ssl=False to disable encryption - :param init_payload: Dict of the payload sent in the connection_init message. + :param adapter: The connection dependency adapter :param connect_timeout: Timeout in seconds for the establishment - of the websocket connection. If None is provided this will wait forever. + of the connection. If None is provided this will wait forever. :param close_timeout: Timeout in seconds for the close. If None is provided this will wait forever. - :param ack_timeout: Timeout in seconds to wait for the connection_ack message - from the server. If None is provided this will wait forever. :param keep_alive_timeout: Optional Timeout in seconds to receive a sign of liveness from the server. - :param connect_args: Other parameters forwarded to websockets.connect """ - self.url: str = url - self.headers: Optional[HeadersLike] = headers - self.ssl: Union[SSLContext, bool] = ssl - self.init_payload: Dict[str, Any] = init_payload - self.connect_timeout: Optional[Union[int, float]] = connect_timeout self.close_timeout: Optional[Union[int, float]] = close_timeout - self.ack_timeout: Optional[Union[int, float]] = ack_timeout self.keep_alive_timeout: Optional[Union[int, float]] = keep_alive_timeout + self.adapter: AdapterConnection = adapter - self.connect_args = connect_args - - self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -158,18 +80,14 @@ def __init__( self._next_keep_alive_message: asyncio.Event = asyncio.Event() self._next_keep_alive_message.set() - self.payloads: Dict[str, Any] = {} - """payloads is a dict which will contain the payloads received - for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" - self._connecting: bool = False + self._connected: bool = False self.close_exception: Optional[Exception] = None - # The list of supported subprotocols should be defined in the subclass - self.supported_subprotocols: List[Subprotocol] = [] - - self.response_headers: Optional[Headers] = None + @property + def response_headers(self) -> Dict[str, str]: + return self.adapter.response_headers async def _initialize(self): """Hook to send the initialization messages after the connection @@ -206,36 +124,30 @@ async def _connection_terminate(self): pass # pragma: no cover async def _send(self, message: str) -> None: - """Send the provided message to the websocket connection and log the message""" + """Send the provided message to the adapter connection and log the message""" - if not self.websocket: + if not self._connected: raise TransportClosed( "Transport is not connected" ) from self.close_exception try: - await self.websocket.send(message) + await self.adapter.send(message) log.info(">>> %s", message) - except ConnectionClosed as e: + except TransportConnectionFailed as e: await self._fail(e, clean_close=False) raise e async def _receive(self) -> str: - """Wait the next message from the websocket connection and log the answer""" + """Wait the next message from the connection and log the answer""" - # It is possible that the websocket has been already closed in another task - if self.websocket is None: + # It is possible that the connection has been already closed in another task + if not self._connected: raise TransportClosed("Transport is already closed") - # Wait for the next websocket frame. Can raise ConnectionClosed - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise TransportProtocolError("Binary data received in the websocket") - - answer: str = data + # Wait for the next frame. + # Can raise TransportConnectionFailed or TransportProtocolError + answer: str = await self.adapter.receive() log.info("<<< %s", answer) @@ -296,10 +208,10 @@ async def _receive_data_loop(self) -> None: try: while True: - # Wait the next answer from the websocket server + # Wait the next answer from the server try: answer = await self._receive() - except (ConnectionClosed, TransportProtocolError) as e: + except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break except TransportClosed: @@ -384,7 +296,7 @@ async def subscribe( while True: # Wait for the answer from the queue of this query_id - # This can raise a TransportError or ConnectionClosed exception. + # This can raise TransportError or TransportConnectionFailed answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -405,6 +317,8 @@ async def subscribe( if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False + if isinstance(e, GeneratorExit): + raise e finally: log.debug(f"In subscribe finally for query_id {query_id}") @@ -433,6 +347,11 @@ async def execute( first_result = result break + # Apparently, on pypy the GeneratorExit exception is not raised after a break + # --> the clean_close has to time out + # We still need to manually close the async generator + await generator.aclose() + if first_result is None: raise TransportQueryError( "Query completed without any answer received from the server" @@ -447,52 +366,30 @@ async def connect(self) -> None: - send the init message - wait for the connection acknowledge from the server - create an asyncio task which will be used to receive - and parse the websocket answers + and parse the answers Should be cleaned with a call to the close coroutine """ log.debug("connect: starting") - if self.websocket is None and not self._connecting: + if not self._connected and not self._connecting: # Set connecting to True to avoid a race condition if user is trying # to connect twice using the same client at the same time self._connecting = True - # If the ssl parameter is not provided, - # generate the ssl value depending on the url - ssl: Optional[Union[SSLContext, bool]] - if self.ssl: - ssl = self.ssl - else: - ssl = True if self.url.startswith("wss") else None - - # Set default arguments used in the websockets.connect call - connect_args: Dict[str, Any] = { - "ssl": ssl, - "extra_headers": self.headers, - "subprotocols": self.supported_subprotocols, - } - - # Adding custom parameters passed from init - connect_args.update(self.connect_args) - - # Connection to the specified url # Generate a TimeoutError if taking more than connect_timeout seconds # Set the _connecting flag to False after in all cases try: - self.websocket = await asyncio.wait_for( - websockets.client.connect(self.url, **connect_args), + await asyncio.wait_for( + self.adapter.connect(), self.connect_timeout, ) + self._connected = True finally: self._connecting = False - self.websocket = cast(WebSocketClientProtocol, self.websocket) - - self.response_headers = self.websocket.response_headers - # Run the after_connect hook of the subclass await self._after_connect() @@ -505,7 +402,7 @@ async def connect(self) -> None: # if no ACKs are received within the ack_timeout try: await self._initialize() - except ConnectionClosed as e: + except TransportConnectionFailed as e: raise e except ( TransportProtocolError, @@ -555,7 +452,6 @@ async def _clean_close(self, e: Exception) -> None: # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): - if listener.send_stop: await self._stop_listener(query_id) listener.send_stop = False @@ -584,7 +480,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: try: # We should always have an active websocket connection here - assert self.websocket is not None + assert self._connected # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: @@ -613,11 +509,11 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: for query_id, listener in self.listeners.items(): await listener.set_exception(e) - log.debug("_close_coro: close websocket connection") + log.debug("_close_coro: close connection") - await self.websocket.close() + await self.adapter.close() - log.debug("_close_coro: websocket connection closed") + log.debug("_close_coro: connection closed") except Exception as exc: # pragma: no cover log.warning("Exception catched in _close_coro: " + repr(exc)) @@ -626,7 +522,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: start cleanup") - self.websocket = None + self._connected = False self.close_task = None self.check_keep_alive_task = None self._wait_closed.set() @@ -638,12 +534,12 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: if self.close_task is None: - if self.websocket is None: - log.debug("_fail started with self.websocket == None -> already closed") - else: + if self._connected: self.close_task = asyncio.shield( asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) ) + else: + log.debug("_fail started with self._connected:False -> already closed") else: log.debug( "close_task is not None in _fail. Previous exception is: " @@ -655,7 +551,7 @@ async def _fail(self, e: Exception, clean_close: bool = True) -> None: async def close(self) -> None: log.debug("close: starting") - await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self._fail(TransportClosed("Transport closed by user")) await self.wait_closed() log.debug("close: done") @@ -663,6 +559,17 @@ async def close(self) -> None: async def wait_closed(self) -> None: log.debug("wait_close: starting") - await self._wait_closed.wait() + try: + await asyncio.wait_for(self._wait_closed.wait(), self.close_timeout) + except asyncio.TimeoutError: + log.warning("Timer close_timeout fired in wait_closed") log.debug("wait_close: done") + + @property + def url(self) -> str: + return self.adapter.url + + @property + def connect_args(self) -> Dict[str, Any]: + return self.adapter.connect_args diff --git a/gql/transport/common/listener_queue.py b/gql/transport/common/listener_queue.py new file mode 100644 index 00000000..54aa650f --- /dev/null +++ b/gql/transport/common/listener_queue.py @@ -0,0 +1,58 @@ +import asyncio +from typing import Optional, Tuple + +from graphql import ExecutionResult + +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index 7ec27a33..3e63f0bc 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -61,6 +61,13 @@ class TransportClosed(TransportError): """ +class TransportConnectionFailed(TransportError): + """Transport adapter connection closed. + + This exception is by the connection adapter code when a connection closed. + """ + + class TransportAlreadyConnected(TransportError): """Transport is already connected. diff --git a/gql/transport/phoenix_channel_websockets.py b/gql/transport/phoenix_channel_websockets.py index 08cde8cc..3885fcac 100644 --- a/gql/transport/phoenix_channel_websockets.py +++ b/gql/transport/phoenix_channel_websockets.py @@ -1,17 +1,18 @@ import asyncio import json import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from graphql import DocumentNode, ExecutionResult, print_ast -from websockets.exceptions import ConnectionClosed +from .common.adapters.websockets import WebSocketsAdapter +from .common.base import SubscriptionTransportBase from .exceptions import ( + TransportConnectionFailed, TransportProtocolError, TransportQueryError, TransportServerError, ) -from .websockets_base import WebsocketsTransportBase log = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def __init__(self, query_id: int) -> None: self.unsubscribe_id: Optional[int] = None -class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): +class PhoenixChannelWebsocketsTransport(SubscriptionTransportBase): """The PhoenixChannelWebsocketsTransport is an async transport which allows you to execute queries and subscriptions against an `Absinthe`_ backend using the `Phoenix`_ framework `channels`_. @@ -36,23 +37,48 @@ class PhoenixChannelWebsocketsTransport(WebsocketsTransportBase): def __init__( self, + url: str, + *, channel_name: str = "__absinthe__:control", heartbeat_interval: float = 30, - *args, + ack_timeout: Optional[Union[int, float]] = 10, **kwargs, ) -> None: """Initialize the transport with the given parameters. + :param url: The server URL.'. :param channel_name: Channel on the server this transport will join. The default for Absinthe servers is "__absinthe__:control" :param heartbeat_interval: Interval in second between each heartbeat messages sent by the client + :param ack_timeout: Timeout in seconds to wait for the reply message + from the server. """ self.channel_name: str = channel_name self.heartbeat_interval: float = heartbeat_interval self.heartbeat_task: Optional[asyncio.Future] = None self.subscriptions: Dict[str, Subscription] = {} - super().__init__(*args, **kwargs) + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + ws_adapter_args = {} + for ws_arg in ["headers", "ssl", "connect_args"]: + try: + ws_adapter_args[ws_arg] = kwargs.pop(ws_arg) + except KeyError: + pass + + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + **ws_adapter_args, + ) + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + **kwargs, + ) async def _initialize(self) -> None: """Join the specified channel and wait for the connection ACK. @@ -101,7 +127,7 @@ async def heartbeat_coro(): } ) ) - except ConnectionClosed: # pragma: no cover + except TransportConnectionFailed: # pragma: no cover return self.heartbeat_task = asyncio.ensure_future(heartbeat_coro()) @@ -370,7 +396,7 @@ async def _handle_answer( execution_result: Optional[ExecutionResult], ) -> None: if answer_type == "close": - await self.close() + pass else: await super()._handle_answer(answer_type, answer_id, execution_result) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 02abb61f..7a0ce10a 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,26 +1,13 @@ -import asyncio -import json -import logging -from contextlib import suppress from ssl import SSLContext -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Union -from graphql import DocumentNode, ExecutionResult, print_ast from websockets.datastructures import HeadersLike -from websockets.exceptions import ConnectionClosed -from websockets.typing import Subprotocol -from .exceptions import ( - TransportProtocolError, - TransportQueryError, - TransportServerError, -) -from .websockets_base import WebsocketsTransportBase +from .common.adapters.websockets import WebSocketsAdapter +from .websockets_protocol import WebsocketsProtocolTransportBase -log = logging.getLogger(__name__) - -class WebsocketsTransport(WebsocketsTransportBase): +class WebsocketsTransport(WebsocketsProtocolTransportBase): """:ref:`Async Transport ` used to execute GraphQL queries on remote servers with websocket connection. @@ -28,17 +15,13 @@ class WebsocketsTransport(WebsocketsTransportBase): on a websocket connection. """ - # This transport supports two subprotocols and will autodetect the - # subprotocol supported on the server - APOLLO_SUBPROTOCOL = cast(Subprotocol, "graphql-ws") - GRAPHQLWS_SUBPROTOCOL = cast(Subprotocol, "graphql-transport-ws") - def __init__( self, url: str, + *, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, - init_payload: Dict[str, Any] = {}, + init_payload: Optional[Dict[str, Any]] = None, connect_timeout: Optional[Union[int, float]] = 10, close_timeout: Optional[Union[int, float]] = 10, ack_timeout: Optional[Union[int, float]] = 10, @@ -46,8 +29,8 @@ def __init__( ping_interval: Optional[Union[int, float]] = None, pong_timeout: Optional[Union[int, float]] = None, answer_pings: bool = True, - connect_args: Dict[str, Any] = {}, - subprotocols: Optional[List[Subprotocol]] = None, + connect_args: Optional[Dict[str, Any]] = None, + subprotocols: Optional[List[str]] = None, ) -> None: """Initialize the transport with the given parameters. @@ -83,438 +66,33 @@ def __init__( By default: both apollo and graphql-ws subprotocols. """ - super().__init__( - url, - headers, - ssl, - init_payload, - connect_timeout, - close_timeout, - ack_timeout, - keep_alive_timeout, - connect_args, + # Instanciate a WebSocketAdapter to indicate the use + # of the websockets dependency for this transport + self.adapter: WebSocketsAdapter = WebSocketsAdapter( + url=url, + headers=headers, + ssl=ssl, + connect_args=connect_args, ) - self.ping_interval: Optional[Union[int, float]] = ping_interval - self.pong_timeout: Optional[Union[int, float]] - self.answer_pings: bool = answer_pings - - if ping_interval is not None: - if pong_timeout is None: - self.pong_timeout = ping_interval / 2 - else: - self.pong_timeout = pong_timeout - - self.send_ping_task: Optional[asyncio.Future] = None - - self.ping_received: asyncio.Event = asyncio.Event() - """ping_received is an asyncio Event which will fire each time - a ping is received with the graphql-ws protocol""" - - self.pong_received: asyncio.Event = asyncio.Event() - """pong_received is an asyncio Event which will fire each time - a pong is received with the graphql-ws protocol""" - - if subprotocols is None: - self.supported_subprotocols = [ - self.APOLLO_SUBPROTOCOL, - self.GRAPHQLWS_SUBPROTOCOL, - ] - else: - self.supported_subprotocols = subprotocols - - async def _wait_ack(self) -> None: - """Wait for the connection_ack message. Keep alive messages are ignored""" - - while True: - init_answer = await self._receive() - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type == "connection_ack": - return - - if answer_type != "ka": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) - - async def _send_init_message_and_wait_ack(self) -> None: - """Send init message to the provided websocket and wait for the connection ACK. - - If the answer is not a connection_ack message, we will return an Exception. - """ - - init_message = json.dumps( - {"type": "connection_init", "payload": self.init_payload} - ) - - await self._send(init_message) - - # Wait for the connection_ack message or raise a TimeoutError - await asyncio.wait_for(self._wait_ack(), self.ack_timeout) - - async def _initialize(self): - await self._send_init_message_and_wait_ack() - - async def send_ping(self, payload: Optional[Any] = None) -> None: - """Send a ping message for the graphql-ws protocol""" - - ping_message = {"type": "ping"} - - if payload is not None: - ping_message["payload"] = payload - - await self._send(json.dumps(ping_message)) - - async def send_pong(self, payload: Optional[Any] = None) -> None: - """Send a pong message for the graphql-ws protocol""" - - pong_message = {"type": "pong"} - - if payload is not None: - pong_message["payload"] = payload - - await self._send(json.dumps(pong_message)) - - async def _send_stop_message(self, query_id: int) -> None: - """Send stop message to the provided websocket connection and query_id. - - The server should afterwards return a 'complete' message. - """ - - stop_message = json.dumps({"id": str(query_id), "type": "stop"}) - - await self._send(stop_message) - - async def _send_complete_message(self, query_id: int) -> None: - """Send a complete message for the provided query_id. - - This is only for the graphql-ws protocol. - """ - - complete_message = json.dumps({"id": str(query_id), "type": "complete"}) - - await self._send(complete_message) - - async def _stop_listener(self, query_id: int): - """Stop the listener corresponding to the query_id depending on the - detected backend protocol. - - For apollo: send a "stop" message - (a "complete" message will be sent from the backend) - - For graphql-ws: send a "complete" message and simulate the reception - of a "complete" message from the backend - """ - log.debug(f"stop listener {query_id}") - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - await self._send_complete_message(query_id) - await self.listeners[query_id].put(("complete", None)) - else: - await self._send_stop_message(query_id) - - async def _send_connection_terminate_message(self) -> None: - """Send a connection_terminate message to the provided websocket connection. - - This message indicates that the connection will disconnect. - """ - - connection_terminate_message = json.dumps({"type": "connection_terminate"}) - - await self._send(connection_terminate_message) - - async def _send_query( - self, - document: DocumentNode, - variable_values: Optional[Dict[str, Any]] = None, - operation_name: Optional[str] = None, - ) -> int: - """Send a query to the provided websocket connection. - - We use an incremented id to reference the query. - - Returns the used id for this query. - """ - - query_id = self.next_query_id - self.next_query_id += 1 - - payload: Dict[str, Any] = {"query": print_ast(document)} - if variable_values: - payload["variables"] = variable_values - if operation_name: - payload["operationName"] = operation_name - - query_type = "start" - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - query_type = "subscribe" - - query_str = json.dumps( - {"id": str(query_id), "type": query_type, "payload": payload} + # Initialize the WebsocketsProtocolTransportBase parent class + super().__init__( + adapter=self.adapter, + init_payload=init_payload, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + ack_timeout=ack_timeout, + keep_alive_timeout=keep_alive_timeout, + ping_interval=ping_interval, + pong_timeout=pong_timeout, + answer_pings=answer_pings, + subprotocols=subprotocols, ) - await self._send(query_str) - - return query_id - - async def _connection_terminate(self): - if self.subprotocol == self.APOLLO_SUBPROTOCOL: - await self._send_connection_terminate_message() - - def _parse_answer_graphqlws( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - graphql-ws protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - - Differences with the apollo websockets protocol (superclass): - - the "data" message is now called "next" - - the "stop" message is now called "complete" - - there is no connection_terminate or connection_error messages - - instead of a unidirectional keep-alive (ka) message from server to client, - there is now the possibility to send bidirectional ping/pong messages - - connection_ack has an optional payload - - the 'error' answer type returns a list of errors instead of a single error - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["next", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "next" or answer_type == "error": - - payload = json_answer.get("payload") - - if answer_type == "next": - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - # Saving answer_type as 'data' to be understood with superclass - answer_type = "data" - - elif answer_type == "error": - - if not isinstance(payload, list): - raise ValueError("payload is not a list") - - raise TransportQueryError( - str(payload[0]), query_id=answer_id, errors=payload - ) - - elif answer_type in ["ping", "pong", "connection_ack"]: - self.payloads[answer_type] = json_answer.get("payload", None) - - else: - raise ValueError - - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer_apollo( - self, json_answer: Dict[str, Any] - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server if the server supports the - apollo websockets protocol. - - Returns a list consisting of: - - the answer_type (between: - 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') - - the answer id (Integer) if received or None - - an execution Result if the answer_type is 'data' or None - """ - - answer_type: str = "" - answer_id: Optional[int] = None - execution_result: Optional[ExecutionResult] = None - - try: - answer_type = str(json_answer.get("type")) - - if answer_type in ["data", "error", "complete"]: - answer_id = int(str(json_answer.get("id"))) - - if answer_type == "data" or answer_type == "error": - - payload = json_answer.get("payload") - - if not isinstance(payload, dict): - raise ValueError("payload is not a dict") - - if answer_type == "data": - - if "errors" not in payload and "data" not in payload: - raise ValueError( - "payload does not contain 'data' or 'errors' fields" - ) - - execution_result = ExecutionResult( - errors=payload.get("errors"), - data=payload.get("data"), - extensions=payload.get("extensions"), - ) - - elif answer_type == "error": - - raise TransportQueryError( - str(payload), query_id=answer_id, errors=[payload] - ) - - elif answer_type == "ka": - # Keep-alive message - if self.check_keep_alive_task is not None: - self._next_keep_alive_message.set() - elif answer_type == "connection_ack": - pass - elif answer_type == "connection_error": - error_payload = json_answer.get("payload") - raise TransportServerError(f"Server error: '{repr(error_payload)}'") - else: - raise ValueError - - except ValueError as e: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {json_answer}" - ) from e - - return answer_type, answer_id, execution_result - - def _parse_answer( - self, answer: str - ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: - """Parse the answer received from the server depending on - the detected subprotocol. - """ - try: - json_answer = json.loads(answer) - except ValueError: - raise TransportProtocolError( - f"Server did not return a GraphQL result: {answer}" - ) - - if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: - return self._parse_answer_graphqlws(json_answer) - - return self._parse_answer_apollo(json_answer) - - async def _send_ping_coro(self) -> None: - """Coroutine to periodically send a ping from the client to the backend. - - Only used for the graphql-ws protocol. - - Send a ping every ping_interval seconds. - Close the connection if a pong is not received within pong_timeout seconds. - """ - - assert self.ping_interval is not None - - try: - while True: - await asyncio.sleep(self.ping_interval) - - await self.send_ping() - - await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) - - # Reset for the next iteration - self.pong_received.clear() - - except asyncio.TimeoutError: - # No pong received in the appriopriate time, close with error - # If the timeout happens during a close already in progress, do nothing - if self.close_task is None: - await self._fail( - TransportServerError( - f"No pong received after {self.pong_timeout!r} seconds" - ), - clean_close=False, - ) - - async def _handle_answer( - self, - answer_type: str, - answer_id: Optional[int], - execution_result: Optional[ExecutionResult], - ) -> None: - - # Put the answer in the queue - await super()._handle_answer(answer_type, answer_id, execution_result) - - # Answer pong to ping for graphql-ws protocol - if answer_type == "ping": - self.ping_received.set() - if self.answer_pings: - await self.send_pong() - - elif answer_type == "pong": - self.pong_received.set() - - async def _after_connect(self): - - # Find the backend subprotocol returned in the response headers - response_headers = self.websocket.response_headers - try: - self.subprotocol = response_headers["Sec-WebSocket-Protocol"] - except KeyError: - # If the server does not send the subprotocol header, using - # the apollo subprotocol by default - self.subprotocol = self.APOLLO_SUBPROTOCOL - - log.debug(f"backend subprotocol returned: {self.subprotocol!r}") - - async def _after_initialize(self): - - # If requested, create a task to send periodic pings to the backend - if ( - self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL - and self.ping_interval is not None - ): - - self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) - - async def _close_hook(self): - log.debug("_close_hook: start") - - # Properly shut down the send ping task if enabled - if self.send_ping_task is not None: - log.debug("_close_hook: cancelling send_ping_task") - self.send_ping_task.cancel() - with suppress(asyncio.CancelledError, ConnectionClosed): - log.debug("_close_hook: awaiting send_ping_task") - await self.send_ping_task - self.send_ping_task = None + @property + def headers(self) -> Optional[HeadersLike]: + return self.adapter.headers - log.debug("_close_hook: end") + @property + def ssl(self) -> Union[SSLContext, bool]: + return self.adapter.ssl diff --git a/gql/transport/websockets_protocol.py b/gql/transport/websockets_protocol.py new file mode 100644 index 00000000..3348c576 --- /dev/null +++ b/gql/transport/websockets_protocol.py @@ -0,0 +1,516 @@ +import asyncio +import json +import logging +from contextlib import suppress +from typing import Any, Dict, List, Optional, Tuple, Union + +from graphql import DocumentNode, ExecutionResult, print_ast + +from .common.adapters.connection import AdapterConnection +from .common.base import SubscriptionTransportBase +from .exceptions import ( + TransportConnectionFailed, + TransportProtocolError, + TransportQueryError, + TransportServerError, +) + +log = logging.getLogger("gql.transport.websockets") + + +class WebsocketsProtocolTransportBase(SubscriptionTransportBase): + """:ref:`Async Transport ` used to execute GraphQL queries on + remote servers with websocket connection. + + This transport uses asyncio and the provided websockets adapter library + in order to send requests on a websocket connection. + """ + + # This transport supports two subprotocols and will autodetect the + # subprotocol supported on the server + APOLLO_SUBPROTOCOL = "graphql-ws" + GRAPHQLWS_SUBPROTOCOL = "graphql-transport-ws" + + def __init__( + self, + *, + adapter: AdapterConnection, + init_payload: Optional[Dict[str, Any]] = None, + connect_timeout: Optional[Union[int, float]] = 10, + close_timeout: Optional[Union[int, float]] = 10, + ack_timeout: Optional[Union[int, float]] = 10, + keep_alive_timeout: Optional[Union[int, float]] = None, + ping_interval: Optional[Union[int, float]] = None, + pong_timeout: Optional[Union[int, float]] = None, + answer_pings: bool = True, + subprotocols: Optional[List[str]] = None, + ) -> None: + """Initialize the transport with the given parameters. + + :param adapter: The connection dependency adapter + :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment + of the websocket connection. If None is provided this will wait forever. + :param close_timeout: Timeout in seconds for the close. If None is provided + this will wait forever. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message + from the server. If None is provided this will wait forever. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. + :param ping_interval: Delay in seconds between pings sent by the client to + the backend for the graphql-ws protocol. None (by default) means that + we don't send pings. Note: there are also pings sent by the underlying + websockets protocol. See the + :ref:`keepalive documentation ` + for more information about this. + :param pong_timeout: Delay in seconds to receive a pong from the backend + after we sent a ping (only for the graphql-ws protocol). + By default equal to half of the ping_interval. + :param answer_pings: Whether the client answers the pings from the backend + (for the graphql-ws protocol). + By default: True + :param subprotocols: list of subprotocols sent to the + backend in the 'subprotocols' http header. + By default: both apollo and graphql-ws subprotocols. + """ + + if subprotocols is None: + subprotocols = [ + self.APOLLO_SUBPROTOCOL, + self.GRAPHQLWS_SUBPROTOCOL, + ] + + self.adapter.subprotocols = subprotocols + + # Initialize the generic SubscriptionTransportBase parent class + super().__init__( + adapter=self.adapter, + connect_timeout=connect_timeout, + close_timeout=close_timeout, + keep_alive_timeout=keep_alive_timeout, + ) + + if init_payload is None: + init_payload = {} + + self.init_payload: Dict[str, Any] = init_payload + self.ack_timeout: Optional[Union[int, float]] = ack_timeout + + self.payloads: Dict[str, Any] = {} + """payloads is a dict which will contain the payloads received + for example with the graphql-ws protocol: 'ping', 'pong', 'connection_ack'""" + + self.ping_interval: Optional[Union[int, float]] = ping_interval + self.pong_timeout: Optional[Union[int, float]] + self.answer_pings: bool = answer_pings + + if ping_interval is not None: + if pong_timeout is None: + self.pong_timeout = ping_interval / 2 + else: + self.pong_timeout = pong_timeout + + self.send_ping_task: Optional[asyncio.Future] = None + + self.ping_received: asyncio.Event = asyncio.Event() + """ping_received is an asyncio Event which will fire each time + a ping is received with the graphql-ws protocol""" + + self.pong_received: asyncio.Event = asyncio.Event() + """pong_received is an asyncio Event which will fire each time + a pong is received with the graphql-ws protocol""" + + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored""" + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + + async def _send_init_message_and_wait_ack(self) -> None: + """Send init message to the provided websocket and wait for the connection ACK. + + If the answer is not a connection_ack message, we will return an Exception. + """ + + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) + + # Wait for the connection_ack message or raise a TimeoutError + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) + + async def _initialize(self): + await self._send_init_message_and_wait_ack() + + async def send_ping(self, payload: Optional[Any] = None) -> None: + """Send a ping message for the graphql-ws protocol""" + + ping_message = {"type": "ping"} + + if payload is not None: + ping_message["payload"] = payload + + await self._send(json.dumps(ping_message)) + + async def send_pong(self, payload: Optional[Any] = None) -> None: + """Send a pong message for the graphql-ws protocol""" + + pong_message = {"type": "pong"} + + if payload is not None: + pong_message["payload"] = payload + + await self._send(json.dumps(pong_message)) + + async def _send_stop_message(self, query_id: int) -> None: + """Send stop message to the provided websocket connection and query_id. + + The server should afterwards return a 'complete' message. + """ + + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) + + await self._send(stop_message) + + async def _send_complete_message(self, query_id: int) -> None: + """Send a complete message for the provided query_id. + + This is only for the graphql-ws protocol. + """ + + complete_message = json.dumps({"id": str(query_id), "type": "complete"}) + + await self._send(complete_message) + + async def _stop_listener(self, query_id: int): + """Stop the listener corresponding to the query_id depending on the + detected backend protocol. + + For apollo: send a "stop" message + (a "complete" message will be sent from the backend) + + For graphql-ws: send a "complete" message and simulate the reception + of a "complete" message from the backend + """ + log.debug(f"stop listener {query_id}") + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + await self._send_complete_message(query_id) + await self.listeners[query_id].put(("complete", None)) + else: + await self._send_stop_message(query_id) + + async def _send_connection_terminate_message(self) -> None: + """Send a connection_terminate message to the provided websocket connection. + + This message indicates that the connection will disconnect. + """ + + connection_terminate_message = json.dumps({"type": "connection_terminate"}) + + await self._send(connection_terminate_message) + + async def _send_query( + self, + document: DocumentNode, + variable_values: Optional[Dict[str, Any]] = None, + operation_name: Optional[str] = None, + ) -> int: + """Send a query to the provided websocket connection. + + We use an incremented id to reference the query. + + Returns the used id for this query. + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + payload: Dict[str, Any] = {"query": print_ast(document)} + if variable_values: + payload["variables"] = variable_values + if operation_name: + payload["operationName"] = operation_name + + query_type = "start" + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + query_type = "subscribe" + + query_str = json.dumps( + {"id": str(query_id), "type": query_type, "payload": payload} + ) + + await self._send(query_str) + + return query_id + + async def _connection_terminate(self): + if self.subprotocol == self.APOLLO_SUBPROTOCOL: + await self._send_connection_terminate_message() + + def _parse_answer_graphqlws( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + graphql-ws protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ping', 'pong', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + + Differences with the apollo websockets protocol (superclass): + - the "data" message is now called "next" + - the "stop" message is now called "complete" + - there is no connection_terminate or connection_error messages + - instead of a unidirectional keep-alive (ka) message from server to client, + there is now the possibility to send bidirectional ping/pong messages + - connection_ack has an optional payload + - the 'error' answer type returns a list of errors instead of a single error + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["next", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "next" or answer_type == "error": + + payload = json_answer.get("payload") + + if answer_type == "next": + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + # Saving answer_type as 'data' to be understood with superclass + answer_type = "data" + + elif answer_type == "error": + + if not isinstance(payload, list): + raise ValueError("payload is not a list") + + raise TransportQueryError( + str(payload[0]), query_id=answer_id, errors=payload + ) + + elif answer_type in ["ping", "pong", "connection_ack"]: + self.payloads[answer_type] = json_answer.get("payload", None) + + else: + raise ValueError + + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer_apollo( + self, json_answer: Dict[str, Any] + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server if the server supports the + apollo websockets protocol. + + Returns a list consisting of: + - the answer_type (between: + 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None + + try: + answer_type = str(json_answer.get("type")) + + if answer_type in ["data", "error", "complete"]: + answer_id = int(str(json_answer.get("id"))) + + if answer_type == "data" or answer_type == "error": + + payload = json_answer.get("payload") + + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") + + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), + data=payload.get("data"), + extensions=payload.get("extensions"), + ) + + elif answer_type == "error": + + raise TransportQueryError( + str(payload), query_id=answer_id, errors=[payload] + ) + + elif answer_type == "ka": + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() + elif answer_type == "connection_ack": + pass + elif answer_type == "connection_error": + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") + else: + raise ValueError + + except ValueError as e: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {json_answer}" + ) from e + + return answer_type, answer_id, execution_result + + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: + """Parse the answer received from the server depending on + the detected subprotocol. + """ + try: + json_answer = json.loads(answer) + except ValueError: + raise TransportProtocolError( + f"Server did not return a GraphQL result: {answer}" + ) + + if self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL: + return self._parse_answer_graphqlws(json_answer) + + return self._parse_answer_apollo(json_answer) + + async def _send_ping_coro(self) -> None: + """Coroutine to periodically send a ping from the client to the backend. + + Only used for the graphql-ws protocol. + + Send a ping every ping_interval seconds. + Close the connection if a pong is not received within pong_timeout seconds. + """ + + assert self.ping_interval is not None + + try: + while True: + await asyncio.sleep(self.ping_interval) + + await self.send_ping() + + await asyncio.wait_for(self.pong_received.wait(), self.pong_timeout) + + # Reset for the next iteration + self.pong_received.clear() + + except asyncio.TimeoutError: + # No pong received in the appriopriate time, close with error + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + f"No pong received after {self.pong_timeout!r} seconds" + ), + clean_close=False, + ) + + async def _handle_answer( + self, + answer_type: str, + answer_id: Optional[int], + execution_result: Optional[ExecutionResult], + ) -> None: + + # Put the answer in the queue + await super()._handle_answer(answer_type, answer_id, execution_result) + + # Answer pong to ping for graphql-ws protocol + if answer_type == "ping": + self.ping_received.set() + if self.answer_pings: + await self.send_pong() + + elif answer_type == "pong": + self.pong_received.set() + + async def _after_connect(self): + + # Find the backend subprotocol returned in the response headers + try: + self.subprotocol = self.response_headers["Sec-WebSocket-Protocol"] + except KeyError: + # If the server does not send the subprotocol header, using + # the apollo subprotocol by default + self.subprotocol = self.APOLLO_SUBPROTOCOL + + log.debug(f"backend subprotocol returned: {self.subprotocol!r}") + + async def _after_initialize(self): + + # If requested, create a task to send periodic pings to the backend + if ( + self.subprotocol == self.GRAPHQLWS_SUBPROTOCOL + and self.ping_interval is not None + ): + + self.send_ping_task = asyncio.ensure_future(self._send_ping_coro()) + + async def _close_hook(self): + log.debug("_close_hook: start") + + # Properly shut down the send ping task if enabled + if self.send_ping_task is not None: + log.debug("_close_hook: cancelling send_ping_task") + self.send_ping_task.cancel() + with suppress(asyncio.CancelledError, TransportConnectionFailed): + log.debug("_close_hook: awaiting send_ping_task") + await self.send_ping_task + self.send_ping_task = None + + log.debug("_close_hook: end") diff --git a/tests/conftest.py b/tests/conftest.py index b0103a99..f9e11dab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import logging import os import pathlib +import platform import re import ssl import sys @@ -19,6 +20,9 @@ all_transport_dependencies = ["aiohttp", "requests", "httpx", "websockets", "botocore"] +PyPy = platform.python_implementation() == "PyPy" + + def pytest_addoption(parser): parser.addoption( "--run-online", @@ -121,9 +125,10 @@ async def ssl_aiohttp_server(): "gql.transport.aiohttp", "gql.transport.aiohttp_websockets", "gql.transport.appsync", + "gql.transport.common.base", + "gql.transport.httpx", "gql.transport.phoenix_channel_websockets", "gql.transport.requests", - "gql.transport.httpx", "gql.transport.websockets", "gql.dsl", "gql.utilities.parse_result", diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 8ee44d2c..81c79ba7 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -7,7 +7,7 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -148,7 +148,7 @@ async def test_aiohttp_websocket_sending_invalid_data( invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send_str(invalid_data) + await session.transport.adapter.websocket.send_str(invalid_data) await asyncio.sleep(2 * MS) @@ -289,7 +289,7 @@ async def test_aiohttp_websocket_server_closing_directly(event_loop, server): sample_transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -309,7 +309,7 @@ async def test_aiohttp_websocket_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index b234d296..f87682d2 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -247,7 +248,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_directly( transport = AIOHTTPWebsocketsTransport(url=url) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async with Client(transport=transport): pass @@ -267,7 +268,7 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index d40d15ce..f380948c 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -8,9 +8,9 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] @@ -260,7 +260,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_aiohttp_websocket_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -390,7 +394,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_server_connection_closed count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): number = result["number"] print(f"Number received: {number}") @@ -839,7 +843,7 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except ConnectionResetError: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -847,23 +851,33 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionFailed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -871,6 +885,8 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index d76d646f..8786d58d 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -60,7 +61,14 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = AIOHTTPWebsocketsTransport(url=url, websocket_close_timeout=10) + transport = AIOHTTPWebsocketsTransport( + url=url, + websocket_close_timeout=10, + headers={"test": "1234"}, + ) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: @@ -84,7 +92,7 @@ async def test_aiohttp_websocket_starting_client_in_context_manager( assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -135,7 +143,7 @@ async def test_aiohttp_websocket_using_ssl_connection( assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -166,19 +174,26 @@ async def test_aiohttp_websocket_using_ssl_connection_self_cert_fail( **extra_args, ) - with pytest.raises(ClientConnectorCertificateError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, ClientConnectorCertificateError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -380,13 +395,13 @@ async def test_aiohttp_websocket_multiple_connections_in_series( await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -519,8 +534,8 @@ async def test_aiohttp_websocket_connect_failed_with_authentication_in_connectio await session.execute(query1) - assert transport.session is None - assert transport.websocket is None + assert transport.adapter.session is None + assert transport._connected is False @pytest.mark.parametrize("aiohttp_ws_server", [server1_answers], indirect=True) @@ -564,7 +579,7 @@ def test_aiohttp_websocket_execute_sync(aiohttp_ws_server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -753,6 +768,6 @@ async def test_aiohttp_websocket_connector_owner_false(event_loop, aiohttp_ws_se assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False await connector.close() diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index 9d2d652b..4ea11a7b 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -9,7 +9,7 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportClosed, TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError from .conftest import MS, WebSocketServerHelper from .starwars.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef @@ -250,7 +250,8 @@ async def test_aiohttp_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -264,6 +265,9 @@ async def test_aiohttp_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -381,7 +385,7 @@ async def test_aiohttp_websocket_subscription_server_connection_closed( count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(ConnectionResetError): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -772,14 +776,12 @@ async def test_subscribe_on_closing_transport(event_loop, server, subscription_s subscription = gql(subscription_str.format(count=count)) async with client as session: - session.transport.websocket._writer._closing = True + session.transport.adapter.websocket._writer._closing = True - with pytest.raises(ConnectionResetError) as e: + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass - assert e.value.args[0] == "Cannot write to closing transport" - @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -798,9 +800,7 @@ async def test_subscribe_on_null_transport(event_loop, server, subscription_str) async with client as session: - session.transport.websocket = None - with pytest.raises(TransportClosed) as e: + session.transport.adapter.websocket = None + with pytest.raises(TransportConnectionFailed): async for _ in session.subscribe(subscription): pass - - assert e.value.args[0] == "WebSocket connection is closed" diff --git a/tests/test_client.py b/tests/test_client.py index 1e794558..e5edec8b 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -280,3 +280,7 @@ async def test_async_transport_close_on_schema_retrieval_failure(): pass assert client.transport.session is None + + import asyncio + + await asyncio.sleep(1) diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index befeeb4e..3b6bd901 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -6,6 +6,7 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -233,7 +234,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_closing_directly], indirect=True) async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" @@ -241,7 +241,7 @@ async def test_graphqlws_server_closing_directly(event_loop, graphqlws_server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -257,13 +257,11 @@ async def test_graphqlws_server_closing_after_ack( event_loop, client_and_graphqlws_server ): - import websockets - session, server = client_and_graphqlws_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 683da43a..d4bed34f 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -8,9 +8,9 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -260,7 +260,8 @@ async def test_graphqlws_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -274,6 +275,9 @@ async def test_graphqlws_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("graphqlws_server", [server_countdown], indirect=True) @@ -385,14 +389,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_graphqlws_subscription_server_connection_closed( event_loop, client_and_graphqlws_server, subscription_str ): - import websockets - session, server = client_and_graphqlws_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -812,7 +814,6 @@ async def test_graphqlws_subscription_reconnecting_session( event_loop, graphqlws_server, subscription_str, execute_instead_of_subscribe ): - import websockets from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportClosed @@ -838,7 +839,7 @@ async def test_graphqlws_subscription_reconnecting_session( print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") async for result in session.subscribe(subscription_with_disconnect): pass - except websockets.exceptions.ConnectionClosedOK: + except TransportConnectionFailed: pass await asyncio.sleep(50 * MS) @@ -846,23 +847,33 @@ async def test_graphqlws_subscription_reconnecting_session( # Then with the same session handle, we make a subscription or an execute # which will detect that the transport is closed so that the client could # try to reconnect + generator = None try: if execute_instead_of_subscribe: print("\nEXECUTION_2\n") await session.execute(subscription) else: print("\nSUBSCRIPTION_2\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: pass - except TransportClosed: + except (TransportClosed, TransportConnectionFailed): + if generator: + await generator.aclose() pass - await asyncio.sleep(50 * MS) + timeout = 50 + + if PyPy: + timeout = 500 + + await asyncio.sleep(timeout * MS) # And finally with the same session handle, we make a subscription # which works correctly print("\nSUBSCRIPTION_3\n") - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -870,6 +881,8 @@ async def test_graphqlws_subscription_reconnecting_session( assert number == count count -= 1 + await generator.aclose() + assert count == -1 await client.close_async() diff --git a/tests/test_phoenix_channel_query.py b/tests/test_phoenix_channel_query.py index f39edacb..56d28875 100644 --- a/tests/test_phoenix_channel_query.py +++ b/tests/test_phoenix_channel_query.py @@ -1,6 +1,7 @@ import pytest from gql import Client, gql +from gql.transport.exceptions import TransportConnectionFailed from .conftest import get_localhost_ssl_context_client @@ -65,16 +66,16 @@ async def test_phoenix_channel_query(event_loop, server, query_str): result = await session.execute(query) print("Client received:", result) + continents = result["continents"] + print("Continents received:", continents) + africa = continents[0] + assert africa["code"] == "AF" -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [query_server], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_phoenix_channel_query_ssl( - event_loop, ws_ssl_server, query_str, verify_https -): +async def test_phoenix_channel_query_ssl(event_loop, ws_ssl_server, query_str): from gql.transport.phoenix_channel_websockets import ( PhoenixChannelWebsocketsTransport, ) @@ -85,12 +86,9 @@ async def test_phoenix_channel_query_ssl( extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", @@ -134,13 +132,17 @@ async def test_phoenix_channel_query_ssl_self_cert_fail( query = gql(query_str) - with pytest.raises(SSLCertVerificationError) as exc_info: + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: await session.execute(query) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) query2_str = """ @@ -214,8 +216,12 @@ async def test_phoenix_channel_subscription(event_loop, server, query_str): first_result = None query = gql(query_str) async with Client(transport=transport) as session: - async for result in session.subscribe(query): + generator = session.subscribe(query) + async for result in generator: first_result = result break + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + print("Client received:", first_result) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 6193c658..35ca665b 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -186,7 +186,7 @@ async def test_phoenix_channel_subscription( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger websockets_logger.setLevel(logging.DEBUG) phoenix_logger.setLevel(logging.DEBUG) @@ -201,7 +201,9 @@ async def test_phoenix_channel_subscription( subscription = gql(subscription_str.format(count=count)) async with Client(transport=sample_transport) as session: - async for result in session.subscribe(subscription): + + generator = session.subscribe(subscription) + async for result in generator: number = result["countdown"]["number"] print(f"Number received: {number}") @@ -212,6 +214,9 @@ async def test_phoenix_channel_subscription( count -= 1 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + assert count == end_count @@ -227,7 +232,7 @@ async def test_phoenix_channel_subscription_no_break( PhoenixChannelWebsocketsTransport, ) from gql.transport.phoenix_channel_websockets import log as phoenix_logger - from gql.transport.websockets import log as websockets_logger + from gql.transport.websockets_protocol import log as websockets_logger from .conftest import MS @@ -378,7 +383,8 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): subscription = gql(heartbeat_subscription_str) async with Client(transport=sample_transport) as session: i = 0 - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: heartbeat_count = result["heartbeat"]["heartbeat_count"] print(f"Heartbeat count received: {heartbeat_count}") @@ -387,3 +393,6 @@ async def test_phoenix_channel_heartbeat(event_loop, server, subscription_str): break i += 1 + + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index cb9e7274..68b2fe52 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportProtocolError, TransportQueryError, ) @@ -141,7 +142,7 @@ async def test_websocket_sending_invalid_data(event_loop, client_and_server, que invalid_data = "QSDF" print(f">>> {invalid_data}") - await session.transport.websocket.send(invalid_data) + await session.transport.adapter.websocket.send(invalid_data) await asyncio.sleep(2 * MS) @@ -272,7 +273,6 @@ async def server_closing_directly(ws): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" @@ -280,7 +280,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionFailed): async with Client(transport=sample_transport): pass @@ -294,13 +294,11 @@ async def server_closing_after_ack(ws): @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) async def test_websocket_server_closing_after_ack(event_loop, client_and_server): - import websockets - session, server = client_and_server query = gql("query { hello }") - with pytest.raises(websockets.exceptions.ConnectionClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 2c723b3f..b1e3c07a 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -9,6 +9,7 @@ from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, + TransportConnectionFailed, TransportQueryError, TransportServerError, ) @@ -51,19 +52,19 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(event_loop, server): - import websockets from gql.transport.websockets import WebsocketsTransport url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + assert transport.response_headers == {} + assert transport.headers["test"] == "1234" async with Client(transport=transport) as session: - assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol - ) + assert transport._connected is True query1 = gql(query1_str) @@ -85,14 +86,12 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): assert transport.response_headers["dummy"] == "test1234" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False -@pytest.mark.skip(reason="ssl=False is not working for now") @pytest.mark.asyncio @pytest.mark.parametrize("ws_ssl_server", [server1_answers], indirect=True) -@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) -async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_https): +async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server): import websockets from gql.transport.websockets import WebsocketsTransport @@ -103,19 +102,16 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ extra_args = {} - if verify_https == "cert_provided": - _, ssl_context = get_localhost_ssl_context_client() + _, ssl_context = get_localhost_ssl_context_client() - extra_args["ssl"] = ssl_context - elif verify_https == "disabled": - extra_args["ssl"] = False + extra_args["ssl"] = ssl_context transport = WebsocketsTransport(url=url, **extra_args) async with Client(transport=transport) as session: assert isinstance( - transport.websocket, websockets.client.WebSocketClientProtocol + transport.adapter.websocket, websockets.client.WebSocketClientProtocol ) query1 = gql(query1_str) @@ -133,7 +129,7 @@ async def test_websocket_using_ssl_connection(event_loop, ws_ssl_server, verify_ assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -157,19 +153,26 @@ async def test_websocket_using_ssl_connection_self_cert_fail( transport = WebsocketsTransport(url=url, **extra_args) - with pytest.raises(SSLCertVerificationError) as exc_info: + if verify_https == "explicitely_enabled": + assert transport.ssl is True + + with pytest.raises(TransportConnectionFailed) as exc_info: async with Client(transport=transport) as session: query1 = gql(query1_str) await session.execute(query1) + cause = exc_info.value.__cause__ + + assert isinstance(cause, SSLCertVerificationError) + expected_error = "certificate verify failed: self-signed certificate" - assert expected_error in str(exc_info.value) + assert expected_error in str(cause) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -355,13 +358,13 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False async with Client(transport=transport) as session: await assert_client_is_working(session) # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -484,7 +487,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) - assert transport.websocket is None + assert transport._connected is False @pytest.mark.parametrize("server", [server1_answers], indirect=True) @@ -526,7 +529,7 @@ def test_websocket_execute_sync(server): assert africa["code"] == "AF" # Check client is disconnect here - assert transport.websocket is None + assert transport._connected is False @pytest.mark.asyncio @@ -649,3 +652,52 @@ async def test_websocket_simple_query_with_extensions( execution_result = await session.execute(query, get_execution_result=True) assert execution_result.extensions["key1"] == "val1" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_adapter_connection_closed(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport(url=url, headers={"test": "1234"}) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + await transport.adapter.close() + + with pytest.raises(TransportClosed): + await session.execute(query1) + + # Check client is disconnect here + assert transport._connected is False + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_transport_closed_in_receive(event_loop, server): + from gql.transport.websockets import WebsocketsTransport + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + transport = WebsocketsTransport( + url=url, + close_timeout=0.1, + ) + + async with Client(transport=transport) as session: + + query1 = gql(query1_str) + + # Close adapter connection manually (should not be done) + # await transport.adapter.close() + transport._connected = False + + with pytest.raises(TransportClosed): + await session.execute(query1) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 5af44d59..6f291218 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -9,9 +9,9 @@ from parse import search from gql import Client, gql -from gql.transport.exceptions import TransportServerError +from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, WebSocketServerHelper +from .conftest import MS, PyPy, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -181,7 +181,8 @@ async def test_websocket_subscription_break( count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in session.subscribe(subscription): + generator = session.subscribe(subscription) + async for result in generator: number = result["number"] print(f"Number received: {number}") @@ -195,6 +196,9 @@ async def test_websocket_subscription_break( assert count == 5 + # Using aclose here to make it stop cleanly on pypy + await generator.aclose() + @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -306,14 +310,12 @@ async def server_countdown_close_connection_in_middle(ws): async def test_websocket_subscription_server_connection_closed( event_loop, client_and_server, subscription_str ): - import websockets - session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - with pytest.raises(websockets.exceptions.ConnectionClosedOK): + with pytest.raises(TransportConnectionFailed): async for result in session.subscribe(subscription): @@ -415,7 +417,14 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(20 * MS)) + + keep_alive_timeout = 20 * MS + if PyPy: + keep_alive_timeout = 200 * MS + + sample_transport = WebsocketsTransport( + url=url, keep_alive_timeout=keep_alive_timeout + ) client = Client(transport=sample_transport) diff --git a/tests/test_websockets_adapter.py b/tests/test_websockets_adapter.py new file mode 100644 index 00000000..85fbf00a --- /dev/null +++ b/tests/test_websockets_adapter.py @@ -0,0 +1,98 @@ +import json + +import pytest +from graphql import print_ast + +from gql import gql +from gql.transport.exceptions import TransportConnectionFailed + +# Marking all tests in this file with the websockets marker +pytestmark = pytest.mark.websockets + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' + '{{"code":"AS","name":"Asia"}},{{"code":"EU","name":"Europe"}},' + '{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + +server1_answers = [ + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_simple_query(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url) + + await adapter.connect() + + init_message = json.dumps({"type": "connection_init", "payload": {}}) + + await adapter.send(init_message) + + result = await adapter.receive() + print(f"result={result}") + + payload = json.dumps({"query": query}) + query_message = json.dumps({"id": 1, "type": "start", "payload": payload}) + + await adapter.send(query_message) + + result = await adapter.receive() + print(f"result={result}") + + await adapter.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websockets_adapter_edge_cases(event_loop, server): + from gql.transport.common.adapters.websockets import WebSocketsAdapter + + url = f"ws://{server.hostname}:{server.port}/graphql" + + query = print_ast(gql(query1_str)) + print("query=", query) + + adapter = WebSocketsAdapter(url, headers={"a": 1}, ssl=False, connect_args={}) + + await adapter.connect() + + assert adapter.headers["a"] == 1 + assert adapter.ssl is False + assert adapter.connect_args == {} + assert adapter.response_headers["dummy"] == "test1234" + + # Connect twice causes AssertionError + with pytest.raises(AssertionError): + await adapter.connect() + + await adapter.close() + + # Second close call is ignored + await adapter.close() + + with pytest.raises(TransportConnectionFailed): + await adapter.send("Blah") + + with pytest.raises(TransportConnectionFailed): + await adapter.receive()