Skip to content

Handle keep-alive behavior to close the connection #201

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 62 additions & 3 deletions gql/transport/websockets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import json
import logging
from contextlib import suppress
from ssl import SSLContext
from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast

Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
connect_timeout: int = 10,
close_timeout: int = 10,
ack_timeout: int = 10,
keep_alive_timeout: Optional[int] = None,
connect_args: Dict[str, Any] = {},
) -> None:
"""Initialize the transport with the given parameters.
Expand All @@ -107,6 +109,8 @@ def __init__(
:param close_timeout: Timeout in seconds for the close.
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
from the server.
:param keep_alive_timeout: Optional Timeout in seconds to receive
a sign of liveness from the server.
:param connect_args: Other parameters forwarded to websockets.connect
"""
self.url: str = url
Expand All @@ -117,6 +121,7 @@ def __init__(
self.connect_timeout: int = connect_timeout
self.close_timeout: int = close_timeout
self.ack_timeout: int = ack_timeout
self.keep_alive_timeout: Optional[int] = keep_alive_timeout

self.connect_args = connect_args

Expand All @@ -125,6 +130,7 @@ def __init__(
self.listeners: Dict[int, ListenerQueue] = {}

self.receive_data_task: Optional[asyncio.Future] = None
self.check_keep_alive_task: Optional[asyncio.Future] = None
self.close_task: Optional[asyncio.Future] = None

# We need to set an event loop here if there is none
Expand All @@ -141,6 +147,10 @@ def __init__(
self._no_more_listeners: asyncio.Event = asyncio.Event()
self._no_more_listeners.set()

if self.keep_alive_timeout is not None:
self._next_keep_alive_message: asyncio.Event = asyncio.Event()
self._next_keep_alive_message.set()

self._connecting: bool = False

self.close_exception: Optional[Exception] = None
Expand Down Expand Up @@ -315,8 +325,9 @@ def _parse_answer(
)

elif answer_type == "ka":
# KeepAlive message
pass
# Keep-alive message
if self.check_keep_alive_task is not None:
self._next_keep_alive_message.set()
elif answer_type == "connection_ack":
pass
elif answer_type == "connection_error":
Expand All @@ -332,8 +343,41 @@ def _parse_answer(

return answer_type, answer_id, execution_result

async def _receive_data_loop(self) -> None:
async def _check_ws_liveness(self) -> None:
"""Coroutine which will periodically check the liveness of the connection
through keep-alive messages
"""

try:
while True:
await asyncio.wait_for(
self._next_keep_alive_message.wait(), self.keep_alive_timeout
)

# Reset for the next iteration
self._next_keep_alive_message.clear()

except asyncio.TimeoutError:
# No keep-alive message in the appriopriate interval, close with error
# while trying to notify the server of a proper close (in case
# the keep-alive interval of the client or server was not aligned
# the connection still remains)

# If the timeout happens during a close already in progress, do nothing
if self.close_task is None:
await self._fail(
TransportServerError(
"No keep-alive message has been received within "
"the expected interval ('keep_alive_timeout' parameter)"
),
clean_close=False,
)

except asyncio.CancelledError:
# The client is probably closing, handle it properly
pass

async def _receive_data_loop(self) -> None:
try:
while True:

Expand Down Expand Up @@ -549,6 +593,13 @@ async def connect(self) -> None:
await self._fail(e, clean_close=False)
raise e

# If specified, create a task to check liveness of the connection
# through keep-alive messages
if self.keep_alive_timeout is not None:
self.check_keep_alive_task = asyncio.ensure_future(
self._check_ws_liveness()
)

# Create a task to listen to the incoming websocket messages
self.receive_data_task = asyncio.ensure_future(self._receive_data_loop())

Expand Down Expand Up @@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
# We should always have an active websocket connection here
assert self.websocket is not None

# Properly shut down liveness checker if enabled
if self.check_keep_alive_task is not None:
# More info: https://stackoverflow.com/a/43810272/1113207
self.check_keep_alive_task.cancel()
with suppress(asyncio.CancelledError):
await self.check_keep_alive_task

# Saving exception to raise it later if trying to use the transport
# after it has already closed.
self.close_exception = e
Expand Down Expand Up @@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:

self.websocket = None
self.close_task = None
self.check_keep_alive_task = None

self._wait_closed.set()

Expand Down
62 changes: 62 additions & 0 deletions tests/test_websocket_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from parse import search

from gql import Client, gql
from gql.transport.exceptions import TransportServerError

from .conftest import MS, WebSocketServerHelper

Expand Down Expand Up @@ -378,6 +379,67 @@ async def test_websocket_subscription_with_keepalive(
assert count == -1


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_websocket_subscription_with_keepalive_with_timeout_ok(
event_loop, server, subscription_str
):

from gql.transport.websockets import WebsocketsTransport

path = "/graphql"
url = f"ws://{server.hostname}:{server.port}{path}"
sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS))

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

async with client as session:
async for result in session.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert count == -1


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_websocket_subscription_with_keepalive_with_timeout_nok(
event_loop, server, subscription_str
):

from gql.transport.websockets import WebsocketsTransport

path = "/graphql"
url = f"ws://{server.hostname}:{server.port}{path}"
sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS))

client = Client(transport=sample_transport)

count = 10
subscription = gql(subscription_str.format(count=count))

async with client as session:
with pytest.raises(TransportServerError) as exc_info:
async for result in session.subscribe(subscription):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert "No keep-alive message has been received" in str(exc_info.value)


@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
def test_websocket_subscription_sync(server, subscription_str):
Expand Down