From 7da1f1b027009fcbb7b61bdb93d0c7dd76d32006 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Rame=CC=81?= Date: Wed, 7 Apr 2021 15:27:52 +0200 Subject: [PATCH 1/5] Handle keep-alive behavior to close the connection --- gql/transport/websockets.py | 63 +++++++++++++++++++++++++++++++++++-- 1 file changed, 60 insertions(+), 3 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 76a234bd..c3be1aa8 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import suppress import json import logging from ssl import SSLContext @@ -94,6 +95,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, + keep_alive_timeout: int = 0, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -107,6 +109,7 @@ def __init__( :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. + :param keep_alive_timeout: Timeout in seconds to receive a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url @@ -117,6 +120,7 @@ def __init__( self.connect_timeout: int = connect_timeout self.close_timeout: int = close_timeout self.ack_timeout: int = ack_timeout + self.keep_alive_timeout: int = keep_alive_timeout self.connect_args = connect_args @@ -125,6 +129,7 @@ def __init__( self.listeners: Dict[int, ListenerQueue] = {} self.receive_data_task: Optional[asyncio.Future] = None + self.check_keep_alive_task: Optional[asyncio.Future] = None self.close_task: Optional[asyncio.Future] = None # We need to set an event loop here if there is none @@ -141,6 +146,10 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() + + self._keep_alive_not_received: bool = False self._connecting: bool = False self.close_exception: Optional[Exception] = None @@ -313,8 +322,9 @@ def _parse_answer( ) elif answer_type == "ka": - # KeepAlive message - pass + # Keep-alive message + if self.check_keep_alive_task is not None: + self._next_keep_alive_message.set() elif answer_type == "connection_ack": pass elif answer_type == "connection_error": @@ -330,8 +340,31 @@ def _parse_answer( return answer_type, answer_id, execution_result - async def _receive_data_loop(self) -> None: + async def _check_ws_liveness(self) -> None: + """Coroutine which will periodically check the liveness of the connection through keep-alive messages + """ + try: + while True: + await asyncio.wait_for(self._next_keep_alive_message.wait(), self.keep_alive_timeout) + + # Reset for the next iteration + self._next_keep_alive_message.clear() + except asyncio.TimeoutError: + # No keep-alive message in the appriopriate interval, make the receive loop stopping properly by knowing from where it comes + self._keep_alive_not_received = True + + if self.receive_data_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.receive_data_task.cancel() + with suppress(asyncio.CancelledError): + await self.receive_data_task + + except asyncio.CancelledError: + # The client is probably closing, handle it properly + pass + + async def _receive_data_loop(self) -> None: try: while True: @@ -372,6 +405,16 @@ async def _receive_data_loop(self) -> None: await self._handle_answer(answer_type, answer_id, execution_result) + except Exception as err: + if self._keep_alive_not_received: + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case the keep-alive interval of the client or server was not aligned the connection still remains) + e = TransportServerError( + "No keep-alive message has been received within the expected interval ('keep_alive_timeout' parameter)") + + await self._fail(e, clean_close=True) + else: + raise err finally: log.debug("Exiting _receive_data_loop()") @@ -534,6 +577,7 @@ async def connect(self) -> None: self.next_query_id = 1 self.close_exception = None + self._keep_alive_not_received = False self._wait_closed.clear() # Send the init message and wait for the ack from the server @@ -547,6 +591,11 @@ async def connect(self) -> None: await self._fail(e, clean_close=False) raise e + # If specified, create a task to check liveness of the connection (through keep-alive messages) + if self.keep_alive_timeout > 0: + self.check_keep_alive_task = asyncio.ensure_future( + self._check_ws_liveness()) + # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) @@ -595,6 +644,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # We should always have an active websocket connection here assert self.websocket is not None + # Properly shut down liveness checker if enabled + if self.check_keep_alive_task is not None: + # More info: https://stackoverflow.com/a/43810272/1113207 + self.check_keep_alive_task.cancel() + with suppress(asyncio.CancelledError): + await self.check_keep_alive_task + # Saving exception to raise it later if trying to use the transport # after it has already closed. self.close_exception = e @@ -627,6 +683,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None self.close_task = None + self.check_keep_alive_task = None self._wait_closed.set() From a1351ab390c9006b9c945ae0acf52f5095b18c31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Rame=CC=81?= Date: Wed, 7 Apr 2021 16:21:05 +0200 Subject: [PATCH 2/5] Fix some check issues --- gql/transport/websockets.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index c3be1aa8..0eb6e97b 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,7 +1,7 @@ import asyncio -from contextlib import suppress import json import logging +from contextlib import suppress from ssl import SSLContext from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast @@ -346,7 +346,9 @@ async def _check_ws_liveness(self) -> None: try: while True: - await asyncio.wait_for(self._next_keep_alive_message.wait(), self.keep_alive_timeout) + await asyncio.wait_for( + self._next_keep_alive_message.wait(), self.keep_alive_timeout + ) # Reset for the next iteration self._next_keep_alive_message.clear() @@ -409,10 +411,12 @@ async def _receive_data_loop(self) -> None: if self._keep_alive_not_received: # No keep-alive message in the appriopriate interval, close with error # while trying to notify the server of a proper close (in case the keep-alive interval of the client or server was not aligned the connection still remains) - e = TransportServerError( - "No keep-alive message has been received within the expected interval ('keep_alive_timeout' parameter)") - - await self._fail(e, clean_close=True) + await self._fail( + TransportServerError( + "No keep-alive message has been received within the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=True, + ) else: raise err finally: @@ -594,7 +598,8 @@ async def connect(self) -> None: # If specified, create a task to check liveness of the connection (through keep-alive messages) if self.keep_alive_timeout > 0: self.check_keep_alive_task = asyncio.ensure_future( - self._check_ws_liveness()) + self._check_ws_liveness() + ) # Create a task to listen to the incoming websocket messages self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) From 4a14637f859ebe38396db1bd2432854dc6697970 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Thomas=20Rame=CC=81?= Date: Wed, 7 Apr 2021 16:27:10 +0200 Subject: [PATCH 3/5] Manual fix for lines length --- gql/transport/websockets.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 0eb6e97b..d46d43db 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -109,7 +109,8 @@ def __init__( :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. - :param keep_alive_timeout: Timeout in seconds to receive a sign of liveness from the server. + :param keep_alive_timeout: Timeout in seconds to receive a sign of liveness + from the server. :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url @@ -341,7 +342,8 @@ def _parse_answer( return answer_type, answer_id, execution_result async def _check_ws_liveness(self) -> None: - """Coroutine which will periodically check the liveness of the connection through keep-alive messages + """Coroutine which will periodically check the liveness of the connection + through keep-alive messages """ try: @@ -353,7 +355,8 @@ async def _check_ws_liveness(self) -> None: # Reset for the next iteration self._next_keep_alive_message.clear() except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, make the receive loop stopping properly by knowing from where it comes + # No keep-alive message in the appriopriate interval, + # make the receive loop stopping properly by knowing from where it comes self._keep_alive_not_received = True if self.receive_data_task is not None: @@ -410,10 +413,13 @@ async def _receive_data_loop(self) -> None: except Exception as err: if self._keep_alive_not_received: # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case the keep-alive interval of the client or server was not aligned the connection still remains) + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) await self._fail( TransportServerError( - "No keep-alive message has been received within the expected interval ('keep_alive_timeout' parameter)" + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" ), clean_close=True, ) @@ -595,7 +601,8 @@ async def connect(self) -> None: await self._fail(e, clean_close=False) raise e - # If specified, create a task to check liveness of the connection (through keep-alive messages) + # If specified, create a task to check liveness of the connection + # through keep-alive messages if self.keep_alive_timeout > 0: self.check_keep_alive_task = asyncio.ensure_future( self._check_ws_liveness() From 6fd76ff09c1db9eb9bea523e4bc63717bb210443 Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 24 Apr 2021 19:12:28 +0200 Subject: [PATCH 4/5] Refactor and add tests modifications: * clean_close = False * keep_alive_timeout is now Optional[int] default None * calling self._fail directly from the _check_ws_liveness coro * no need to cancel the _receive_data_loop coro, it will stop itself once the websocket will close --- gql/transport/websockets.py | 54 ++++++++++------------- tests/test_websocket_subscription.py | 64 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 32 deletions(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 2fa4ce05..a0e02d62 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -95,7 +95,7 @@ def __init__( connect_timeout: int = 10, close_timeout: int = 10, ack_timeout: int = 10, - keep_alive_timeout: int = 0, + keep_alive_timeout: Optional[int] = None, connect_args: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given parameters. @@ -109,8 +109,8 @@ def __init__( :param close_timeout: Timeout in seconds for the close. :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. - :param keep_alive_timeout: Timeout in seconds to receive a sign of liveness - from the server. + :param keep_alive_timeout: Optional Timeout in seconds to receive + a sign of liveness from the server. :param connect_args: Other parameters forwarded to websockets.connect """ self.url: str = url @@ -121,7 +121,7 @@ def __init__( self.connect_timeout: int = connect_timeout self.close_timeout: int = close_timeout self.ack_timeout: int = ack_timeout - self.keep_alive_timeout: int = keep_alive_timeout + self.keep_alive_timeout: Optional[int] = keep_alive_timeout self.connect_args = connect_args @@ -147,10 +147,10 @@ def __init__( self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() - self._next_keep_alive_message: asyncio.Event = asyncio.Event() - self._next_keep_alive_message.set() + if self.keep_alive_timeout is not None: + self._next_keep_alive_message: asyncio.Event = asyncio.Event() + self._next_keep_alive_message.set() - self._keep_alive_not_received: bool = False self._connecting: bool = False self.close_exception: Optional[Exception] = None @@ -356,16 +356,22 @@ async def _check_ws_liveness(self) -> None: # Reset for the next iteration self._next_keep_alive_message.clear() + except asyncio.TimeoutError: - # No keep-alive message in the appriopriate interval, - # make the receive loop stopping properly by knowing from where it comes - self._keep_alive_not_received = True + # No keep-alive message in the appriopriate interval, close with error + # while trying to notify the server of a proper close (in case + # the keep-alive interval of the client or server was not aligned + # the connection still remains) - if self.receive_data_task is not None: - # More info: https://stackoverflow.com/a/43810272/1113207 - self.receive_data_task.cancel() - with suppress(asyncio.CancelledError): - await self.receive_data_task + # If the timeout happens during a close already in progress, do nothing + if self.close_task is None: + await self._fail( + TransportServerError( + "No keep-alive message has been received within " + "the expected interval ('keep_alive_timeout' parameter)" + ), + clean_close=False, + ) except asyncio.CancelledError: # The client is probably closing, handle it properly @@ -412,21 +418,6 @@ async def _receive_data_loop(self) -> None: await self._handle_answer(answer_type, answer_id, execution_result) - except Exception as err: - if self._keep_alive_not_received: - # No keep-alive message in the appriopriate interval, close with error - # while trying to notify the server of a proper close (in case - # the keep-alive interval of the client or server was not aligned - # the connection still remains) - await self._fail( - TransportServerError( - "No keep-alive message has been received within " - "the expected interval ('keep_alive_timeout' parameter)" - ), - clean_close=True, - ) - else: - raise err finally: log.debug("Exiting _receive_data_loop()") @@ -589,7 +580,6 @@ async def connect(self) -> None: self.next_query_id = 1 self.close_exception = None - self._keep_alive_not_received = False self._wait_closed.clear() # Send the init message and wait for the ack from the server @@ -605,7 +595,7 @@ async def connect(self) -> None: # If specified, create a task to check liveness of the connection # through keep-alive messages - if self.keep_alive_timeout > 0: + if self.keep_alive_timeout is not None: self.check_keep_alive_task = asyncio.ensure_future( self._check_ws_liveness() ) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 7d80c8eb..4e75621a 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -7,6 +7,7 @@ from parse import search from gql import Client, gql +from gql.transport.exceptions import TransportServerError from .conftest import MS, WebSocketServerHelper @@ -378,6 +379,69 @@ async def test_websocket_subscription_with_keepalive( assert count == -1 +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_ok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive_with_timeout_nok( + event_loop, server, subscription_str +): + + from gql.transport.websockets import WebsocketsTransport + + path = "/graphql" + url = f"ws://{server.hostname}:{server.port}{path}" + sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + + client = Client(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async with client as session: + with pytest.raises(TransportServerError) as exc_info: + async for result in session.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert "No keep-alive message has been received" in str(exc_info.value) + + assert count == 9 + + @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) def test_websocket_subscription_sync(server, subscription_str): From bfd6292c06a429c2b61b93d2e9fc1e6cea016bbb Mon Sep 17 00:00:00 2001 From: Hanusz Leszek Date: Sat, 24 Apr 2021 21:51:45 +0200 Subject: [PATCH 5/5] Remove flaky assert --- tests/test_websocket_subscription.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 4e75621a..fcd176b5 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -439,8 +439,6 @@ async def test_websocket_subscription_with_keepalive_with_timeout_nok( assert "No keep-alive message has been received" in str(exc_info.value) - assert count == 9 - @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str])