diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index b4fd3dec..6004442d 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -175,6 +175,23 @@ async def _receive(self) -> str: return answer + async def _wait_ack(self) -> None: + """Wait for the connection_ack message. Keep alive messages are ignored + """ + + while True: + init_answer = await self._receive() + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type == "connection_ack": + return + + if answer_type != "ka": + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) + async def _send_init_message_and_wait_ack(self) -> None: """Send init message to the provided websocket and wait for the connection ACK. @@ -188,14 +205,7 @@ async def _send_init_message_and_wait_ack(self) -> None: await self._send(init_message) # Wait for the connection_ack message or raise a TimeoutError - init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout) - - answer_type, answer_id, execution_result = self._parse_answer(init_answer) - - if answer_type != "connection_ack": - raise TransportProtocolError( - "Websocket server did not return a connection ack" - ) + await asyncio.wait_for(self._wait_ack(), self.ack_timeout) async def _send_stop_message(self, query_id: int) -> None: """Send stop message to the provided websocket connection and query_id. diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 9c1f7a17..aed27189 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -241,7 +241,7 @@ async def test_websocket_transport_protocol_errors(event_loop, client_and_server async def server_without_ack(ws, path): # Sending something else than an ack - await WebSocketServer.send_keepalive(ws) + await WebSocketServer.send_complete(ws, 1) await ws.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 4d235d95..13014c28 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -472,3 +472,43 @@ async def test_websocket_add_extra_parameters_to_connect(event_loop, server): async with Client(transport=sample_transport) as session: await session.execute(query) + + +async def server_sending_keep_alive_before_connection_ack(ws, path): + await WebSocketServer.send_keepalive(ws) + await WebSocketServer.send_keepalive(ws) + await WebSocketServer.send_keepalive(ws) + await WebSocketServer.send_keepalive(ws) + await WebSocketServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await WebSocketServer.send_complete(ws, 1) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_sending_keep_alive_before_connection_ack], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_non_regression_bug_108( + event_loop, client_and_server, query_str +): + + # This test will check that we now ignore keepalive message + # arriving before the connection_ack + # See bug #108 + + session, server = client_and_server + + query = gql(query_str) + + result = await session.execute(query) + + print("Client received:", result) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF"