1
1
import asyncio
2
2
import json
3
3
import logging
4
+ from contextlib import suppress
4
5
from ssl import SSLContext
5
6
from typing import Any , AsyncGenerator , Dict , Optional , Tuple , Union , cast
6
7
@@ -94,6 +95,7 @@ def __init__(
94
95
connect_timeout : int = 10 ,
95
96
close_timeout : int = 10 ,
96
97
ack_timeout : int = 10 ,
98
+ keep_alive_timeout : Optional [int ] = None ,
97
99
connect_args : Dict [str , Any ] = {},
98
100
) -> None :
99
101
"""Initialize the transport with the given parameters.
@@ -107,6 +109,8 @@ def __init__(
107
109
:param close_timeout: Timeout in seconds for the close.
108
110
:param ack_timeout: Timeout in seconds to wait for the connection_ack message
109
111
from the server.
112
+ :param keep_alive_timeout: Optional Timeout in seconds to receive
113
+ a sign of liveness from the server.
110
114
:param connect_args: Other parameters forwarded to websockets.connect
111
115
"""
112
116
self .url : str = url
@@ -117,6 +121,7 @@ def __init__(
117
121
self .connect_timeout : int = connect_timeout
118
122
self .close_timeout : int = close_timeout
119
123
self .ack_timeout : int = ack_timeout
124
+ self .keep_alive_timeout : Optional [int ] = keep_alive_timeout
120
125
121
126
self .connect_args = connect_args
122
127
@@ -125,6 +130,7 @@ def __init__(
125
130
self .listeners : Dict [int , ListenerQueue ] = {}
126
131
127
132
self .receive_data_task : Optional [asyncio .Future ] = None
133
+ self .check_keep_alive_task : Optional [asyncio .Future ] = None
128
134
self .close_task : Optional [asyncio .Future ] = None
129
135
130
136
# We need to set an event loop here if there is none
@@ -141,6 +147,10 @@ def __init__(
141
147
self ._no_more_listeners : asyncio .Event = asyncio .Event ()
142
148
self ._no_more_listeners .set ()
143
149
150
+ if self .keep_alive_timeout is not None :
151
+ self ._next_keep_alive_message : asyncio .Event = asyncio .Event ()
152
+ self ._next_keep_alive_message .set ()
153
+
144
154
self ._connecting : bool = False
145
155
146
156
self .close_exception : Optional [Exception ] = None
@@ -315,8 +325,9 @@ def _parse_answer(
315
325
)
316
326
317
327
elif answer_type == "ka" :
318
- # KeepAlive message
319
- pass
328
+ # Keep-alive message
329
+ if self .check_keep_alive_task is not None :
330
+ self ._next_keep_alive_message .set ()
320
331
elif answer_type == "connection_ack" :
321
332
pass
322
333
elif answer_type == "connection_error" :
@@ -332,8 +343,41 @@ def _parse_answer(
332
343
333
344
return answer_type , answer_id , execution_result
334
345
335
- async def _receive_data_loop (self ) -> None :
346
+ async def _check_ws_liveness (self ) -> None :
347
+ """Coroutine which will periodically check the liveness of the connection
348
+ through keep-alive messages
349
+ """
350
+
351
+ try :
352
+ while True :
353
+ await asyncio .wait_for (
354
+ self ._next_keep_alive_message .wait (), self .keep_alive_timeout
355
+ )
336
356
357
+ # Reset for the next iteration
358
+ self ._next_keep_alive_message .clear ()
359
+
360
+ except asyncio .TimeoutError :
361
+ # No keep-alive message in the appriopriate interval, close with error
362
+ # while trying to notify the server of a proper close (in case
363
+ # the keep-alive interval of the client or server was not aligned
364
+ # the connection still remains)
365
+
366
+ # If the timeout happens during a close already in progress, do nothing
367
+ if self .close_task is None :
368
+ await self ._fail (
369
+ TransportServerError (
370
+ "No keep-alive message has been received within "
371
+ "the expected interval ('keep_alive_timeout' parameter)"
372
+ ),
373
+ clean_close = False ,
374
+ )
375
+
376
+ except asyncio .CancelledError :
377
+ # The client is probably closing, handle it properly
378
+ pass
379
+
380
+ async def _receive_data_loop (self ) -> None :
337
381
try :
338
382
while True :
339
383
@@ -549,6 +593,13 @@ async def connect(self) -> None:
549
593
await self ._fail (e , clean_close = False )
550
594
raise e
551
595
596
+ # If specified, create a task to check liveness of the connection
597
+ # through keep-alive messages
598
+ if self .keep_alive_timeout is not None :
599
+ self .check_keep_alive_task = asyncio .ensure_future (
600
+ self ._check_ws_liveness ()
601
+ )
602
+
552
603
# Create a task to listen to the incoming websocket messages
553
604
self .receive_data_task = asyncio .ensure_future (self ._receive_data_loop ())
554
605
@@ -597,6 +648,13 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
597
648
# We should always have an active websocket connection here
598
649
assert self .websocket is not None
599
650
651
+ # Properly shut down liveness checker if enabled
652
+ if self .check_keep_alive_task is not None :
653
+ # More info: https://stackoverflow.com/a/43810272/1113207
654
+ self .check_keep_alive_task .cancel ()
655
+ with suppress (asyncio .CancelledError ):
656
+ await self .check_keep_alive_task
657
+
600
658
# Saving exception to raise it later if trying to use the transport
601
659
# after it has already closed.
602
660
self .close_exception = e
@@ -629,6 +687,7 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None:
629
687
630
688
self .websocket = None
631
689
self .close_task = None
690
+ self .check_keep_alive_task = None
632
691
633
692
self ._wait_closed .set ()
634
693
0 commit comments