diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index d1656de2..701b6de6 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,6 +1,7 @@ import asyncio import json import logging +from contextlib import suppress from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast @@ -94,6 +95,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, + keep_alive_timeout: Optional[int] = None, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -107,6 +109,8 @@ def __init__( :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url @@ -117,6 +121,7 @@ def __init__( self.connect_timeout: int = connect_timeout self.close_timeout: int = close_timeout self.ack_timeout: int = ack_timeout + self.keep_alive_timeout: Optional[int] = keep_alive_timeout self.connect_args = connect_args @@ -125,6 +130,7 @@ def __init__( self.listeners: Dict[int, ListenerQueue] = {} self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -141,6 +147,10 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + self._connecting: bool = False self.close_exception: Optional[Exception] = None @@ -315,8 +325,9 @@ def _parse_answer( ) elif answer_type == "ka": - # KeepAlive message - pass + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() elif answer_type == "connection_ack": pass elif answer_type == "connection_error": @@ -332,8 +343,41 @@ def _parse_answer( return answer_type, answer_id, execution_result - async def _receive_data_loop(self) -> None: + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages + """ + + try: + while True: + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) + # Reset for the next iteration + self._next_keep_alive_message.clear() + + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) + + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: try: while True: @@ -549,6 +593,13 @@ async def connect(self) -> None: await self._fail(e, clean_close=False) raise e + # If specified, create a task to check liveness of the connection + # through keep-alive messages + if self.keep_alive_timeout is not None: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness() + ) + # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) @@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # We should always have an active websocket connection here assert self.websocket is not None + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e @@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None self.close_task = None + self.check_keep_alive_task = None self._wait_closed.set() diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 7d80c8eb..fcd176b5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -7,6 +7,7 @@ from parse import search from gql import Client, gql +from gql.transport.exceptions import TransportServerError from .conftest import MS, WebSocketServerHelper @@ -378,6 +379,67 @@ async def test_websocket_subscription_with_keepalive( assert count == -1 +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_websocket_subscription_sync(server, subscription_str):