Skip to content

Commit 35203e8

Browse files
authored
Handle keep-alive behavior to close the connection (#201)
1 parent 4528977 commit 35203e8

File tree

2 files changed

+124
-3
lines changed

2 files changed

+124
-3
lines changed

gql/transport/websockets.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import json
33
import logging
4+
from contextlib import suppress
45
from ssl import SSLContext
56
from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast
67

@@ -94,6 +95,7 @@ def __init__(
9495
connect_timeout: int = 10,
9596
close_timeout: int = 10,
9697
ack_timeout: int = 10,
98+
keep_alive_timeout: Optional[int] = None,
9799
connect_args: Dict[str, Any] = {},
98100
) -> None:
99101
"""Initialize the transport with the given parameters.
@@ -107,6 +109,8 @@ def __init__(
107109
:param close_timeout: Timeout in seconds for the close.
108110
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
109111
from the server.
112+
:param keep_alive_timeout: Optional Timeout in seconds to receive
113+
a sign of liveness from the server.
110114
:param connect_args: Other parameters forwarded to websockets.connect
111115
"""
112116
self.url: str = url
@@ -117,6 +121,7 @@ def __init__(
117121
self.connect_timeout: int = connect_timeout
118122
self.close_timeout: int = close_timeout
119123
self.ack_timeout: int = ack_timeout
124+
self.keep_alive_timeout: Optional[int] = keep_alive_timeout
120125

121126
self.connect_args = connect_args
122127

@@ -125,6 +130,7 @@ def __init__(
125130
self.listeners: Dict[int, ListenerQueue] = {}
126131

127132
self.receive_data_task: Optional[asyncio.Future] = None
133+
self.check_keep_alive_task: Optional[asyncio.Future] = None
128134
self.close_task: Optional[asyncio.Future] = None
129135

130136
# We need to set an event loop here if there is none
@@ -141,6 +147,10 @@ def __init__(
141147
self._no_more_listeners: asyncio.Event = asyncio.Event()
142148
self._no_more_listeners.set()
143149

150+
if self.keep_alive_timeout is not None:
151+
self._next_keep_alive_message: asyncio.Event = asyncio.Event()
152+
self._next_keep_alive_message.set()
153+
144154
self._connecting: bool = False
145155

146156
self.close_exception: Optional[Exception] = None
@@ -315,8 +325,9 @@ def _parse_answer(
315325
)
316326

317327
elif answer_type == "ka":
318-
# KeepAlive message
319-
pass
328+
# Keep-alive message
329+
if self.check_keep_alive_task is not None:
330+
self._next_keep_alive_message.set()
320331
elif answer_type == "connection_ack":
321332
pass
322333
elif answer_type == "connection_error":
@@ -332,8 +343,41 @@ def _parse_answer(
332343

333344
return answer_type, answer_id, execution_result
334345

335-
async def _receive_data_loop(self) -> None:
346+
async def _check_ws_liveness(self) -> None:
347+
"""Coroutine which will periodically check the liveness of the connection
348+
through keep-alive messages
349+
"""
350+
351+
try:
352+
while True:
353+
await asyncio.wait_for(
354+
self._next_keep_alive_message.wait(), self.keep_alive_timeout
355+
)
336356

357+
# Reset for the next iteration
358+
self._next_keep_alive_message.clear()
359+
360+
except asyncio.TimeoutError:
361+
# No keep-alive message in the appriopriate interval, close with error
362+
# while trying to notify the server of a proper close (in case
363+
# the keep-alive interval of the client or server was not aligned
364+
# the connection still remains)
365+
366+
# If the timeout happens during a close already in progress, do nothing
367+
if self.close_task is None:
368+
await self._fail(
369+
TransportServerError(
370+
"No keep-alive message has been received within "
371+
"the expected interval ('keep_alive_timeout' parameter)"
372+
),
373+
clean_close=False,
374+
)
375+
376+
except asyncio.CancelledError:
377+
# The client is probably closing, handle it properly
378+
pass
379+
380+
async def _receive_data_loop(self) -> None:
337381
try:
338382
while True:
339383

@@ -549,6 +593,13 @@ async def connect(self) -> None:
549593
await self._fail(e, clean_close=False)
550594
raise e
551595

596+
# If specified, create a task to check liveness of the connection
597+
# through keep-alive messages
598+
if self.keep_alive_timeout is not None:
599+
self.check_keep_alive_task = asyncio.ensure_future(
600+
self._check_ws_liveness()
601+
)
602+
552603
# Create a task to listen to the incoming websocket messages
553604
self.receive_data_task = asyncio.ensure_future(self._receive_data_loop())
554605

@@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
597648
# We should always have an active websocket connection here
598649
assert self.websocket is not None
599650

651+
# Properly shut down liveness checker if enabled
652+
if self.check_keep_alive_task is not None:
653+
# More info: https://stackoverflow.com/a/43810272/1113207
654+
self.check_keep_alive_task.cancel()
655+
with suppress(asyncio.CancelledError):
656+
await self.check_keep_alive_task
657+
600658
# Saving exception to raise it later if trying to use the transport
601659
# after it has already closed.
602660
self.close_exception = e
@@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
629687

630688
self.websocket = None
631689
self.close_task = None
690+
self.check_keep_alive_task = None
632691

633692
self._wait_closed.set()
634693

tests/test_websocket_subscription.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from parse import search
88

99
from gql import Client, gql
10+
from gql.transport.exceptions import TransportServerError
1011

1112
from .conftest import MS, WebSocketServerHelper
1213

@@ -378,6 +379,67 @@ async def test_websocket_subscription_with_keepalive(
378379
assert count == -1
379380

380381

382+
@pytest.mark.asyncio
383+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
384+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
385+
async def test_websocket_subscription_with_keepalive_with_timeout_ok(
386+
event_loop, server, subscription_str
387+
):
388+
389+
from gql.transport.websockets import WebsocketsTransport
390+
391+
path = "/graphql"
392+
url = f"ws://{server.hostname}:{server.port}{path}"
393+
sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(500 * MS))
394+
395+
client = Client(transport=sample_transport)
396+
397+
count = 10
398+
subscription = gql(subscription_str.format(count=count))
399+
400+
async with client as session:
401+
async for result in session.subscribe(subscription):
402+
403+
number = result["number"]
404+
print(f"Number received: {number}")
405+
406+
assert number == count
407+
count -= 1
408+
409+
assert count == -1
410+
411+
412+
@pytest.mark.asyncio
413+
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
414+
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
415+
async def test_websocket_subscription_with_keepalive_with_timeout_nok(
416+
event_loop, server, subscription_str
417+
):
418+
419+
from gql.transport.websockets import WebsocketsTransport
420+
421+
path = "/graphql"
422+
url = f"ws://{server.hostname}:{server.port}{path}"
423+
sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS))
424+
425+
client = Client(transport=sample_transport)
426+
427+
count = 10
428+
subscription = gql(subscription_str.format(count=count))
429+
430+
async with client as session:
431+
with pytest.raises(TransportServerError) as exc_info:
432+
async for result in session.subscribe(subscription):
433+
434+
number = result["number"]
435+
print(f"Number received: {number}")
436+
437+
assert number == count
438+
count -= 1
439+
440+
assert "No keep-alive message has been received" in str(exc_info.value)
441+
442+
381443
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
382444
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
383445
def test_websocket_subscription_sync(server, subscription_str):

0 commit comments

Comments
 (0)