From a2a35245520fd08aed18576e1a19f7a5e4c3af70 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 12 Mar 2025 21:28:23 +0100 Subject: [PATCH 1/5] Bump websockets to >=14.2 Using new websockets asyncio implementation --- gql/transport/common/adapters/websockets.py | 12 ++++--- gql/transport/common/base.py | 12 ++++--- setup.py | 2 +- tests/conftest.py | 39 +++++++-------------- tests/test_aiohttp_websocket_exceptions.py | 20 +++++------ tests/test_appsync_websockets.py | 2 +- tests/test_graphqlws_exceptions.py | 14 ++++---- tests/test_websocket_exceptions.py | 2 +- tests/test_websocket_query.py | 4 +-- 9 files changed, 48 insertions(+), 59 deletions(-) diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index c2524fb4..bf8574ca 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional, Union import websockets -from websockets.client import WebSocketClientProtocol +from websockets import ClientConnection from websockets.datastructures import Headers, HeadersLike from ...exceptions import TransportConnectionFailed, TransportProtocolError @@ -40,7 +40,7 @@ def __init__( self._headers: Optional[HeadersLike] = headers self.ssl = ssl - self.websocket: Optional[WebSocketClientProtocol] = None + self.websocket: Optional[ClientConnection] = None self._response_headers: Optional[Headers] = None async def connect(self) -> None: @@ -57,7 +57,7 @@ async def connect(self) -> None: # Set default arguments used in the websockets.connect call connect_args: Dict[str, Any] = { "ssl": ssl, - "extra_headers": self.headers, + "additional_headers": self.headers, } if self.subprotocols: @@ -68,11 +68,13 @@ async def connect(self) -> None: # Connection to the specified url try: - self.websocket = await websockets.client.connect(self.url, **connect_args) + self.websocket = await websockets.connect(self.url, **connect_args) except Exception as e: raise TransportConnectionFailed("Connect failed") from e - self._response_headers = self.websocket.response_headers + assert self.websocket.response is not None + + self._response_headers = self.websocket.response.headers async def send(self, message: str) -> None: """Send message to the WebSocket server. diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index a3d025c0..664353df 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -482,6 +482,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # We should always have an active websocket connection here assert self._connected + # Saving exception to raise it later if trying to use the transport + # after it has already closed. + self.close_exception = e + # Properly shut down liveness checker if enabled if self.check_keep_alive_task is not None: # More info: https://stackoverflow.com/a/43810272/1113207 @@ -492,10 +496,6 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: # Calling the subclass close hook await self._close_hook() - # Saving exception to raise it later if trying to use the transport - # after it has already closed. - self.close_exception = e - if clean_close: log.debug("_close_coro: starting clean_close") try: @@ -503,7 +503,9 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: except Exception as exc: # pragma: no cover log.warning("Ignoring exception in _clean_close: " + repr(exc)) - log.debug("_close_coro: sending exception to listeners") + log.debug( + f"_close_coro: sending exception to {len(self.listeners)} listeners" + ) # Send an exception to all remaining listeners for query_id, listener in self.listeners.items(): diff --git a/setup.py b/setup.py index a36284b0..aed15440 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ ] install_websockets_requires = [ - "websockets>=10.1,<14", + "websockets>=14.2,<16", ] install_botocore_requires = [ diff --git a/tests/conftest.py b/tests/conftest.py index 70a050d5..c69551b0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -197,7 +197,7 @@ def __init__(self, with_ssl: bool = False): async def start(self, handler, extra_serve_args=None): - import websockets.server + import websockets print("Starting server") @@ -209,16 +209,21 @@ async def start(self, handler, extra_serve_args=None): extra_serve_args["ssl"] = ssl_context # Adding dummy response headers - extra_serve_args["extra_headers"] = {"dummy": "test1234"} + extra_headers = {"dummy": "test1234"} + + def process_response(connection, request, response): + response.headers.update(extra_headers) + return response # Start a server with a random open port - self.start_server = websockets.server.serve( - handler, "127.0.0.1", 0, **extra_serve_args + self.server = await websockets.serve( + handler, + "127.0.0.1", + 0, + process_response=process_response, + **extra_serve_args, ) - # Wait that the server is started - self.server = await self.start_server - # Get hostname and port hostname, port = self.server.sockets[0].getsockname()[:2] # type: ignore assert hostname == "127.0.0.1" @@ -603,24 +608,6 @@ async def graphqlws_server(request): subprotocol = "graphql-transport-ws" - from websockets.server import WebSocketServerProtocol - - class CustomSubprotocol(WebSocketServerProtocol): - def select_subprotocol(self, client_subprotocols, server_subprotocols): - print(f"Client subprotocols: {client_subprotocols!r}") - print(f"Server subprotocols: {server_subprotocols!r}") - - return subprotocol - - def process_subprotocol(self, headers, available_subprotocols): - # Overwriting available subprotocols - available_subprotocols = [subprotocol] - - print(f"headers: {headers!r}") - # print (f"Available subprotocols: {available_subprotocols!r}") - - return super().process_subprotocol(headers, available_subprotocols) - server_handler = get_server_handler(request) try: @@ -628,7 +615,7 @@ def process_subprotocol(self, headers, available_subprotocols): # Starting the server with the fixture param as the handler function await test_server.start( - server_handler, extra_serve_args={"create_protocol": CustomSubprotocol} + server_handler, extra_serve_args={"subprotocols": [subprotocol]} ) yield test_server diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 86c502a9..30776d61 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -118,10 +118,10 @@ async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -261,10 +261,10 @@ async def test_aiohttp_websocket_server_does_not_ack(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -281,10 +281,10 @@ async def test_aiohttp_websocket_server_closing_directly(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -323,10 +323,10 @@ async def test_aiohttp_websocket_server_sending_invalid_query_errors(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @@ -342,9 +342,9 @@ async def test_aiohttp_websocket_non_regression_bug_105(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = AIOHTTPWebsocketsTransport(url=url) + transport = AIOHTTPWebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): diff --git a/tests/test_appsync_websockets.py b/tests/test_appsync_websockets.py index 37cbe460..0be04034 100644 --- a/tests/test_appsync_websockets.py +++ b/tests/test_appsync_websockets.py @@ -139,7 +139,7 @@ async def realtime_appsync_server_template(ws): ) return - path = ws.path + path = ws.request.path print(f"path = {path}") diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 2e3514d1..4cf8b89c 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -111,10 +111,10 @@ async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -212,10 +212,10 @@ async def test_graphqlws_server_does_not_ack(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -231,10 +231,10 @@ async def test_graphqlws_server_closing_directly(graphqlws_server): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -251,7 +251,7 @@ async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server): query = gql("query { hello }") - with pytest.raises(TransportConnectionFailed): + with pytest.raises(TransportClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 08058aea..31f2712b 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -296,7 +296,7 @@ async def test_websocket_server_closing_after_ack(client_and_server): query = gql("query { hello }") - with pytest.raises(TransportConnectionFailed): + with pytest.raises(TransportClosed): await session.execute(query) await session.transport.wait_closed() diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 99ff7334..732d686f 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -112,9 +112,7 @@ async def test_websocket_using_ssl_connection(ws_ssl_server): async with Client(transport=transport) as session: - assert isinstance( - transport.adapter.websocket, websockets.client.WebSocketClientProtocol - ) + assert isinstance(transport.adapter.websocket, websockets.ClientConnection) query1 = gql(query1_str) From a0f75a5bb602f54b7fefa9f3f79b9bd393f2a091 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 12 Mar 2025 22:04:13 +0100 Subject: [PATCH 2/5] Rename sample_transport to transport once and for all --- tests/test_aiohttp_online.py | 14 ++++----- tests/test_appsync_auth.py | 34 +++++++++++----------- tests/test_appsync_http.py | 4 +-- tests/test_async_client_validation.py | 16 +++++----- tests/test_http_async_sync.py | 18 ++++++------ tests/test_httpx_online.py | 14 ++++----- tests/test_phoenix_channel_exceptions.py | 32 ++++++++------------ tests/test_phoenix_channel_subscription.py | 12 ++++---- tests/test_websocket_exceptions.py | 20 ++++++------- tests/test_websocket_subscription.py | 30 +++++++++---------- 10 files changed, 90 insertions(+), 104 deletions(-) diff --git a/tests/test_aiohttp_online.py b/tests/test_aiohttp_online.py index 7cacd921..a4f2480c 100644 --- a/tests/test_aiohttp_online.py +++ b/tests/test_aiohttp_online.py @@ -19,10 +19,10 @@ async def test_aiohttp_simple_query(): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -60,11 +60,9 @@ async def test_aiohttp_invalid_query(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = AIOHTTPTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -89,12 +87,12 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks(): from gql.transport.aiohttp import AIOHTTPTransport - sample_transport = AIOHTTPTransport( + transport = AIOHTTPTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_appsync_auth.py b/tests/test_appsync_auth.py index 8abb3410..94eaed2b 100644 --- a/tests/test_appsync_auth.py +++ b/tests/test_appsync_auth.py @@ -9,15 +9,15 @@ def test_appsync_init_with_minimal_args(fake_session_factory): from gql.transport.appsync_auth import AppSyncIAMAuthentication from gql.transport.appsync_websockets import AppSyncWebsocketsTransport - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory() ) - assert isinstance(sample_transport.auth, AppSyncIAMAuthentication) - assert sample_transport.connect_timeout == 10 - assert sample_transport.close_timeout == 10 - assert sample_transport.ack_timeout == 10 - assert sample_transport.ssl is False - assert sample_transport.connect_args == {} + assert isinstance(transport.auth, AppSyncIAMAuthentication) + assert transport.connect_timeout == 10 + assert transport.close_timeout == 10 + assert transport.ack_timeout == 10 + assert transport.ssl is False + assert transport.connect_args == {} @pytest.mark.botocore @@ -27,11 +27,11 @@ def test_appsync_init_with_no_credentials(caplog, fake_session_factory): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport with pytest.raises(botocore.exceptions.NoCredentialsError): - sample_transport = AppSyncWebsocketsTransport( + transport = AppSyncWebsocketsTransport( url=mock_transport_url, session=fake_session_factory(credentials=None), ) - assert sample_transport.auth is None + assert transport.auth is None expected_error = "Credentials not found" @@ -46,8 +46,8 @@ def test_appsync_init_with_jwt_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncJWTAuthentication(host=mock_transport_host, jwt="some-jwt") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -61,8 +61,8 @@ def test_appsync_init_with_apikey_auth(): from gql.transport.appsync_websockets import AppSyncWebsocketsTransport auth = AppSyncApiKeyAuthentication(host=mock_transport_host, api_key="some-api-key") - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth assert auth.get_headers() == { "host": mock_transport_host, @@ -95,8 +95,8 @@ def test_appsync_init_with_iam_auth_with_creds(fake_credentials_factory): credentials=fake_credentials_factory(), region_name="us-east-1", ) - sample_transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) - assert sample_transport.auth is auth + transport = AppSyncWebsocketsTransport(url=mock_transport_url, auth=auth) + assert transport.auth is auth @pytest.mark.botocore @@ -153,7 +153,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): signer=fake_signer_factory(), request_creator=fake_request_factory, ) - sample_transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) + transport = AppSyncWebsocketsTransport(url=test_url, auth=auth) header_string = ( "eyJGYWtlQXV0aG9yaXphdGlvbiI6ImEiLCJGYWtlVGltZSI6InRvZGF5" @@ -164,7 +164,7 @@ def test_munge_url(fake_signer_factory, fake_request_factory): "wss://appsync-realtime-api.aws.example.org/" f"some-other-params?header={header_string}&payload=e30=" ) - assert sample_transport.url == expected_url + assert transport.url == expected_url @pytest.mark.botocore diff --git a/tests/test_appsync_http.py b/tests/test_appsync_http.py index 536b2fe9..168924bc 100644 --- a/tests/test_appsync_http.py +++ b/tests/test_appsync_http.py @@ -49,9 +49,9 @@ async def handler(request): region_name="us-east-1", ) - sample_transport = AIOHTTPTransport(url=url, auth=auth) + transport = AIOHTTPTransport(url=url, auth=auth) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ diff --git a/tests/test_async_client_validation.py b/tests/test_async_client_validation.py index be214134..c256e5dd 100644 --- a/tests/test_async_client_validation.py +++ b/tests/test_async_client_validation.py @@ -91,9 +91,9 @@ async def test_async_client_validation(server, subscription_str, client_params): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: @@ -138,9 +138,9 @@ async def test_async_client_validation_invalid_query( url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport, **client_params) + client = Client(transport=transport, **client_params) async with client as session: @@ -171,10 +171,10 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(AssertionError): - async with Client(transport=sample_transport, **client_params): + async with Client(transport=transport, **client_params): pass @@ -261,10 +261,10 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=True, ) as session: diff --git a/tests/test_http_async_sync.py b/tests/test_http_async_sync.py index 45efd7f5..61dc1809 100644 --- a/tests/test_http_async_sync.py +++ b/tests/test_http_async_sync.py @@ -15,11 +15,11 @@ async def test_async_client_async_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instantiate client async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) as session: @@ -58,17 +58,17 @@ async def test_async_client_sync_transport(fetch_schema_from_transport): url = "http://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Impossible to use a sync transport asynchronously with pytest.raises(AssertionError): async with Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ): pass - sample_transport.close() + transport.close() @pytest.mark.aiohttp @@ -82,11 +82,11 @@ def test_sync_client_async_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get async transport - sample_transport = AIOHTTPTransport(url=url) + transport = AIOHTTPTransport(url=url) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) @@ -125,11 +125,11 @@ def test_sync_client_sync_transport(fetch_schema_from_transport): url = "https://countries.trevorblades.com/graphql" # Get sync transport - sample_transport = RequestsHTTPTransport(url=url, use_json=True) + transport = RequestsHTTPTransport(url=url, use_json=True) # Instanciate client client = Client( - transport=sample_transport, + transport=transport, fetch_schema_from_transport=fetch_schema_from_transport, ) diff --git a/tests/test_httpx_online.py b/tests/test_httpx_online.py index 3b08fa18..c6e84368 100644 --- a/tests/test_httpx_online.py +++ b/tests/test_httpx_online.py @@ -19,10 +19,10 @@ async def test_httpx_simple_query(): url = "https://countries.trevorblades.com/graphql" # Get transport - sample_transport = HTTPXAsyncTransport(url=url) + transport = HTTPXAsyncTransport(url=url) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -60,11 +60,9 @@ async def test_httpx_invalid_query(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( - url="https://countries.trevorblades.com/graphql" - ) + transport = HTTPXAsyncTransport(url="https://countries.trevorblades.com/graphql") - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query = gql( """ @@ -89,12 +87,12 @@ async def test_httpx_two_queries_in_parallel_using_two_tasks(): from gql.transport.httpx import HTTPXAsyncTransport - sample_transport = HTTPXAsyncTransport( + transport = HTTPXAsyncTransport( url="https://countries.trevorblades.com/graphql", ) # Instanciate client - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: query1 = gql( """ diff --git a/tests/test_phoenix_channel_exceptions.py b/tests/test_phoenix_channel_exceptions.py index 09c129b3..b7f11dcb 100644 --- a/tests/test_phoenix_channel_exceptions.py +++ b/tests/test_phoenix_channel_exceptions.py @@ -167,13 +167,11 @@ async def test_phoenix_channel_query_protocol_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -197,13 +195,11 @@ async def test_phoenix_channel_query_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportQueryError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -414,13 +410,11 @@ async def test_phoenix_channel_subscription_protocol_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await asyncio.sleep(10 * MS) break @@ -444,13 +438,11 @@ async def test_phoenix_channel_server_error(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( - channel_name="test_channel", url=url - ) + transport = PhoenixChannelWebsocketsTransport(channel_name="test_channel", url=url) query = gql(query_str) with pytest.raises(TransportServerError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: await session.execute(query) @@ -476,12 +468,12 @@ async def test_phoenix_channel_unsubscribe_error(server, query_str): # Reduce close_timeout. These tests will wait for an unsubscribe # reply that will never come... - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): break @@ -504,13 +496,13 @@ async def test_phoenix_channel_unsubscribe_error_forcing(server, query_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name="test_channel", url=url, close_timeout=1 ) query = gql(query_str) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for _result in session.subscribe(query): await session.transport._send_stop_message(2) await asyncio.sleep(10 * MS) diff --git a/tests/test_phoenix_channel_subscription.py b/tests/test_phoenix_channel_subscription.py index 25ca0f0b..ecda9c38 100644 --- a/tests/test_phoenix_channel_subscription.py +++ b/tests/test_phoenix_channel_subscription.py @@ -191,14 +191,14 @@ async def test_phoenix_channel_subscription(server, subscription_str, end_count) path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=5 ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: generator = session.subscribe(subscription) async for result in generator: @@ -240,14 +240,14 @@ async def test_phoenix_channel_subscription_no_break(server, subscription_str): async def testing_stopping_without_break(): - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, close_timeout=(5000 * MS) ) count = 10 subscription = gql(subscription_str.format(count=count)) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: async for result in session.subscribe(subscription): number = result["countdown"]["number"] print(f"Number received: {number}") @@ -372,12 +372,12 @@ async def test_phoenix_channel_heartbeat(server, subscription_str): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = PhoenixChannelWebsocketsTransport( + transport = PhoenixChannelWebsocketsTransport( channel_name=test_channel, url=url, heartbeat_interval=0.1 ) subscription = gql(heartbeat_subscription_str) - async with Client(transport=sample_transport) as session: + async with Client(transport=transport) as session: i = 0 generator = session.subscribe(subscription) async for result in generator: diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 31f2712b..0a3b37fd 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -118,10 +118,10 @@ async def test_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=1) with pytest.raises(asyncio.TimeoutError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -257,10 +257,10 @@ async def test_websocket_server_does_not_ack(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -276,10 +276,10 @@ async def test_websocket_server_closing_directly(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) with pytest.raises(TransportConnectionFailed): - async with Client(transport=sample_transport): + async with Client(transport=transport): pass @@ -323,10 +323,10 @@ async def test_websocket_server_sending_invalid_query_errors(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) # Invalid server message is ignored - async with Client(transport=sample_transport): + async with Client(transport=transport): await asyncio.sleep(2 * MS) @@ -342,9 +342,9 @@ async def test_websocket_non_regression_bug_105(server): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) # Create a coroutine which start the connection with the transport but does nothing async def client_connect(client): diff --git a/tests/test_websocket_subscription.py b/tests/test_websocket_subscription.py index 89acd635..8d2fd152 100644 --- a/tests/test_websocket_subscription.py +++ b/tests/test_websocket_subscription.py @@ -420,11 +420,9 @@ async def test_websocket_subscription_with_keepalive_with_timeout_ok( if PyPy: keep_alive_timeout = 200 * MS - sample_transport = WebsocketsTransport( - url=url, keep_alive_timeout=keep_alive_timeout - ) + transport = WebsocketsTransport(url=url, keep_alive_timeout=keep_alive_timeout) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -452,9 +450,9 @@ async def test_websocket_subscription_with_keepalive_with_timeout_nok( path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) + transport = WebsocketsTransport(url=url, keep_alive_timeout=(1 * MS)) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -480,9 +478,9 @@ def test_websocket_subscription_sync(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -506,9 +504,9 @@ def test_websocket_subscription_sync_user_exception(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -537,9 +535,9 @@ def test_websocket_subscription_sync_break(server, subscription_str): url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -578,9 +576,9 @@ def test_websocket_subscription_sync_graceful_shutdown(server, subscription_str) url = f"ws://{server.hostname}:{server.port}/graphql" print(f"url = {url}") - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) @@ -630,9 +628,9 @@ async def test_websocket_subscription_running_in_thread( def test_code(): path = "/graphql" url = f"ws://{server.hostname}:{server.port}{path}" - sample_transport = WebsocketsTransport(url=url) + transport = WebsocketsTransport(url=url) - client = Client(transport=sample_transport) + client = Client(transport=transport) count = 10 subscription = gql(subscription_str.format(count=count)) From ed18fd4f85ed4761e9abb907834cda37ced237a0 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 12 Mar 2025 22:19:19 +0100 Subject: [PATCH 3/5] faster tests --- tests/test_aiohttp_websocket_exceptions.py | 2 +- tests/test_aiohttp_websocket_graphqlws_exceptions.py | 2 +- tests/test_aiohttp_websocket_query.py | 2 +- tests/test_graphqlws_exceptions.py | 2 +- tests/test_websocket_exceptions.py | 2 +- tests/test_websocket_query.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index 30776d61..f5dd8964 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -118,7 +118,7 @@ async def test_aiohttp_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index a7548cce..789bdc01 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -117,7 +117,7 @@ async def test_aiohttp_websocket_graphqlws_server_does_not_send_ack( url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=1) + transport = AIOHTTPWebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index cf91d148..2d13b70c 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -319,7 +319,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( await session.execute(query) # Then we do other things - await asyncio.sleep(1000 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index 4cf8b89c..bed455eb 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -111,7 +111,7 @@ async def test_graphqlws_server_does_not_send_ack(graphqlws_server, query_str): url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}/graphql" - transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 0a3b37fd..845e2fc7 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -118,7 +118,7 @@ async def test_websocket_server_does_not_send_ack(server, query_str): url = f"ws://{server.hostname}:{server.port}/graphql" - transport = WebsocketsTransport(url=url, ack_timeout=1) + transport = WebsocketsTransport(url=url, ack_timeout=0.1) with pytest.raises(asyncio.TimeoutError): async with Client(transport=transport): diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index 732d686f..be255850 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -288,7 +288,7 @@ async def test_websocket_server_closing_after_first_query(client_and_server, que await session.execute(query) # Then we do other things - await asyncio.sleep(100 * MS) + await asyncio.sleep(10 * MS) # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception From a1eae602fcaf66e1ac9984f72167327fac82ff31 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 14 Mar 2025 23:41:23 +0100 Subject: [PATCH 4/5] Using TransportConnectionFailed instead of TransportClosed Now the reconnecting session will reconnect as soon as it detects that the connection failed. Better log messages --- gql/client.py | 14 +-- gql/transport/common/adapters/aiohttp.py | 11 ++- gql/transport/common/adapters/websockets.py | 10 ++- gql/transport/common/base.py | 29 +++--- tests/test_aiohttp_websocket_exceptions.py | 9 ++ ..._aiohttp_websocket_graphqlws_exceptions.py | 7 +- ...iohttp_websocket_graphqlws_subscription.py | 89 ++++++++++--------- tests/test_aiohttp_websocket_query.py | 3 +- tests/test_graphqlws_exceptions.py | 27 +++++- tests/test_graphqlws_subscription.py | 89 ++++++++++--------- tests/test_websocket_exceptions.py | 27 +++++- tests/test_websocket_query.py | 7 +- 12 files changed, 204 insertions(+), 118 deletions(-) diff --git a/gql/client.py b/gql/client.py index faf3230a..99cd6e46 100644 --- a/gql/client.py +++ b/gql/client.py @@ -35,7 +35,7 @@ from .graphql_request import GraphQLRequest from .transport.async_transport import AsyncTransport -from .transport.exceptions import TransportClosed, TransportQueryError +from .transport.exceptions import TransportConnectionFailed, TransportQueryError from .transport.local_schema import LocalSchemaTransport from .transport.transport import Transport from .utilities import build_client_schema, get_introspection_query_ast @@ -1730,6 +1730,7 @@ async def _connection_loop(self): # Then wait for the reconnect event self._reconnect_request_event.clear() await self._reconnect_request_event.wait() + await self.transport.close() async def start_connecting_task(self): """Start the task responsible to restart the connection @@ -1758,7 +1759,7 @@ async def _execute_once( **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent method _execute but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ try: @@ -1770,7 +1771,7 @@ async def _execute_once( parse_result=parse_result, **kwargs, ) - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise @@ -1786,7 +1787,8 @@ async def _execute( **kwargs: Any, ) -> ExecutionResult: """Same Coroutine as parent, but with optional retries - and requesting a reconnection if we receive a TransportClosed exception. + and requesting a reconnection if we receive a + TransportConnectionFailed exception. """ return await self._execute_with_retries( @@ -1808,7 +1810,7 @@ async def _subscribe( **kwargs: Any, ) -> AsyncGenerator[ExecutionResult, None]: """Same Async generator as parent method _subscribe but requesting a - reconnection if we receive a TransportClosed exception. + reconnection if we receive a TransportConnectionFailed exception. """ inner_generator: AsyncGenerator[ExecutionResult, None] = super()._subscribe( @@ -1824,7 +1826,7 @@ async def _subscribe( async for result in inner_generator: yield result - except TransportClosed: + except TransportConnectionFailed: self._reconnect_request_event.set() raise diff --git a/gql/transport/common/adapters/aiohttp.py b/gql/transport/common/adapters/aiohttp.py index d5b16a82..d2e1a346 100644 --- a/gql/transport/common/adapters/aiohttp.py +++ b/gql/transport/common/adapters/aiohttp.py @@ -178,12 +178,14 @@ async def send(self, message: str) -> None: TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionFailed("Connection is already closed") + raise TransportConnectionFailed("WebSocket connection is already closed") try: await self.websocket.send_str(message) - except ConnectionResetError as e: - raise TransportConnectionFailed("Connection was closed") from e + except Exception as e: + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -200,6 +202,9 @@ async def receive(self) -> str: raise TransportConnectionFailed("Connection is already closed") while True: + # Should not raise any exception: + # https://docs.aiohttp.org/en/stable/_modules/aiohttp/client_ws.html + # #ClientWebSocketResponse.receive ws_message = await self.websocket.receive() # Ignore low-level ping and pong received diff --git a/gql/transport/common/adapters/websockets.py b/gql/transport/common/adapters/websockets.py index bf8574ca..6d248e71 100644 --- a/gql/transport/common/adapters/websockets.py +++ b/gql/transport/common/adapters/websockets.py @@ -86,12 +86,14 @@ async def send(self, message: str) -> None: TransportConnectionFailed: If connection closed """ if self.websocket is None: - raise TransportConnectionFailed("Connection is already closed") + raise TransportConnectionFailed("WebSocket connection is already closed") try: await self.websocket.send(message) except Exception as e: - raise TransportConnectionFailed("Connection was closed") from e + raise TransportConnectionFailed( + f"Error trying to send data: {type(e).__name__}" + ) from e async def receive(self) -> str: """Receive message from the WebSocket server. @@ -111,7 +113,9 @@ async def receive(self) -> str: try: data = await self.websocket.recv() except Exception as e: - raise TransportConnectionFailed("Connection was closed") from e + raise TransportConnectionFailed( + f"Error trying to receive data: {type(e).__name__}" + ) from e # websocket.recv() can return either str or bytes # In our case, we should receive only str here diff --git a/gql/transport/common/base.py b/gql/transport/common/base.py index 664353df..cae8f488 100644 --- a/gql/transport/common/base.py +++ b/gql/transport/common/base.py @@ -127,11 +127,13 @@ async def _send(self, message: str) -> None: """Send the provided message to the adapter connection and log the message""" if not self._connected: - raise TransportClosed( - "Transport is not connected" - ) from self.close_exception + if isinstance(self.close_exception, TransportConnectionFailed): + raise self.close_exception + else: + raise TransportConnectionFailed() from self.close_exception try: + # Can raise TransportConnectionFailed await self.adapter.send(message) log.info(">>> %s", message) except TransportConnectionFailed as e: @@ -143,7 +145,7 @@ async def _receive(self) -> str: # It is possible that the connection has been already closed in another task if not self._connected: - raise TransportClosed("Transport is already closed") + raise TransportConnectionFailed() from self.close_exception # Wait for the next frame. # Can raise TransportConnectionFailed or TransportProtocolError @@ -214,8 +216,6 @@ async def _receive_data_loop(self) -> None: except (TransportConnectionFailed, TransportProtocolError) as e: await self._fail(e, clean_close=False) break - except TransportClosed: - break # Parse the answer try: @@ -503,9 +503,10 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: except Exception as exc: # pragma: no cover log.warning("Ignoring exception in _clean_close: " + repr(exc)) - log.debug( - f"_close_coro: sending exception to {len(self.listeners)} listeners" - ) + if log.isEnabledFor(logging.DEBUG): + log.debug( + f"_close_coro: sending exception to {len(self.listeners)} listeners" + ) # Send an exception to all remaining listeners for query_id, listener in self.listeners.items(): @@ -532,7 +533,15 @@ async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: log.debug("_close_coro: exiting") async def _fail(self, e: Exception, clean_close: bool = True) -> None: - log.debug("_fail: starting with exception: " + repr(e)) + if log.isEnabledFor(logging.DEBUG): + import inspect + + current_frame = inspect.currentframe() + assert current_frame is not None + caller_frame = current_frame.f_back + assert caller_frame is not None + caller_name = inspect.getframeinfo(caller_frame).function + log.debug(f"_fail from {caller_name}: " + repr(e)) if self.close_task is None: diff --git a/tests/test_aiohttp_websocket_exceptions.py b/tests/test_aiohttp_websocket_exceptions.py index f5dd8964..2fb6722c 100644 --- a/tests/test_aiohttp_websocket_exceptions.py +++ b/tests/test_aiohttp_websocket_exceptions.py @@ -301,6 +301,15 @@ async def test_aiohttp_websocket_server_closing_after_ack(aiohttp_client_and_ser query = gql("query { hello }") + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed): + await session.execute(query) + + await session.transport.wait_closed() + + print("\n Trying to execute second query.\n") + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_exceptions.py b/tests/test_aiohttp_websocket_graphqlws_exceptions.py index 789bdc01..52bc27a4 100644 --- a/tests/test_aiohttp_websocket_graphqlws_exceptions.py +++ b/tests/test_aiohttp_websocket_graphqlws_exceptions.py @@ -5,7 +5,6 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -264,10 +263,14 @@ async def test_aiohttp_websocket_graphqlws_server_closing_after_ack( query = gql("query { hello }") + print("\n Trying to execute first query.\n") + with pytest.raises(TransportConnectionFailed): await session.execute(query) await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index e8832217..2c417efe 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -11,7 +11,7 @@ from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, PyPy, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the aiohttp AND websockets marker pytestmark = [pytest.mark.aiohttp, pytest.mark.websockets] @@ -821,7 +821,6 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( ): from gql.transport.aiohttp_websockets import AIOHTTPWebsocketsTransport - from gql.transport.exceptions import TransportClosed path = "/graphql" url = f"ws://{graphqlws_server.hostname}:{graphqlws_server.port}{path}" @@ -839,56 +838,64 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except TransportConnectionFailed: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect - generator = None + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - generator = session.subscribe(subscription) - async for result in generator: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except (TransportClosed, TransportConnectionFailed): - if generator: - await generator.aclose() + except TransportConnectionFailed: pass - timeout = 50 - - if PyPy: - timeout = 500 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - await asyncio.sleep(timeout * MS) + assert transport._connected is False - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - generator = session.subscribe(subscription) - async for result in generator: + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - number = result["number"] - print(f"Number received: {number}") + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - assert number == count - count -= 1 + assert number == count + count -= 1 - await generator.aclose() + await generator.aclose() - assert count == -1 + assert count == -1 + # Close the reconnecting session await client.close_async() + + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break + + assert transport._connected is False diff --git a/tests/test_aiohttp_websocket_query.py b/tests/test_aiohttp_websocket_query.py index 2d13b70c..a3087d78 100644 --- a/tests/test_aiohttp_websocket_query.py +++ b/tests/test_aiohttp_websocket_query.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportQueryError, TransportServerError, @@ -323,7 +322,7 @@ async def test_aiohttp_websocket_server_closing_after_first_query( # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) diff --git a/tests/test_graphqlws_exceptions.py b/tests/test_graphqlws_exceptions.py index bed455eb..6f30c8da 100644 --- a/tests/test_graphqlws_exceptions.py +++ b/tests/test_graphqlws_exceptions.py @@ -5,7 +5,6 @@ from gql import Client, gql from gql.transport.exceptions import ( - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -251,10 +250,32 @@ async def test_graphqlws_server_closing_after_ack(client_and_graphqlws_server): query = gql("query { hello }") - with pytest.raises(TransportClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 94028d26..2fdcb41a 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -11,7 +11,7 @@ from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError -from .conftest import MS, PyPy, WebSocketServerHelper +from .conftest import MS, WebSocketServerHelper # Marking all tests in this file with the websockets marker pytestmark = pytest.mark.websockets @@ -814,7 +814,6 @@ async def test_graphqlws_subscription_reconnecting_session( graphqlws_server, subscription_str, execute_instead_of_subscribe ): - from gql.transport.exceptions import TransportClosed from gql.transport.websockets import WebsocketsTransport path = "/graphql" @@ -833,56 +832,64 @@ async def test_graphqlws_subscription_reconnecting_session( reconnecting=True, retry_connect=False, retry_execute=False ) - # First we make a subscription which will cause a disconnect in the backend - # (count=8) - try: - print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") - async for result in session.subscribe(subscription_with_disconnect): - pass - except TransportConnectionFailed: - pass - - await asyncio.sleep(50 * MS) - - # Then with the same session handle, we make a subscription or an execute - # which will detect that the transport is closed so that the client could - # try to reconnect - generator = None + # First we make a query or subscription which will cause a disconnect + # in the backend (count=8) try: if execute_instead_of_subscribe: - print("\nEXECUTION_2\n") - await session.execute(subscription) + print("\nEXECUTION_1\n") + await session.execute(subscription_with_disconnect) else: - print("\nSUBSCRIPTION_2\n") - generator = session.subscribe(subscription) - async for result in generator: + print("\nSUBSCRIPTION_1_WITH_DISCONNECT\n") + async for result in session.subscribe(subscription_with_disconnect): pass - except (TransportClosed, TransportConnectionFailed): - if generator: - await generator.aclose() + except TransportConnectionFailed: pass - timeout = 50 - - if PyPy: - timeout = 500 + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break - await asyncio.sleep(timeout * MS) + assert transport._connected is False - # And finally with the same session handle, we make a subscription - # which works correctly - print("\nSUBSCRIPTION_3\n") - generator = session.subscribe(subscription) - async for result in generator: + # Wait for reconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if transport._connected: + print(f"\nConnected again in {i+1} MS") + break - number = result["number"] - print(f"Number received: {number}") + assert transport._connected is True + + # Then after the reconnection, we make a query or a subscription + if execute_instead_of_subscribe: + print("\nEXECUTION_2\n") + result = await session.execute(subscription) + assert result["number"] == 10 + else: + print("\nSUBSCRIPTION_2\n") + generator = session.subscribe(subscription) + async for result in generator: + number = result["number"] + print(f"Number received: {number}") - assert number == count - count -= 1 + assert number == count + count -= 1 - await generator.aclose() + await generator.aclose() - assert count == -1 + assert count == -1 + # Close the reconnecting session await client.close_async() + + # Wait for disconnect + for i in range(200): + await asyncio.sleep(1 * MS) + if not transport._connected: + print(f"\nDisconnected in {i+1} MS") + break + + assert transport._connected is False diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index 845e2fc7..b6169468 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportProtocolError, TransportQueryError, @@ -296,14 +295,36 @@ async def test_websocket_server_closing_after_ack(client_and_server): query = gql("query { hello }") - with pytest.raises(TransportClosed): + print("\n Trying to execute first query.\n") + + with pytest.raises(TransportConnectionFailed) as exc1: await session.execute(query) + exc1_cause = exc1.value.__cause__ + exc1_cause_str = f"{type(exc1_cause).__name__}:{exc1_cause!s}" + + print(f"\n First query Exception cause: {exc1_cause_str}\n") + + assert ( + exc1_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + await session.transport.wait_closed() - with pytest.raises(TransportClosed): + print("\n Trying to execute second query.\n") + + with pytest.raises(TransportConnectionFailed) as exc2: await session.execute(query) + exc2_cause = exc2.value.__cause__ + exc2_cause_str = f"{type(exc2_cause).__name__}:{exc2_cause!s}" + + print(f" Second query Exception cause: {exc2_cause_str}\n") + + assert ( + exc2_cause_str == "ConnectionClosedOK:received 1000 (OK); then sent 1000 (OK)" + ) + async def server_sending_invalid_query_errors(ws): await WebSocketServerHelper.send_connection_ack(ws) diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index be255850..979bb99b 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -8,7 +8,6 @@ from gql import Client, gql from gql.transport.exceptions import ( TransportAlreadyConnected, - TransportClosed, TransportConnectionFailed, TransportQueryError, TransportServerError, @@ -292,7 +291,7 @@ async def test_websocket_server_closing_after_first_query(client_and_server, que # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query) @@ -661,7 +660,7 @@ async def test_websocket_adapter_connection_closed(server): # Close adapter connection manually (should not be done) await transport.adapter.close() - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query1) # Check client is disconnect here @@ -689,5 +688,5 @@ async def test_websocket_transport_closed_in_receive(server): # await transport.adapter.close() transport._connected = False - with pytest.raises(TransportClosed): + with pytest.raises(TransportConnectionFailed): await session.execute(query1) From cc3c3dde91f406f1d276843d2aa7848cd4d82218 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 14 Mar 2025 23:57:58 +0100 Subject: [PATCH 5/5] Try to fix test for fast computers --- tests/test_aiohttp_websocket_graphqlws_subscription.py | 2 -- tests/test_graphqlws_subscription.py | 2 -- 2 files changed, 4 deletions(-) diff --git a/tests/test_aiohttp_websocket_graphqlws_subscription.py b/tests/test_aiohttp_websocket_graphqlws_subscription.py index 2c417efe..7c000d01 100644 --- a/tests/test_aiohttp_websocket_graphqlws_subscription.py +++ b/tests/test_aiohttp_websocket_graphqlws_subscription.py @@ -858,8 +858,6 @@ async def test_aiohttp_websocket_graphqlws_subscription_reconnecting_session( print(f"\nDisconnected in {i+1} MS") break - assert transport._connected is False - # Wait for reconnect for i in range(200): await asyncio.sleep(1 * MS) diff --git a/tests/test_graphqlws_subscription.py b/tests/test_graphqlws_subscription.py index 2fdcb41a..b4c6a17b 100644 --- a/tests/test_graphqlws_subscription.py +++ b/tests/test_graphqlws_subscription.py @@ -852,8 +852,6 @@ async def test_graphqlws_subscription_reconnecting_session( print(f"\nDisconnected in {i+1} MS") break - assert transport._connected is False - # Wait for reconnect for i in range(200): await asyncio.sleep(1 * MS)