From 1be0d6bd2715f02e34233e83e9eb5b5d9ffa0557 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 25 Mar 2020 13:39:11 +0100 Subject: [PATCH 01/46] Implementation of a websockets transport using asyncio Necessitate Python 3.6 Allows a connection to an Apollo GraphQL websocket endpoint Supports fetching schema from transport using the introspection query Supports queries, mutations AND subscriptions Supports connection using a client ssl certificate Only one request per connection for now --- README.md | 98 +++++++++++ gql/client.py | 22 +++ gql/transport/websockets.py | 220 ++++++++++++++++++++++++ setup.py | 3 +- tests_py36/test_websockets_transport.py | 53 ++++++ 5 files changed, 395 insertions(+), 1 deletion(-) create mode 100644 gql/transport/websockets.py create mode 100644 tests_py36/test_websockets_transport.py diff --git a/README.md b/README.md index 4085d239..4e268308 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,104 @@ query = gql(''' client.execute(query) ``` +## Websockets transport with asyncio + +It is possible to use the websockets transport using the `asyncio` library. +Python3.6 is required for this transport. + +The websockets transport uses the apollo protocol described here: + +[Apollo websockets transport protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) + +This transport allows to do subscriptions ! + +For the moment, only one request is done for each websocket connection + +```python +import logging +logging.basicConfig(level=logging.INFO) + +from gql import gql, Client +from gql.transport.websockets import WebsocketsTransport +import asyncio + +async def main() + + sample_transport = WebsocketsTransport( + url='wss://countries.trevorblades.com/graphql', + ssl=True, + headers={'Authorization': 'token'} + ) + + client = Client(transport=sample_transport) + + # Fetch schema (optional) + await client.fetch_schema() + + # Execute single query + query = gql(''' + query getContinents { + continents { + code + name + } + } + ''') + result = await client.execute_async(query) + print (f'result data = {result.data}, errors = {result.errors}') + + # Request subscription + subscription = gql(''' + subscription { + somethingChanged: { + id + } + } + ''') + async for result in client.subscribe(subscription): + print (f'result.data = {result.data}') + +asyncio.run(main()) +``` + +### Websockets SSL + +If you need to connect to an ssl encrypted endpoint: + +* use _wss_ instead of _ws_ in the url of the transport +* set the parameter ssl to True + +```python +import ssl + +sample_transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + headers={'Authorization': 'token'}, + ssl=True +) +``` + +If you have a self-signed ssl certificate, you need to provide an ssl_context with the server public certificate: + +```python +import pathlib +import ssl + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +localhost_pem = pathlib.Path(__file__).with_name("YOUR_SERVER_PUBLIC_CERTIFICATE.pem") +ssl_context.load_verify_locations(localhost_pem) + +sample_transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + ssl=ssl_context +) +``` + +If you have also need to have a client ssl certificate, add: + +```python +ssl_context.load_cert_chain(certfile='YOUR_CLIENT_CERTIFICATE.pem', keyfile='YOUR_CLIENT_CERTIFICATE_KEY.key') +``` ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/gql/client.py b/gql/client.py index 8c583259..e5304a31 100644 --- a/gql/client.py +++ b/gql/client.py @@ -34,6 +34,10 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" + if hasattr(transport, 'USING_ASYNCIO'): + assert ( + not transport.USING_ASYNCIO + ), "With an asyncio transport, please use 'await client.fetch_schema()' instead of fetch_schema_from_transport=True" introspection = transport.execute(parse(introspection_query)).data if introspection: assert not schema, "Cant provide introspection and schema at the same time" @@ -93,3 +97,21 @@ def _get_result(self, document, *args, **kwargs): retries_count += 1 raise RetryError(retries_count, last_exception) + + async def subscribe(self, document, *args, **kwargs): + if self.schema: + self.validate(document) + + async for result in self.transport.subscribe(document, *args, **kwargs): + yield result + + async def execute_async(self, document, *args, **kwargs): + if self.schema: + self.validate(document) + + return await self.transport.single_query(document, *args, **kwargs) + + async def fetch_schema(self): + execution_result = await self.transport.single_query(parse(introspection_query)) + self.introspection = execution_result.data + self.schema = build_client_schema(self.introspection) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py new file mode 100644 index 00000000..4966ff6e --- /dev/null +++ b/gql/transport/websockets.py @@ -0,0 +1,220 @@ +from __future__ import absolute_import + +from typing import Any, Dict, Union + +import websockets +import asyncio +import json +import logging + +from graphql.execution import ExecutionResult +from graphql.language.ast import Document +from graphql.language.printer import print_ast + +from gql.transport import Transport + +log = logging.getLogger(__name__) + +class WebsocketsTransport(Transport): + """Transport to execute GraphQL queries on remote servers with a websocket connection. + + This transport use asyncio + The transport uses the websockets library in order to send requests on a websocket connection. + + See README.md for Usage + """ + + USING_ASYNCIO = True + + def __init__( + self, + url, + headers=None, + ssl=False, + ): + """Initialize the transport with the given request parameters. + + :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. + :param headers: Dict of HTTP Headers. + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + """ + self.url = url + self.ssl = ssl + self.headers = headers + self.next_query_id = 1 + + async def _send_message(self, websocket, message): + """Send the provided message to the websocket connection and log the message + """ + + await websocket.send(message) + log.info('>>> %s', message) + + async def _wait_answer(self, websocket): + """Wait the next message from the websocket connection and log the answer + """ + + answer = await websocket.recv() + log.info('<<< %s', answer) + + return answer + + async def _send_init_message_and_wait_ack(self, websocket): + """Send an init message to the provided websocket then wait for the connection ack + + If the answer is not a connection_ack message, we will return an Exception + """ + + await self._send_message(websocket, '{"type":"connection_init","payload":{}}') + + init_answer = await self._wait_answer(websocket) + + answer_type, answer_id, execution_result = self._parse_answer(init_answer) + + if answer_type != 'connection_ack': + raise Exception('Websocket server did not return a connection ack') + + async def _send_stop_message(self, websocket, query_id): + """Send a stop message to the provided websocket connection for the provided query_id + + The server should afterwards return a 'complete' message + """ + + stop_message = json.dumps({ + 'id': str(query_id), + 'type': 'stop' + }) + + await self._send_message(websocket, stop_message) + + async def _send_query(self, websocket, document, variable_values, operation_name): + """Send a query to the provided websocket connection + + We use an incremented id to reference the query + + Returns the used id for this query + """ + + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps({ + 'id': str(query_id), + 'type': 'start', + 'payload': { + 'variables': variable_values or {}, + #'extensions': {}, + 'operationName': operation_name or '', + 'query': print_ast(document) + } + }) + + await self._send_message(websocket, query_str) + + return query_id + + def _parse_answer(self, answer): + """Parse the answer received from the server + + Returns a list consisting of: + - the answer_type (between: 'connection_ack', 'ka', 'connection_error', 'data', 'error', 'complete') + - the answer id (Integer) if received or None + - an execution Result if the answer_type is 'data' or None + """ + + answer_type = None + answer_id = None + execution_result = None + + try: + json_answer = json.loads(answer) + + if not isinstance(json_answer, dict): + raise ValueError + + answer_type = json_answer.get('type') + + if answer_type in ['data', 'error', 'complete']: + answer_id = int(json_answer.get('id')) + + if answer_type == 'data': + result = json_answer.get('payload') + + if 'errors' not in result and 'data' not in result: + raise ValueError + + execution_result = ExecutionResult(errors=result.get('errors'), data=result.get('data')) + + elif answer_type == 'error': + raise Exception('Websocket server error') + + elif answer_type == 'ka': + # KeepAlive message + pass + elif answer_type == 'connection_ack': + pass + elif answer_type == 'connection_error': + raise Exception('Websocket Connection Error') + else: + raise ValueError + + except ValueError: + raise Exception('Websocket server did not return a GraphQL result') + + return (answer_type, answer_id, execution_result) + + + async def subscribe(self, document, variable_values=None, operation_name=None): + """Send a query and receive the results using a python async generator + + The query can be a graphql query, mutation or subscription + + The results are sent as an ExecutionResult object + """ + + # Connection to the specified url + async with websockets.connect( + self.url, + ssl=self.ssl, + extra_headers=self.headers, + subprotocols=['graphql-ws'] + ) as websocket: + + # Send the init message and wait for the ack from the server + await self._send_init_message_and_wait_ack(websocket) + + # Send the query and receive the id + query_id = await self._send_query(websocket, document, variable_values, operation_name) + + # Loop over the received answers + while True: + + # Wait the next answer from the websocket server + answer = await self._wait_answer(websocket) + + # Parse the answer + answer_type, answer_id, execution_result = self._parse_answer(answer) + + # If the received answer id corresponds to the query id, + # Then we will yield the results back as an ExecutionResult object + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output and disconnect from the server + if answer_id == query_id: + if execution_result is not None: + yield execution_result + + elif answer_type == 'complete': + return + + async def single_query(self, document, variable_values=None, operation_name=None): + """Send a query but close the connection as soon as we have the first answer + + The result is sent as an ExecutionResult object + """ + async for result in self.subscribe(document, variable_values, operation_name): + return result + + def execute(self, document): + raise NotImplementedError( + "You should use the async function 'execute_async' for this transport" + ) diff --git a/setup.py b/setup.py index 2cad42fd..380b4fac 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,8 @@ 'six>=1.10.0', 'graphql-core>=2,<3', 'promise>=2.0,<3', - 'requests>=2.12,<3' + 'requests>=2.12,<3', + 'websockets' ] tests_require = [ diff --git a/tests_py36/test_websockets_transport.py b/tests_py36/test_websockets_transport.py new file mode 100644 index 00000000..fb4e7830 --- /dev/null +++ b/tests_py36/test_websockets_transport.py @@ -0,0 +1,53 @@ +import logging + +logging.basicConfig(level=logging.INFO) + +from gql import gql, Client +from gql.transport.websockets import WebsocketsTransport +from graphql.execution import ExecutionResult +from typing import Dict + +import asyncio +import pytest + +@pytest.mark.asyncio +async def test_websocket_query(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url='wss://countries.trevorblades.com/graphql', + ssl=True + ) + + # Instanciate client + client = Client(transport=sample_transport) + + query = gql(''' + query getContinents { + continents { + code + name + } + } + ''') + + # Fetch schema + await client.fetch_schema() + + # Execute query + result = await client.execute_async(query) + + # Verify result + assert isinstance(result, ExecutionResult) + assert result.errors == None + + assert isinstance(result.data, Dict) + + continents = result.data['continents'] + + africa = continents[0] + + assert africa['code'] == 'AF' + + # Sleep 1 second to allow the connections to end + await asyncio.sleep(1) From 159a35cb3cdf6e334634461c143b37f67ba9d195 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 25 Mar 2020 15:57:36 +0100 Subject: [PATCH 02/46] Add missing 'pathlib' dependency --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 380b4fac..215d31e0 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ 'graphql-core>=2,<3', 'promise>=2.0,<3', 'requests>=2.12,<3', + 'pathlib', 'websockets' ] From d65c697363c203fcea139a507a0928769ac12951 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 26 Mar 2020 22:49:06 +0100 Subject: [PATCH 03/46] Now using AsyncClient and AsyncTransport classes It is now possible to execute multiple queries in parallel using the same websocket connection --- README.md | 98 +++++++++---- gql/__init__.py | 4 +- gql/client.py | 23 +++- gql/transport/__init__.py | 39 ++++++ gql/transport/websockets.py | 175 +++++++++++++++++------- tests_py36/test_websockets_transport.py | 109 +++++++++++---- 6 files changed, 339 insertions(+), 109 deletions(-) diff --git a/README.md b/README.md index 4e268308..955f4595 100644 --- a/README.md +++ b/README.md @@ -122,19 +122,17 @@ The websockets transport uses the apollo protocol described here: [Apollo websockets transport protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) -This transport allows to do subscriptions ! - -For the moment, only one request is done for each websocket connection +This transport allows to do multiple queries, mutations and subscriptions on the same websocket connection ```python import logging logging.basicConfig(level=logging.INFO) -from gql import gql, Client +from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport import asyncio -async def main() +async def main(): sample_transport = WebsocketsTransport( url='wss://countries.trevorblades.com/graphql', @@ -142,33 +140,33 @@ async def main() headers={'Authorization': 'token'} ) - client = Client(transport=sample_transport) - - # Fetch schema (optional) - await client.fetch_schema() - - # Execute single query - query = gql(''' - query getContinents { - continents { - code - name - } - } - ''') - result = await client.execute_async(query) - print (f'result data = {result.data}, errors = {result.errors}') + async with AsyncClient(transport=sample_transport) as client: - # Request subscription - subscription = gql(''' - subscription { - somethingChanged: { - id + # Fetch schema (optional) + await client.fetch_schema() + + # Execute single query + query = gql(''' + query getContinents { + continents { + code + name + } + } + ''') + result = await client.execute(query) + print (f'result data = {result.data}, errors = {result.errors}') + + # Request subscription + subscription = gql(''' + subscription { + somethingChanged { + id + } } - } - ''') - async for result in client.subscribe(subscription): - print (f'result.data = {result.data}') + ''') + async for result in client.subscribe(subscription): + print (f'result.data = {result.data}') asyncio.run(main()) ``` @@ -212,6 +210,46 @@ If you have also need to have a client ssl certificate, add: ssl_context.load_cert_chain(certfile='YOUR_CLIENT_CERTIFICATE.pem', keyfile='YOUR_CLIENT_CERTIFICATE_KEY.key') ``` +### Websockets advanced usage + +It is possible to send multiple GraphQL queries (query, mutation or subscription) in parallel, +on the same websocket connection, using asyncio tasks + +```python + +async def execute_query1(): + result = await client.execute(query1) + print (f'result data = {result.data}, errors = {result.errors}') + +async def execute_query2(): + result = await client.execute(query2) + print (f'result data = {result.data}, errors = {result.errors}') + +async def execute_subscription1(): + async for result in client.subscribe(subscription1): + print (f'result data = {result.data}, errors = {result.errors}') + +async def execute_subscription2(): + async for result in client.subscribe(subscription2): + print (f'result data = {result.data}, errors = {result.errors}') + +task1 = asyncio.create_task(execute_query1()) +task2 = asyncio.create_task(execute_query2()) +task3 = asyncio.create_task(execute_subscription1()) +task4 = asyncio.create_task(execute_subscription2()) + +await task1 +await task2 +await task3 +await task4 +``` + +Subscriptions tasks can be stopped at any time by running + +```python +task.cancel() +``` + ## Contributing See [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/gql/__init__.py b/gql/__init__.py index cd6c0088..f3c8f920 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,4 +1,4 @@ from .gql import gql -from .client import Client +from .client import Client, AsyncClient -__all__ = ["gql", "Client"] +__all__ = ["gql", "Client", "AsyncClient"] diff --git a/gql/client.py b/gql/client.py index e5304a31..cd51e644 100644 --- a/gql/client.py +++ b/gql/client.py @@ -4,6 +4,7 @@ from graphql.validation import validate from .transport.local_schema import LocalSchemaTransport +from gql.transport import AsyncTransport log = logging.getLogger(__name__) @@ -34,10 +35,9 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" - if hasattr(transport, 'USING_ASYNCIO'): - assert ( - not transport.USING_ASYNCIO - ), "With an asyncio transport, please use 'await client.fetch_schema()' instead of fetch_schema_from_transport=True" + assert ( + not isinstance(transport, AsyncTransport) + ), "With an asyncio transport, please use 'await client.fetch_schema()' instead of fetch_schema_from_transport=True" introspection = transport.execute(parse(introspection_query)).data if introspection: assert not schema, "Cant provide introspection and schema at the same time" @@ -98,6 +98,8 @@ def _get_result(self, document, *args, **kwargs): raise RetryError(retries_count, last_exception) +class AsyncClient(Client): + async def subscribe(self, document, *args, **kwargs): if self.schema: self.validate(document) @@ -105,13 +107,20 @@ async def subscribe(self, document, *args, **kwargs): async for result in self.transport.subscribe(document, *args, **kwargs): yield result - async def execute_async(self, document, *args, **kwargs): + async def execute(self, document, *args, **kwargs): if self.schema: self.validate(document) - return await self.transport.single_query(document, *args, **kwargs) + return await self.transport.execute(document, *args, **kwargs) async def fetch_schema(self): - execution_result = await self.transport.single_query(parse(introspection_query)) + execution_result = await self.transport.execute(parse(introspection_query)) self.introspection = execution_result.data self.schema = build_client_schema(self.introspection) + + async def __aenter__(self): + await self.transport.connect() + return self + + async def __aexit__(self, *args): + await self.transport.close() diff --git a/gql/transport/__init__.py b/gql/transport/__init__.py index 89c636a0..6193ddd2 100644 --- a/gql/transport/__init__.py +++ b/gql/transport/__init__.py @@ -20,3 +20,42 @@ def execute(self, document): raise NotImplementedError( "Any Transport subclass must implement execute method" ) + +@six.add_metaclass(abc.ABCMeta) +class AsyncTransport(Transport): + + @abc.abstractmethod + async def connect(self): + """Coroutine used to create a connection to the specified address + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + async def close(self): + """Coroutine used to Close an established connection + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + async def execute(self, document, variable_values=None, operation_name=None): + """Execute the provided document AST for either a remote or local GraphQL Schema. + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + async def subscribe(self, document, variable_values=None, operation_name=None): + """Send a query and receive the results using an async generator + + The query can be a graphql query, mutation or subscription + + The results are sent as an ExecutionResult object + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 4966ff6e..0cb24e71 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -11,11 +11,11 @@ from graphql.language.ast import Document from graphql.language.printer import print_ast -from gql.transport import Transport +from gql.transport import AsyncTransport log = logging.getLogger(__name__) -class WebsocketsTransport(Transport): +class WebsocketsTransport(AsyncTransport): """Transport to execute GraphQL queries on remote servers with a websocket connection. This transport use asyncio @@ -24,8 +24,6 @@ class WebsocketsTransport(Transport): See README.md for Usage """ - USING_ASYNCIO = True - def __init__( self, url, @@ -41,40 +39,46 @@ def __init__( self.url = url self.ssl = ssl self.headers = headers + + self.websocket = None self.next_query_id = 1 + self.listeners = {} - async def _send_message(self, websocket, message): + async def _send(self, message): """Send the provided message to the websocket connection and log the message """ - await websocket.send(message) + if not self.websocket: + raise Exception('Transport is not connected') + + await self.websocket.send(message) log.info('>>> %s', message) - async def _wait_answer(self, websocket): + async def _receive(self): """Wait the next message from the websocket connection and log the answer """ - answer = await websocket.recv() + answer = await self.websocket.recv() log.info('<<< %s', answer) return answer - async def _send_init_message_and_wait_ack(self, websocket): + async def _send_init_message_and_wait_ack(self): """Send an init message to the provided websocket then wait for the connection ack If the answer is not a connection_ack message, we will return an Exception """ - await self._send_message(websocket, '{"type":"connection_init","payload":{}}') + await self._send('{"type":"connection_init","payload":{}}') - init_answer = await self._wait_answer(websocket) + init_answer = await self._receive() answer_type, answer_id, execution_result = self._parse_answer(init_answer) if answer_type != 'connection_ack': raise Exception('Websocket server did not return a connection ack') - async def _send_stop_message(self, websocket, query_id): + async def _send_stop_message(self, query_id): """Send a stop message to the provided websocket connection for the provided query_id The server should afterwards return a 'complete' message @@ -85,9 +89,21 @@ async def _send_stop_message(self, websocket, query_id): 'type': 'stop' }) - await self._send_message(websocket, stop_message) + await self._send(stop_message) + + async def _send_connection_terminate_message(self): + """Send a connection_terminate message to the provided websocket connection + + This message indicate that the connection will disconnect + """ + + connection_terminate_message = json.dumps({ + 'type': 'connection_terminate' + }) + + await self._send(connection_terminate_message) - async def _send_query(self, websocket, document, variable_values, operation_name): + async def _send_query(self, document, variable_values, operation_name): """Send a query to the provided websocket connection We use an incremented id to reference the query @@ -109,7 +125,7 @@ async def _send_query(self, websocket, document, variable_values, operation_name } }) - await self._send_message(websocket, query_str) + await self._send(query_str) return query_id @@ -163,6 +179,25 @@ def _parse_answer(self, answer): return (answer_type, answer_id, execution_result) + async def _answer_loop(self): + + while True: + + # Wait the next answer from the websocket server + answer = await self._receive() + + # Parse the answer + answer_type, answer_id, execution_result = self._parse_answer(answer) + + # Continue if no listener exists for this id + if answer_id not in self.listeners: + continue + + # Get the related queue + queue = self.listeners[answer_id] + + # Put the answer in the queue + await queue.put((answer_type, execution_result)) async def subscribe(self, document, variable_values=None, operation_name=None): """Send a query and receive the results using a python async generator @@ -172,49 +207,97 @@ async def subscribe(self, document, variable_values=None, operation_name=None): The results are sent as an ExecutionResult object """ - # Connection to the specified url - async with websockets.connect( - self.url, - ssl=self.ssl, - extra_headers=self.headers, - subprotocols=['graphql-ws'] - ) as websocket: + # Send the query and receive the id + query_id = await self._send_query(document, variable_values, operation_name) - # Send the init message and wait for the ack from the server - await self._send_init_message_and_wait_ack(websocket) - - # Send the query and receive the id - query_id = await self._send_query(websocket, document, variable_values, operation_name) + # Create a queue to receive the answers for this query_id + self.listeners[query_id] = asyncio.Queue() + try: # Loop over the received answers while True: - # Wait the next answer from the websocket server - answer = await self._wait_answer(websocket) + # Wait for the answer from the queue of this query_id + answer_type, execution_result = await self.listeners[query_id].get() - # Parse the answer - answer_type, answer_id, execution_result = self._parse_answer(answer) + # Set the task as done in the listeners queue + self.listeners[query_id].task_done() - # If the received answer id corresponds to the query id, + # If the received answer contains data, # Then we will yield the results back as an ExecutionResult object - # If we receive a 'complete' answer from the server, - # Then we will end this async generator output and disconnect from the server - if answer_id == query_id: - if execution_result is not None: - yield execution_result + if execution_result is not None: + yield execution_result + + # If we receive a 'complete' answer from the server, + # Then we will end this async generator output and disconnect from the server + elif answer_type == 'complete': + break + + except asyncio.CancelledError as error: + await self._send_stop_message(query_id) - elif answer_type == 'complete': - return + finally: + del self.listeners[query_id] - async def single_query(self, document, variable_values=None, operation_name=None): + + async def execute(self, document, variable_values=None, operation_name=None): """Send a query but close the connection as soon as we have the first answer The result is sent as an ExecutionResult object """ - async for result in self.subscribe(document, variable_values, operation_name): - return result + generator = self.subscribe(document, variable_values, operation_name) + + async for execution_result in generator: + first_result = execution_result + generator.aclose() + + return first_result + + async def connect(self): + """Coroutine which will: + + - connect to the websocket address + - send the init message + - wait for the connection acknowledge from the server + - create an asyncio task which will be used to receive and parse the websocket answers + + Should be cleaned with a call to the close coroutine + """ + + if self.websocket == None: + + # Connection to the specified url + self.websocket = await websockets.connect( + self.url, + ssl=self.ssl, + extra_headers=self.headers, + subprotocols=['graphql-ws'] + ) + + # Send the init message and wait for the ack from the server + await self._send_init_message_and_wait_ack() + + # Create a task to listen to the incoming websocket messages + self.listen_loop = asyncio.ensure_future(self._answer_loop()) + + async def close(self): + """Coroutine which will: + + - send the connection terminate message + - close the websocket connection + - send 'complete' messages to close all the existing subscribe async generators + - remove the listen_loop task + """ + + if self.websocket: + + await self._send_connection_terminate_message() + + await self.websocket.close() + + for query_id in self.listeners: + await self.listeners[query_id].put(('complete', None)) + + self.websocket = None - def execute(self, document): - raise NotImplementedError( - "You should use the async function 'execute_async' for this transport" - ) + self.listen_loop.cancel() diff --git a/tests_py36/test_websockets_transport.py b/tests_py36/test_websockets_transport.py index fb4e7830..3cf84cc2 100644 --- a/tests_py36/test_websockets_transport.py +++ b/tests_py36/test_websockets_transport.py @@ -2,16 +2,17 @@ logging.basicConfig(level=logging.INFO) -from gql import gql, Client +from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport from graphql.execution import ExecutionResult from typing import Dict import asyncio import pytest +import sys @pytest.mark.asyncio -async def test_websocket_query(): +async def test_websocket_simple_query(): # Get Websockets transport sample_transport = WebsocketsTransport( @@ -20,34 +21,94 @@ async def test_websocket_query(): ) # Instanciate client - client = Client(transport=sample_transport) + async with AsyncClient(transport=sample_transport) as client: - query = gql(''' - query getContinents { - continents { - code - name - } - } - ''') + query = gql(''' + query getContinents { + continents { + code + name + } + } + ''') - # Fetch schema - await client.fetch_schema() + # Fetch schema + await client.fetch_schema() - # Execute query - result = await client.execute_async(query) + # Execute query + result = await client.execute(query) - # Verify result - assert isinstance(result, ExecutionResult) - assert result.errors == None + # Verify result + assert isinstance(result, ExecutionResult) + assert result.errors == None - assert isinstance(result.data, Dict) + assert isinstance(result.data, Dict) - continents = result.data['continents'] + continents = result.data['continents'] - africa = continents[0] + africa = continents[0] - assert africa['code'] == 'AF' + assert africa['code'] == 'AF' - # Sleep 1 second to allow the connections to end - await asyncio.sleep(1) + print (sys.version_info) + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.asyncio +async def test_websocket_two_queries_in_parallel_using_two_tasks(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url='wss://countries.trevorblades.com/graphql', + ssl=True + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query1 = gql(''' + query getContinents { + continents { + code + } + } + ''') + + query2 = gql(''' + query getContinents { + continents { + name + } + } + ''') + + async def query_task1(): + result = await client.execute(query1) + + assert isinstance(result, ExecutionResult) + assert result.errors == None + + assert isinstance(result.data, Dict) + + continents = result.data['continents'] + + africa = continents[0] + assert africa['code'] == 'AF' + + async def query_task2(): + result = await client.execute(query2) + + assert isinstance(result, ExecutionResult) + assert result.errors == None + + assert isinstance(result.data, Dict) + + continents = result.data['continents'] + + africa = continents[0] + assert africa['name'] == 'Africa' + + task1 = asyncio.create_task(query_task1()) + task2 = asyncio.create_task(query_task2()) + + await task1 + await task2 From 81b795064f06e4466ce30affea8465a86c3358b3 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 27 Mar 2020 11:29:13 +0100 Subject: [PATCH 04/46] Adding gql-cli.py to allow to send queries to a websocket endpoint from the command line --- gql-cli.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 gql-cli.py diff --git a/gql-cli.py b/gql-cli.py new file mode 100644 index 00000000..ce10d54f --- /dev/null +++ b/gql-cli.py @@ -0,0 +1,28 @@ +from gql import gql, AsyncClient +from gql.transport.websockets import WebsocketsTransport +import asyncio +import argparse + +parser = argparse.ArgumentParser(description='Send GraphQL queries from command line to a websocket endpoint') +parser.add_argument('server', help='the server websocket url starting with ws:// or wss://') +args = parser.parse_args() + +async def main(): + + transport = WebsocketsTransport(url=args.server, ssl=args.server.startswith('wss')) + + async with AsyncClient(transport=transport) as client: + + while True: + try: + query_str = input() + except EOFError: + break + + query = gql(query_str) + + async for result in client.subscribe(query): + + print (result.data) + +asyncio.run(main()) From ae1ac365b68958687bd58c53c37bc4144e55c91d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 3 Apr 2020 16:43:58 +0200 Subject: [PATCH 05/46] Better management of ConnectionClosed Exceptions --- gql-cli.py | 2 +- gql/transport/websockets.py | 42 ++++++++++++++++++++++++++++--------- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/gql-cli.py b/gql-cli.py index ce10d54f..6ebf415a 100644 --- a/gql-cli.py +++ b/gql-cli.py @@ -8,7 +8,7 @@ args = parser.parse_args() async def main(): - + transport = WebsocketsTransport(url=args.server, ssl=args.server.startswith('wss')) async with AsyncClient(transport=transport) as client: diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 0cb24e71..cd61bb92 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -43,6 +43,7 @@ def __init__( self.websocket = None self.next_query_id = 1 self.listeners = {} + self._is_closing = False async def _send(self, message): """Send the provided message to the websocket connection and log the message @@ -51,15 +52,25 @@ async def _send(self, message): if not self.websocket: raise Exception('Transport is not connected') - await self.websocket.send(message) - log.info('>>> %s', message) + try: + await self.websocket.send(message) + log.info('>>> %s', message) + except (websockets.exceptions.ConnectionClosedError) as e: + await self.close() + raise e async def _receive(self): """Wait the next message from the websocket connection and log the answer """ - answer = await self.websocket.recv() - log.info('<<< %s', answer) + answer = None + + try: + answer = await self.websocket.recv() + log.info('<<< %s', answer) + except websockets.exceptions.ConnectionClosedError as e: + await self.close() + raise e return answer @@ -233,24 +244,28 @@ async def subscribe(self, document, variable_values=None, operation_name=None): elif answer_type == 'complete': break - except asyncio.CancelledError as error: + except (asyncio.CancelledError, GeneratorExit) as e: await self._send_stop_message(query_id) finally: del self.listeners[query_id] - async def execute(self, document, variable_values=None, operation_name=None): - """Send a query but close the connection as soon as we have the first answer + """Send a query but close the async generator as soon as we have the first answer The result is sent as an ExecutionResult object """ generator = self.subscribe(document, variable_values, operation_name) + first_result = None + async for execution_result in generator: first_result = execution_result generator.aclose() + if first_result is None: + raise asyncio.CancelledError + return first_result async def connect(self): @@ -274,6 +289,9 @@ async def connect(self): subprotocols=['graphql-ws'] ) + # Reset the next query id + self.next_query_id = 1 + # Send the init message and wait for the ack from the server await self._send_init_message_and_wait_ack() @@ -289,11 +307,15 @@ async def close(self): - remove the listen_loop task """ - if self.websocket: + if self.websocket and not self._is_closing: - await self._send_connection_terminate_message() + self._is_closing = True - await self.websocket.close() + try: + await self._send_connection_terminate_message() + await self.websocket.close() + except websockets.exceptions.ConnectionClosedError: + pass for query_id in self.listeners: await self.listeners[query_id].put(('complete', None)) From ada2d6ce50388a714fe9cbea89f5f642e0e5f2e0 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Fri, 3 Apr 2020 17:19:46 +0200 Subject: [PATCH 06/46] Fix dependencies --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 215d31e0..f689bdde 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,7 @@ 'graphql-core>=2,<3', 'promise>=2.0,<3', 'requests>=2.12,<3', - 'pathlib', - 'websockets' + 'websockets>=8.1,<9' ] tests_require = [ From 22188502aae29a728f150614f36684683663fc41 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 6 Apr 2020 13:42:39 +0200 Subject: [PATCH 07/46] Using black to format code --- gql/client.py | 6 +-- gql/transport/__init__.py | 2 +- gql/transport/websockets.py | 81 ++++++++++++++++++------------------- 3 files changed, 43 insertions(+), 46 deletions(-) diff --git a/gql/client.py b/gql/client.py index cd51e644..492f8804 100644 --- a/gql/client.py +++ b/gql/client.py @@ -35,8 +35,8 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" - assert ( - not isinstance(transport, AsyncTransport) + assert not isinstance( + transport, AsyncTransport ), "With an asyncio transport, please use 'await client.fetch_schema()' instead of fetch_schema_from_transport=True" introspection = transport.execute(parse(introspection_query)).data if introspection: @@ -98,8 +98,8 @@ def _get_result(self, document, *args, **kwargs): raise RetryError(retries_count, last_exception) -class AsyncClient(Client): +class AsyncClient(Client): async def subscribe(self, document, *args, **kwargs): if self.schema: self.validate(document) diff --git a/gql/transport/__init__.py b/gql/transport/__init__.py index 6193ddd2..bd2f0ed8 100644 --- a/gql/transport/__init__.py +++ b/gql/transport/__init__.py @@ -21,9 +21,9 @@ def execute(self, document): "Any Transport subclass must implement execute method" ) + @six.add_metaclass(abc.ABCMeta) class AsyncTransport(Transport): - @abc.abstractmethod async def connect(self): """Coroutine used to create a connection to the specified address diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index cd61bb92..c7a560fa 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -15,6 +15,7 @@ log = logging.getLogger(__name__) + class WebsocketsTransport(AsyncTransport): """Transport to execute GraphQL queries on remote servers with a websocket connection. @@ -25,10 +26,7 @@ class WebsocketsTransport(AsyncTransport): """ def __init__( - self, - url, - headers=None, - ssl=False, + self, url, headers=None, ssl=False, ): """Initialize the transport with the given request parameters. @@ -50,11 +48,11 @@ async def _send(self, message): """ if not self.websocket: - raise Exception('Transport is not connected') + raise Exception("Transport is not connected") try: await self.websocket.send(message) - log.info('>>> %s', message) + log.info(">>> %s", message) except (websockets.exceptions.ConnectionClosedError) as e: await self.close() raise e @@ -67,7 +65,7 @@ async def _receive(self): try: answer = await self.websocket.recv() - log.info('<<< %s', answer) + log.info("<<< %s", answer) except websockets.exceptions.ConnectionClosedError as e: await self.close() raise e @@ -86,8 +84,8 @@ async def _send_init_message_and_wait_ack(self): answer_type, answer_id, execution_result = self._parse_answer(init_answer) - if answer_type != 'connection_ack': - raise Exception('Websocket server did not return a connection ack') + if answer_type != "connection_ack": + raise Exception("Websocket server did not return a connection ack") async def _send_stop_message(self, query_id): """Send a stop message to the provided websocket connection for the provided query_id @@ -95,10 +93,7 @@ async def _send_stop_message(self, query_id): The server should afterwards return a 'complete' message """ - stop_message = json.dumps({ - 'id': str(query_id), - 'type': 'stop' - }) + stop_message = json.dumps({"id": str(query_id), "type": "stop"}) await self._send(stop_message) @@ -108,9 +103,7 @@ async def _send_connection_terminate_message(self): This message indicate that the connection will disconnect """ - connection_terminate_message = json.dumps({ - 'type': 'connection_terminate' - }) + connection_terminate_message = json.dumps({"type": "connection_terminate"}) await self._send(connection_terminate_message) @@ -125,16 +118,18 @@ async def _send_query(self, document, variable_values, operation_name): query_id = self.next_query_id self.next_query_id += 1 - query_str = json.dumps({ - 'id': str(query_id), - 'type': 'start', - 'payload': { - 'variables': variable_values or {}, - #'extensions': {}, - 'operationName': operation_name or '', - 'query': print_ast(document) + query_str = json.dumps( + { + "id": str(query_id), + "type": "start", + "payload": { + "variables": variable_values or {}, + #'extensions': {}, + "operationName": operation_name or "", + "query": print_ast(document), + }, } - }) + ) await self._send(query_str) @@ -159,34 +154,36 @@ def _parse_answer(self, answer): if not isinstance(json_answer, dict): raise ValueError - answer_type = json_answer.get('type') + answer_type = json_answer.get("type") - if answer_type in ['data', 'error', 'complete']: - answer_id = int(json_answer.get('id')) + if answer_type in ["data", "error", "complete"]: + answer_id = int(json_answer.get("id")) - if answer_type == 'data': - result = json_answer.get('payload') + if answer_type == "data": + result = json_answer.get("payload") - if 'errors' not in result and 'data' not in result: + if "errors" not in result and "data" not in result: raise ValueError - execution_result = ExecutionResult(errors=result.get('errors'), data=result.get('data')) + execution_result = ExecutionResult( + errors=result.get("errors"), data=result.get("data") + ) - elif answer_type == 'error': - raise Exception('Websocket server error') + elif answer_type == "error": + raise Exception("Websocket server error") - elif answer_type == 'ka': + elif answer_type == "ka": # KeepAlive message pass - elif answer_type == 'connection_ack': + elif answer_type == "connection_ack": pass - elif answer_type == 'connection_error': - raise Exception('Websocket Connection Error') + elif answer_type == "connection_error": + raise Exception("Websocket Connection Error") else: raise ValueError except ValueError: - raise Exception('Websocket server did not return a GraphQL result') + raise Exception("Websocket server did not return a GraphQL result") return (answer_type, answer_id, execution_result) @@ -241,7 +238,7 @@ async def subscribe(self, document, variable_values=None, operation_name=None): # If we receive a 'complete' answer from the server, # Then we will end this async generator output and disconnect from the server - elif answer_type == 'complete': + elif answer_type == "complete": break except (asyncio.CancelledError, GeneratorExit) as e: @@ -286,7 +283,7 @@ async def connect(self): self.url, ssl=self.ssl, extra_headers=self.headers, - subprotocols=['graphql-ws'] + subprotocols=["graphql-ws"], ) # Reset the next query id @@ -318,7 +315,7 @@ async def close(self): pass for query_id in self.listeners: - await self.listeners[query_id].put(('complete', None)) + await self.listeners[query_id].put(("complete", None)) self.websocket = None From 0a6863b80e0e2698d88aece123c4746e24ac04ce Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 6 Apr 2020 13:50:05 +0200 Subject: [PATCH 08/46] Fix flake8 style enforcement --- gql/client.py | 2 +- gql/transport/websockets.py | 8 ++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/gql/client.py b/gql/client.py index 492f8804..4a78a43c 100644 --- a/gql/client.py +++ b/gql/client.py @@ -37,7 +37,7 @@ def __init__( ), "Cant fetch the schema from transport if is already provided" assert not isinstance( transport, AsyncTransport - ), "With an asyncio transport, please use 'await client.fetch_schema()' instead of fetch_schema_from_transport=True" + ), "With an asyncio transport, please use the AsyncClient class" introspection = transport.execute(parse(introspection_query)).data if introspection: assert not schema, "Cant provide introspection and schema at the same time" diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index c7a560fa..e0e8bdda 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,14 +1,11 @@ from __future__ import absolute_import -from typing import Any, Dict, Union - import websockets import asyncio import json import logging from graphql.execution import ExecutionResult -from graphql.language.ast import Document from graphql.language.printer import print_ast from gql.transport import AsyncTransport @@ -124,7 +121,6 @@ async def _send_query(self, document, variable_values, operation_name): "type": "start", "payload": { "variables": variable_values or {}, - #'extensions': {}, "operationName": operation_name or "", "query": print_ast(document), }, @@ -241,7 +237,7 @@ async def subscribe(self, document, variable_values=None, operation_name=None): elif answer_type == "complete": break - except (asyncio.CancelledError, GeneratorExit) as e: + except (asyncio.CancelledError, GeneratorExit): await self._send_stop_message(query_id) finally: @@ -276,7 +272,7 @@ async def connect(self): Should be cleaned with a call to the close coroutine """ - if self.websocket == None: + if self.websocket is None: # Connection to the specified url self.websocket = await websockets.connect( From ff867440f90155808ba085be65f61b9150c4940c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 6 Apr 2020 15:34:11 +0200 Subject: [PATCH 09/46] Adding typing hints --- gql/client.py | 12 +++- gql/transport/__init__.py | 18 +++++- gql/transport/websockets.py | 114 +++++++++++++++++++++++++----------- 3 files changed, 105 insertions(+), 39 deletions(-) diff --git a/gql/client.py b/gql/client.py index 4a78a43c..72f5b873 100644 --- a/gql/client.py +++ b/gql/client.py @@ -2,6 +2,10 @@ from graphql import build_ast_schema, build_client_schema, introspection_query, parse from graphql.validation import validate +from graphql.execution import ExecutionResult +from graphql.language.ast import Document + +from typing import AsyncGenerator from .transport.local_schema import LocalSchemaTransport from gql.transport import AsyncTransport @@ -100,20 +104,22 @@ def _get_result(self, document, *args, **kwargs): class AsyncClient(Client): - async def subscribe(self, document, *args, **kwargs): + async def subscribe( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[ExecutionResult, None]: if self.schema: self.validate(document) async for result in self.transport.subscribe(document, *args, **kwargs): yield result - async def execute(self, document, *args, **kwargs): + async def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: if self.schema: self.validate(document) return await self.transport.execute(document, *args, **kwargs) - async def fetch_schema(self): + async def fetch_schema(self) -> None: execution_result = await self.transport.execute(parse(introspection_query)) self.introspection = execution_result.data self.schema = build_client_schema(self.introspection) diff --git a/gql/transport/__init__.py b/gql/transport/__init__.py index bd2f0ed8..71fd1ede 100644 --- a/gql/transport/__init__.py +++ b/gql/transport/__init__.py @@ -6,6 +6,8 @@ from graphql.language.ast import Document from promise import Promise +from typing import Dict, Optional, AsyncGenerator + @six.add_metaclass(abc.ABCMeta) class Transport: @@ -23,7 +25,7 @@ def execute(self, document): @six.add_metaclass(abc.ABCMeta) -class AsyncTransport(Transport): +class AsyncTransport: @abc.abstractmethod async def connect(self): """Coroutine used to create a connection to the specified address @@ -41,7 +43,12 @@ async def close(self): ) @abc.abstractmethod - async def execute(self, document, variable_values=None, operation_name=None): + async def execute( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: """Execute the provided document AST for either a remote or local GraphQL Schema. """ raise NotImplementedError( @@ -49,7 +56,12 @@ async def execute(self, document, variable_values=None, operation_name=None): ) @abc.abstractmethod - async def subscribe(self, document, variable_values=None, operation_name=None): + def subscribe( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using an async generator The query can be a graphql query, mutation or subscription diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index e0e8bdda..24b99dc5 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,11 +1,21 @@ from __future__ import absolute_import import websockets +from websockets.http import HeadersLike +from websockets.typing import Data, Subprotocol +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosedError + +from ssl import SSLContext + import asyncio import json import logging +from typing import cast, Dict, Optional, Tuple, Union, NoReturn, AsyncGenerator + from graphql.execution import ExecutionResult +from graphql.language.ast import Document from graphql.language.printer import print_ast from gql.transport import AsyncTransport @@ -23,24 +33,27 @@ class WebsocketsTransport(AsyncTransport): """ def __init__( - self, url, headers=None, ssl=False, - ): + self, + url: str, + headers: Optional[HeadersLike] = None, + ssl: Union[SSLContext, bool] = False, + ) -> None: """Initialize the transport with the given request parameters. :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption """ - self.url = url - self.ssl = ssl - self.headers = headers + self.url: str = url + self.ssl: Union[SSLContext, bool] = ssl + self.headers: Optional[HeadersLike] = headers - self.websocket = None - self.next_query_id = 1 - self.listeners = {} - self._is_closing = False + self.websocket: Optional[WebSocketClientProtocol] = None + self.next_query_id: int = 1 + self.listeners: Dict[int, asyncio.Queue] = {} + self._is_closing: bool = False - async def _send(self, message): + async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message """ @@ -50,26 +63,37 @@ async def _send(self, message): try: await self.websocket.send(message) log.info(">>> %s", message) - except (websockets.exceptions.ConnectionClosedError) as e: + except (ConnectionClosedError) as e: await self.close() raise e - async def _receive(self): + async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - answer = None + answer: Optional[str] = None + + if not self.websocket: + raise Exception("Transport is not connected") try: - answer = await self.websocket.recv() + data: Data = await self.websocket.recv() + + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise Exception("Binary data received in the websocket") + + answer = data + log.info("<<< %s", answer) - except websockets.exceptions.ConnectionClosedError as e: + except ConnectionClosedError as e: await self.close() raise e return answer - async def _send_init_message_and_wait_ack(self): + async def _send_init_message_and_wait_ack(self) -> None: """Send an init message to the provided websocket then wait for the connection ack If the answer is not a connection_ack message, we will return an Exception @@ -84,7 +108,7 @@ async def _send_init_message_and_wait_ack(self): if answer_type != "connection_ack": raise Exception("Websocket server did not return a connection ack") - async def _send_stop_message(self, query_id): + async def _send_stop_message(self, query_id: int) -> None: """Send a stop message to the provided websocket connection for the provided query_id The server should afterwards return a 'complete' message @@ -94,7 +118,7 @@ async def _send_stop_message(self, query_id): await self._send(stop_message) - async def _send_connection_terminate_message(self): + async def _send_connection_terminate_message(self) -> None: """Send a connection_terminate message to the provided websocket connection This message indicate that the connection will disconnect @@ -104,7 +128,12 @@ async def _send_connection_terminate_message(self): await self._send(connection_terminate_message) - async def _send_query(self, document, variable_values, operation_name): + async def _send_query( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> int: """Send a query to the provided websocket connection We use an incremented id to reference the query @@ -131,7 +160,9 @@ async def _send_query(self, document, variable_values, operation_name): return query_id - def _parse_answer(self, answer): + def _parse_answer( + self, answer: str + ) -> Tuple[str, Optional[int], Optional[ExecutionResult]]: """Parse the answer received from the server Returns a list consisting of: @@ -140,9 +171,9 @@ def _parse_answer(self, answer): - an execution Result if the answer_type is 'data' or None """ - answer_type = None - answer_id = None - execution_result = None + answer_type: str = "" + answer_id: Optional[int] = None + execution_result: Optional[ExecutionResult] = None try: json_answer = json.loads(answer) @@ -150,14 +181,17 @@ def _parse_answer(self, answer): if not isinstance(json_answer, dict): raise ValueError - answer_type = json_answer.get("type") + answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: - answer_id = int(json_answer.get("id")) + answer_id = int(str(json_answer.get("id"))) if answer_type == "data": result = json_answer.get("payload") + if not isinstance(result, Dict): + raise ValueError + if "errors" not in result and "data" not in result: raise ValueError @@ -183,7 +217,7 @@ def _parse_answer(self, answer): return (answer_type, answer_id, execution_result) - async def _answer_loop(self): + async def _answer_loop(self) -> NoReturn: while True: @@ -203,7 +237,12 @@ async def _answer_loop(self): # Put the answer in the queue await queue.put((answer_type, execution_result)) - async def subscribe(self, document, variable_values=None, operation_name=None): + async def subscribe( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using a python async generator The query can be a graphql query, mutation or subscription @@ -212,7 +251,9 @@ async def subscribe(self, document, variable_values=None, operation_name=None): """ # Send the query and receive the id - query_id = await self._send_query(document, variable_values, operation_name) + query_id: int = await self._send_query( + document, variable_values, operation_name + ) # Create a queue to receive the answers for this query_id self.listeners[query_id] = asyncio.Queue() @@ -243,7 +284,12 @@ async def subscribe(self, document, variable_values=None, operation_name=None): finally: del self.listeners[query_id] - async def execute(self, document, variable_values=None, operation_name=None): + async def execute( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: """Send a query but close the async generator as soon as we have the first answer The result is sent as an ExecutionResult object @@ -261,7 +307,7 @@ async def execute(self, document, variable_values=None, operation_name=None): return first_result - async def connect(self): + async def connect(self) -> None: """Coroutine which will: - connect to the websocket address @@ -272,6 +318,8 @@ async def connect(self): Should be cleaned with a call to the close coroutine """ + GRAPHQLWS_SUBPROTOCOL: Subprotocol = cast(Subprotocol, "graphql-ws") + if self.websocket is None: # Connection to the specified url @@ -279,7 +327,7 @@ async def connect(self): self.url, ssl=self.ssl, extra_headers=self.headers, - subprotocols=["graphql-ws"], + subprotocols=[GRAPHQLWS_SUBPROTOCOL], ) # Reset the next query id @@ -291,7 +339,7 @@ async def connect(self): # Create a task to listen to the incoming websocket messages self.listen_loop = asyncio.ensure_future(self._answer_loop()) - async def close(self): + async def close(self) -> None: """Coroutine which will: - send the connection terminate message @@ -307,7 +355,7 @@ async def close(self): try: await self._send_connection_terminate_message() await self.websocket.close() - except websockets.exceptions.ConnectionClosedError: + except ConnectionClosedError: pass for query_id in self.listeners: From 2398866173893b9d8ae3db9e66e7461b7099eb99 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 6 Apr 2020 17:17:25 +0200 Subject: [PATCH 10/46] Adding gql-cli as a python script + added to MANIFEST --- MANIFEST.in | 2 ++ gql-cli.py => scripts/gql-cli | 2 ++ setup.py | 1 + 3 files changed, 5 insertions(+) rename gql-cli.py => scripts/gql-cli (97%) mode change 100644 => 100755 diff --git a/MANIFEST.in b/MANIFEST.in index 369523a6..8ccdab11 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,6 +10,8 @@ include Makefile include tox.ini +include scripts/gql-cli + recursive-include tests *.py *.yaml *.graphql recursive-include tests_py36 *.py diff --git a/gql-cli.py b/scripts/gql-cli old mode 100644 new mode 100755 similarity index 97% rename from gql-cli.py rename to scripts/gql-cli index 6ebf415a..46cbca1b --- a/gql-cli.py +++ b/scripts/gql-cli @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport import asyncio diff --git a/setup.py b/setup.py index f689bdde..9206488b 100644 --- a/setup.py +++ b/setup.py @@ -63,4 +63,5 @@ include_package_data=True, zip_safe=False, platforms="any", + scripts=['scripts/gql-cli'], ) From 3e314b8b8854b8ff2c432d4cddef7d8dc57d0d56 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 7 Apr 2020 10:52:06 +0200 Subject: [PATCH 11/46] setup.py adding websockets dependency only if python version is >= 3.6 --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 9206488b..b022914d 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,7 @@ 'six>=1.10.0', 'graphql-core>=2,<3', 'promise>=2.0,<3', - 'requests>=2.12,<3', - 'websockets>=8.1,<9' + 'requests>=2.12,<3' ] tests_require = [ @@ -20,6 +19,7 @@ if sys.version_info > (3, 6): tests_require.append('pytest-asyncio>=0.9.0') + install_requires.append('websockets>=8.1,<9') dev_requires = [ 'flake8==3.7.9', From 00f2433b2c2ff7146523dfd66416a4e839b81170 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 7 Apr 2020 11:08:02 +0200 Subject: [PATCH 12/46] Moving python3 code in separate files to try to fix tests for python2 --- gql/__init__.py | 12 ++++- gql/async_client.py | 60 +++++++++++++++++++++++++ gql/client.py | 39 ++--------------- gql/transport/__init__.py | 75 +++----------------------------- gql/transport/async_transport.py | 56 ++++++++++++++++++++++++ gql/transport/transport.py | 22 ++++++++++ gql/transport/websockets.py | 5 ++- setup.py | 6 ++- 8 files changed, 166 insertions(+), 109 deletions(-) create mode 100644 gql/async_client.py create mode 100644 gql/transport/async_transport.py create mode 100644 gql/transport/transport.py diff --git a/gql/__init__.py b/gql/__init__.py index f3c8f920..577a71e2 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,4 +1,12 @@ +import sys + from .gql import gql -from .client import Client, AsyncClient +from .client import Client + +__all__ = ["gql", "Client"] + +if sys.version_info > (3, 6): + from .async_client import AsyncClient -__all__ = ["gql", "Client", "AsyncClient"] + # Cannot use __all__.append here because of flake8 warning + __all__ = ["gql", "Client", "AsyncClient"] diff --git a/gql/async_client.py b/gql/async_client.py new file mode 100644 index 00000000..96452aaa --- /dev/null +++ b/gql/async_client.py @@ -0,0 +1,60 @@ +from graphql import build_ast_schema, build_client_schema, introspection_query, parse +from graphql.execution import ExecutionResult +from graphql.language.ast import Document + +from typing import AsyncGenerator + +from gql.transport import AsyncTransport +from gql import Client + + +class AsyncClient(Client): + def __init__( + self, schema=None, introspection=None, type_def=None, transport=None, + ): + assert isinstance( + transport, AsyncTransport + ), "Only a transport of type AsyncTransport is supported on AsyncClient" + assert not ( + type_def and introspection + ), "Cant provide introspection type definition at the same time" + if introspection: + assert not schema, "Cant provide introspection and schema at the same time" + schema = build_client_schema(introspection) + elif type_def: + assert ( + not schema + ), "Cant provide Type definition and schema at the same time" + type_def_ast = parse(type_def) + schema = build_ast_schema(type_def_ast) + + self.schema = schema + self.introspection = introspection + self.transport = transport + + async def subscribe( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[ExecutionResult, None]: + if self.schema: + self.validate(document) + + async for result in self.transport.subscribe(document, *args, **kwargs): + yield result + + async def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: + if self.schema: + self.validate(document) + + return await self.transport.execute(document, *args, **kwargs) + + async def fetch_schema(self) -> None: + execution_result = await self.transport.execute(parse(introspection_query)) + self.introspection = execution_result.data + self.schema = build_client_schema(self.introspection) + + async def __aenter__(self): + await self.transport.connect() + return self + + async def __aexit__(self, *args): + await self.transport.close() diff --git a/gql/client.py b/gql/client.py index 72f5b873..70d8a0a9 100644 --- a/gql/client.py +++ b/gql/client.py @@ -2,13 +2,9 @@ from graphql import build_ast_schema, build_client_schema, introspection_query, parse from graphql.validation import validate -from graphql.execution import ExecutionResult -from graphql.language.ast import Document - -from typing import AsyncGenerator from .transport.local_schema import LocalSchemaTransport -from gql.transport import AsyncTransport +from .transport import Transport log = logging.getLogger(__name__) @@ -39,8 +35,8 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" - assert not isinstance( - transport, AsyncTransport + assert isinstance( + transport, Transport ), "With an asyncio transport, please use the AsyncClient class" introspection = transport.execute(parse(introspection_query)).data if introspection: @@ -101,32 +97,3 @@ def _get_result(self, document, *args, **kwargs): retries_count += 1 raise RetryError(retries_count, last_exception) - - -class AsyncClient(Client): - async def subscribe( - self, document: Document, *args, **kwargs - ) -> AsyncGenerator[ExecutionResult, None]: - if self.schema: - self.validate(document) - - async for result in self.transport.subscribe(document, *args, **kwargs): - yield result - - async def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: - if self.schema: - self.validate(document) - - return await self.transport.execute(document, *args, **kwargs) - - async def fetch_schema(self) -> None: - execution_result = await self.transport.execute(parse(introspection_query)) - self.introspection = execution_result.data - self.schema = build_client_schema(self.introspection) - - async def __aenter__(self): - await self.transport.connect() - return self - - async def __aexit__(self, *args): - await self.transport.close() diff --git a/gql/transport/__init__.py b/gql/transport/__init__.py index 71fd1ede..0fdf95d6 100644 --- a/gql/transport/__init__.py +++ b/gql/transport/__init__.py @@ -1,73 +1,12 @@ -import abc -from typing import Union +import sys -import six -from graphql.execution import ExecutionResult -from graphql.language.ast import Document -from promise import Promise +from .transport import Transport -from typing import Dict, Optional, AsyncGenerator +__all__ = ["Transport"] -@six.add_metaclass(abc.ABCMeta) -class Transport: - @abc.abstractmethod - def execute(self, document): - # type: (Document) -> Union[ExecutionResult, Promise[ExecutionResult]] - """Execute the provided document AST for either a remote or local GraphQL Schema. +if sys.version_info > (3, 6): + from .async_transport import AsyncTransport - :param document: GraphQL query as AST Node or Document object. - :return: Either ExecutionResult or a Promise that resolves to ExecutionResult object. - """ - raise NotImplementedError( - "Any Transport subclass must implement execute method" - ) - - -@six.add_metaclass(abc.ABCMeta) -class AsyncTransport: - @abc.abstractmethod - async def connect(self): - """Coroutine used to create a connection to the specified address - """ - raise NotImplementedError( - "Any AsyncTransport subclass must implement execute method" - ) - - @abc.abstractmethod - async def close(self): - """Coroutine used to Close an established connection - """ - raise NotImplementedError( - "Any AsyncTransport subclass must implement execute method" - ) - - @abc.abstractmethod - async def execute( - self, - document: Document, - variable_values: Optional[Dict[str, str]] = None, - operation_name: Optional[str] = None, - ) -> ExecutionResult: - """Execute the provided document AST for either a remote or local GraphQL Schema. - """ - raise NotImplementedError( - "Any AsyncTransport subclass must implement execute method" - ) - - @abc.abstractmethod - def subscribe( - self, - document: Document, - variable_values: Optional[Dict[str, str]] = None, - operation_name: Optional[str] = None, - ) -> AsyncGenerator[ExecutionResult, None]: - """Send a query and receive the results using an async generator - - The query can be a graphql query, mutation or subscription - - The results are sent as an ExecutionResult object - """ - raise NotImplementedError( - "Any AsyncTransport subclass must implement execute method" - ) + # Cannot use __all__.append here because of flake8 warning + __all__ = ["Transport", "AsyncTransport"] diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py new file mode 100644 index 00000000..2d4107cf --- /dev/null +++ b/gql/transport/async_transport.py @@ -0,0 +1,56 @@ +import abc + +import six +from graphql.execution import ExecutionResult +from graphql.language.ast import Document + +from typing import Dict, Optional, AsyncGenerator + + +@six.add_metaclass(abc.ABCMeta) +class AsyncTransport: + @abc.abstractmethod + async def connect(self): + """Coroutine used to create a connection to the specified address + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + async def close(self): + """Coroutine used to Close an established connection + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + async def execute( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> ExecutionResult: + """Execute the provided document AST for either a remote or local GraphQL Schema. + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) + + @abc.abstractmethod + def subscribe( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> AsyncGenerator[ExecutionResult, None]: + """Send a query and receive the results using an async generator + + The query can be a graphql query, mutation or subscription + + The results are sent as an ExecutionResult object + """ + raise NotImplementedError( + "Any AsyncTransport subclass must implement execute method" + ) diff --git a/gql/transport/transport.py b/gql/transport/transport.py new file mode 100644 index 00000000..89c636a0 --- /dev/null +++ b/gql/transport/transport.py @@ -0,0 +1,22 @@ +import abc +from typing import Union + +import six +from graphql.execution import ExecutionResult +from graphql.language.ast import Document +from promise import Promise + + +@six.add_metaclass(abc.ABCMeta) +class Transport: + @abc.abstractmethod + def execute(self, document): + # type: (Document) -> Union[ExecutionResult, Promise[ExecutionResult]] + """Execute the provided document AST for either a remote or local GraphQL Schema. + + :param document: GraphQL query as AST Node or Document object. + :return: Either ExecutionResult or a Promise that resolves to ExecutionResult object. + """ + raise NotImplementedError( + "Any Transport subclass must implement execute method" + ) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 24b99dc5..86944129 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -12,7 +12,7 @@ import json import logging -from typing import cast, Dict, Optional, Tuple, Union, NoReturn, AsyncGenerator +from typing import cast, Dict, Optional, Tuple, Union, AsyncGenerator from graphql.execution import ExecutionResult from graphql.language.ast import Document @@ -217,7 +217,8 @@ def _parse_answer( return (answer_type, answer_id, execution_result) - async def _answer_loop(self) -> NoReturn: + async def _answer_loop(self): + # Note: the return type here is NoReturn but NoReturn is not yet supported in python 3.6 while True: diff --git a/setup.py b/setup.py index b022914d..6913b24f 100644 --- a/setup.py +++ b/setup.py @@ -17,9 +17,13 @@ 'vcrpy==3.0.0', ] +scripts = [] + if sys.version_info > (3, 6): tests_require.append('pytest-asyncio>=0.9.0') install_requires.append('websockets>=8.1,<9') + scripts.append['scripts/gql-cli'] + dev_requires = [ 'flake8==3.7.9', @@ -63,5 +67,5 @@ include_package_data=True, zip_safe=False, platforms="any", - scripts=['scripts/gql-cli'], + scripts=scripts, ) From 4197748e4a3ac93ff5a9c8c2546dfb76943cacd2 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 7 Apr 2020 13:00:13 +0200 Subject: [PATCH 13/46] Fix typo in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6913b24f..7c743e95 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ if sys.version_info > (3, 6): tests_require.append('pytest-asyncio>=0.9.0') install_requires.append('websockets>=8.1,<9') - scripts.append['scripts/gql-cli'] + scripts.append('scripts/gql-cli') dev_requires = [ From 930ae7ea519a28d528ebf1556ad8fc12578d6e23 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 7 Apr 2020 13:13:29 +0200 Subject: [PATCH 14/46] Using relative imports to fix mypy --- gql/async_client.py | 4 ++-- gql/transport/websockets.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/gql/async_client.py b/gql/async_client.py index 96452aaa..2b4f1a1b 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -4,8 +4,8 @@ from typing import AsyncGenerator -from gql.transport import AsyncTransport -from gql import Client +from .transport import AsyncTransport +from .client import Client class AsyncClient(Client): diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 86944129..65c9aced 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -18,7 +18,7 @@ from graphql.language.ast import Document from graphql.language.printer import print_ast -from gql.transport import AsyncTransport +from .async_transport import AsyncTransport log = logging.getLogger(__name__) From 55807c2bf1b037335657d9bb69e983bd3f02fae3 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 7 Apr 2020 13:22:59 +0200 Subject: [PATCH 15/46] Using relative imports to fix mypy - try 2 --- gql/async_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/async_client.py b/gql/async_client.py index 2b4f1a1b..2592f766 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -4,7 +4,7 @@ from typing import AsyncGenerator -from .transport import AsyncTransport +from .transport.async_transport import AsyncTransport from .client import Client From 58aa38a9e0cdcbdd860e4a1387701c334688445d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 8 Apr 2020 15:47:04 +0200 Subject: [PATCH 16/46] Better management of cleanup + adding some tests Adding fixtures to create a websocket server for the tests Allow ssl=False argument in transport --- gql/transport/websockets.py | 150 ++++++--- setup.py | 1 + ..._transport.py => test_websocket_online.py} | 55 ++-- tests_py36/test_websocket_query.py | 165 ++++++++++ tests_py36/test_websocket_subscription.py | 291 ++++++++++++++++++ tests_py36/websocket_fixtures.py | 111 +++++++ 6 files changed, 709 insertions(+), 64 deletions(-) rename tests_py36/{test_websockets_transport.py => test_websocket_online.py} (73%) create mode 100644 tests_py36/test_websocket_query.py create mode 100644 tests_py36/test_websocket_subscription.py create mode 100644 tests_py36/websocket_fixtures.py diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 65c9aced..ed15f12c 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -4,7 +4,7 @@ from websockets.http import HeadersLike from websockets.typing import Data, Subprotocol from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosedError +from websockets.exceptions import ConnectionClosedOK, ConnectionClosed from ssl import SSLContext @@ -22,6 +22,52 @@ log = logging.getLogger(__name__) +ParsedAnswer = Tuple[str, Optional[ExecutionResult]] + + +class ListenerQueue: + """Special queue used for each query waiting for server answers + + If the server is stopped while the listener is still waiting, + Then we send an exception to the queue and this exception will be raised + to the consumer once all the previous messages have been consumed from the queue + """ + + def __init__(self, query_id: int, send_stop: bool) -> None: + self.query_id: int = query_id + self.send_stop: bool = send_stop + self._queue: asyncio.Queue = asyncio.Queue() + self._closed: bool = False + + async def get(self) -> ParsedAnswer: + + item = await self._queue.get() + self._queue.task_done() + + # If we receive an exception when reading the queue, we raise it + if isinstance(item, Exception): + self._closed = True + raise item + + # Don't need to save new answers or + # send the stop message if we already received the complete message + answer_type, execution_result = item + if answer_type == "complete": + self.send_stop = False + self._closed = True + + return item + + async def put(self, item: ParsedAnswer) -> None: + + if not self._closed: + await self._queue.put(item) + + async def set_exception(self, exception: Exception) -> None: + + # Put the exception in the queue + await self._queue.put(exception) + class WebsocketsTransport(AsyncTransport): """Transport to execute GraphQL queries on remote servers with a websocket connection. @@ -50,8 +96,9 @@ def __init__( self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 - self.listeners: Dict[int, asyncio.Queue] = {} + self.listeners: Dict[int, ListenerQueue] = {} self._is_closing: bool = False + self._no_more_listeners: asyncio.Event = asyncio.Event() async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message @@ -63,8 +110,8 @@ async def _send(self, message: str) -> None: try: await self.websocket.send(message) log.info(">>> %s", message) - except (ConnectionClosedError) as e: - await self.close() + except (ConnectionClosed) as e: + await self._close_with_exception(e) raise e async def _receive(self) -> str: @@ -87,8 +134,8 @@ async def _receive(self) -> str: answer = data log.info("<<< %s", answer) - except ConnectionClosedError as e: - await self.close() + except ConnectionClosed as e: + await self._close_with_exception(e) raise e return answer @@ -217,13 +264,15 @@ def _parse_answer( return (answer_type, answer_id, execution_result) - async def _answer_loop(self): - # Note: the return type here is NoReturn but NoReturn is not yet supported in python 3.6 + async def _answer_loop(self) -> None: while True: # Wait the next answer from the websocket server - answer = await self._receive() + try: + answer = await self._receive() + except ConnectionClosed: + return # Parse the answer answer_type, answer_id, execution_result = self._parse_answer(answer) @@ -232,17 +281,15 @@ async def _answer_loop(self): if answer_id not in self.listeners: continue - # Get the related queue - queue = self.listeners[answer_id] - # Put the answer in the queue - await queue.put((answer_type, execution_result)) + await self.listeners[answer_id].put((answer_type, execution_result)) async def subscribe( self, document: Document, variable_values: Optional[Dict[str, str]] = None, operation_name: Optional[str] = None, + send_stop: Optional[bool] = True, ) -> AsyncGenerator[ExecutionResult, None]: """Send a query and receive the results using a python async generator @@ -257,17 +304,15 @@ async def subscribe( ) # Create a queue to receive the answers for this query_id - self.listeners[query_id] = asyncio.Queue() + listener = ListenerQueue(query_id, send_stop=(send_stop is True)) + self.listeners[query_id] = listener try: # Loop over the received answers while True: # Wait for the answer from the queue of this query_id - answer_type, execution_result = await self.listeners[query_id].get() - - # Set the task as done in the listeners queue - self.listeners[query_id].task_done() + answer_type, execution_result = await listener.get() # If the received answer contains data, # Then we will yield the results back as an ExecutionResult object @@ -277,13 +322,21 @@ async def subscribe( # If we receive a 'complete' answer from the server, # Then we will end this async generator output and disconnect from the server elif answer_type == "complete": + log.debug( + f"Complete received for query {query_id} --> exit without error" + ) break - except (asyncio.CancelledError, GeneratorExit): - await self._send_stop_message(query_id) + except (asyncio.CancelledError, GeneratorExit) as e: + log.debug("Exception in subscribe: " + repr(e)) + if listener.send_stop: + await self._send_stop_message(query_id) + listener.send_stop = False finally: del self.listeners[query_id] + if len(self.listeners) == 0: + self._no_more_listeners.set() async def execute( self, @@ -295,16 +348,10 @@ async def execute( The result is sent as an ExecutionResult object """ - generator = self.subscribe(document, variable_values, operation_name) - - first_result = None - - async for execution_result in generator: - first_result = execution_result - generator.aclose() - - if first_result is None: - raise asyncio.CancelledError + async for result in self.subscribe( + document, variable_values, operation_name, send_stop=False + ): + first_result = result return first_result @@ -326,7 +373,7 @@ async def connect(self) -> None: # Connection to the specified url self.websocket = await websockets.connect( self.url, - ssl=self.ssl, + ssl=self.ssl if self.ssl else None, extra_headers=self.headers, subprotocols=[GRAPHQLWS_SUBPROTOCOL], ) @@ -343,24 +390,49 @@ async def connect(self) -> None: async def close(self) -> None: """Coroutine which will: + - send stop messages for each active query to the server - send the connection terminate message - close the websocket connection - - send 'complete' messages to close all the existing subscribe async generators + + - send the exceptions to all current listeners - remove the listen_loop task """ + if self.websocket and not self._is_closing: + + # Send stop message for all current queries + for query_id, listener in self.listeners.items(): + + if listener.send_stop: + await self._send_stop_message(query_id) + listener.send_stop = False + + # Wait that there is no more listeners (we received 'complete' for all queries) + await asyncio.wait_for(self._no_more_listeners.wait(), timeout=5) + await self._send_connection_terminate_message() + + await self.websocket.close() + + await self._close_with_exception( + ConnectionClosedOK( + code=1000, reason="Websocket GraphQL transport closed by user" + ) + ) + + async def _close_with_exception(self, e: Exception) -> None: + """Coroutine called to close the transport if the underlaying websocket transport + has closed itself + + - send the exceptions to all current listeners + - remove the listen_loop task + """ if self.websocket and not self._is_closing: self._is_closing = True - try: - await self._send_connection_terminate_message() - await self.websocket.close() - except ConnectionClosedError: - pass + for query_id, listener in self.listeners.items(): - for query_id in self.listeners: - await self.listeners[query_id].put(("complete", None)) + await listener.set_exception(e) self.websocket = None diff --git a/setup.py b/setup.py index 7c743e95..39c99bb5 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ if sys.version_info > (3, 6): tests_require.append('pytest-asyncio>=0.9.0') + tests_require.append('parse>=1.6.0') install_requires.append('websockets>=8.1,<9') scripts.append('scripts/gql-cli') diff --git a/tests_py36/test_websockets_transport.py b/tests_py36/test_websocket_online.py similarity index 73% rename from tests_py36/test_websockets_transport.py rename to tests_py36/test_websocket_online.py index 3cf84cc2..bb503bdb 100644 --- a/tests_py36/test_websockets_transport.py +++ b/tests_py36/test_websocket_online.py @@ -1,36 +1,37 @@ import logging - -logging.basicConfig(level=logging.INFO) +import asyncio +import pytest +import sys from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport from graphql.execution import ExecutionResult from typing import Dict -import asyncio -import pytest -import sys +logging.basicConfig(level=logging.INFO) + @pytest.mark.asyncio async def test_websocket_simple_query(): # Get Websockets transport sample_transport = WebsocketsTransport( - url='wss://countries.trevorblades.com/graphql', - ssl=True + url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client async with AsyncClient(transport=sample_transport) as client: - query = gql(''' + query = gql( + """ query getContinents { continents { code name } } - ''') + """ + ) # Fetch schema await client.fetch_schema() @@ -40,17 +41,18 @@ async def test_websocket_simple_query(): # Verify result assert isinstance(result, ExecutionResult) - assert result.errors == None + assert result.errors is None assert isinstance(result.data, Dict) - continents = result.data['continents'] + continents = result.data["continents"] africa = continents[0] - assert africa['code'] == 'AF' + assert africa["code"] == "AF" + + print(sys.version_info) - print (sys.version_info) @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio @@ -58,54 +60,57 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): # Get Websockets transport sample_transport = WebsocketsTransport( - url='wss://countries.trevorblades.com/graphql', - ssl=True + url="wss://countries.trevorblades.com/graphql", ssl=True ) # Instanciate client async with AsyncClient(transport=sample_transport) as client: - query1 = gql(''' + query1 = gql( + """ query getContinents { continents { code } } - ''') + """ + ) - query2 = gql(''' + query2 = gql( + """ query getContinents { continents { name } } - ''') + """ + ) async def query_task1(): result = await client.execute(query1) assert isinstance(result, ExecutionResult) - assert result.errors == None + assert result.errors is None assert isinstance(result.data, Dict) - continents = result.data['continents'] + continents = result.data["continents"] africa = continents[0] - assert africa['code'] == 'AF' + assert africa["code"] == "AF" async def query_task2(): result = await client.execute(query2) assert isinstance(result, ExecutionResult) - assert result.errors == None + assert result.errors is None assert isinstance(result.data, Dict) - continents = result.data['continents'] + continents = result.data["continents"] africa = continents[0] - assert africa['name'] == 'Africa' + assert africa["name"] == "Africa" task1 = asyncio.create_task(query_task1()) task2 = asyncio.create_task(query_task2()) diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py new file mode 100644 index 00000000..9d448e21 --- /dev/null +++ b/tests_py36/test_websocket_query.py @@ -0,0 +1,165 @@ +import asyncio +import pytest +import websockets + +from .websocket_fixtures import server, client_and_server, TestServer +from graphql.execution import ExecutionResult +from gql.transport.websockets import WebsocketsTransport +from gql import gql, AsyncClient +from typing import Dict + + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' + '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},{{"code":"AS","name":"Asia"}},' + '{{"code":"EU","name":"Europe"}},{{"code":"NA","name":"North America"}},{{"code":"OC","name":"Oceania"}},' + '{{"code":"SA","name":"South America"}}]}}}}}}' +) + + +async def server1(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1], indirect=True) +async def test_websocket_starting_client_in_context_manager(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient(transport=sample_transport) as client: + + assert isinstance( + sample_transport.websocket, websockets.client.WebSocketClientProtocol + ) + + query1 = gql(query1_str) + + result = await client.execute(query1) + + assert isinstance(result, ExecutionResult) + + print("Client received: " + str(result.data)) + + # Verify result + assert result.errors is None + assert isinstance(result.data, Dict) + + continents = result.data["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert sample_transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_simple_query(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + result = await client.execute(query) + + print("Client received: " + str(result.data)) + + +async def server1_two_queries_in_series(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=2)) + await TestServer.send_complete(ws, 2) + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_two_queries_in_series], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_two_queries_in_series(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + result1 = await client.execute(query) + + print("Query1 received: " + str(result1.data)) + + result2 = await client.execute(query) + + print("Query2 received: " + str(result2.data)) + + assert str(result1.data) == str(result2.data) + + +async def server1_two_queries_in_parallel(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await ws.send(query1_server_answer.format(query_id=2)) + await TestServer.send_complete(ws, 1) + await TestServer.send_complete(ws, 2) + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_two_queries_in_parallel(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + result1 = None + result2 = None + + async def task1_coro(): + nonlocal result1 + result1 = await client.execute(query) + + async def task2_coro(): + nonlocal result2 + result2 = await client.execute(query) + + task1 = asyncio.ensure_future(task1_coro()) + task2 = asyncio.ensure_future(task2_coro()) + + await asyncio.gather(task1, task2) + + print("Query1 received: " + str(result1.data)) + print("Query2 received: " + str(result2.data)) + + assert str(result1.data) == str(result2.data) diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py new file mode 100644 index 00000000..c404784e --- /dev/null +++ b/tests_py36/test_websocket_subscription.py @@ -0,0 +1,291 @@ +import asyncio +import pytest +import json +import websockets + +from parse import search +from .websocket_fixtures import MS, server, client_and_server, TestServer +from graphql.execution import ExecutionResult +from gql import gql + + +countdown_server_answer = ( + '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' +) + + +async def server_countdown(ws, path): + try: + await TestServer.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + print(f"Countdown started from: {count}") + + async def counting_coro(): + for number in range(count, -1, -1): + await ws.send( + countdown_server_answer.format(query_id=query_id, number=number) + ) + await asyncio.sleep(2 * MS) + + counting_task = asyncio.ensure_future(counting_coro()) + + async def stopping_coro(): + nonlocal counting_task + while True: + + result = await ws.recv() + json_result = json.loads(result) + + if json_result["type"] == "stop" and json_result["id"] == str(query_id): + print("Cancelling counting task now") + counting_task.cancel() + + stopping_task = asyncio.ensure_future(stopping_coro()) + + try: + await counting_task + except asyncio.CancelledError: + print("Now counting task is cancelled") + + stopping_task.cancel() + + try: + await stopping_task + except asyncio.CancelledError: + print("Now stopping task is cancelled") + + await TestServer.send_complete(ws, query_id) + await TestServer.wait_connection_terminate(ws) + except websockets.exceptions.ConnectionClosedOK: + pass + finally: + await ws.wait_closed() + + +countdown_subscription_str = """ + subscription {{ + countdown (count: {count}) {{ + number + }} + }} +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription(client_and_server, subscription_str): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_break(client_and_server, subscription_str): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + + if count <= 5: + break + + count -= 1 + + #await asyncio.sleep(1) + + assert count == 5 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_task_cancel(client_and_server, subscription_str): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def cancel_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + task.cancel() + + cancel_task = asyncio.ensure_future(cancel_task_coro()) + + await asyncio.gather(task, cancel_task) + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_close_transport( + client_and_server, subscription_str +): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async def task_coro(): + nonlocal count + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + task = asyncio.ensure_future(task_coro()) + + async def close_transport_task_coro(): + nonlocal task + + await asyncio.sleep(11 * MS) + + await client.transport.close() + + close_transport_task = asyncio.ensure_future(close_transport_task_coro()) + + # with pytest.raises(websockets.exceptions.ConnectionClosedOK): + await asyncio.gather(task, close_transport_task) + + assert count > 0 + + +async def server_countdown_close_connection_in_middle(ws, path): + await TestServer.send_connection_ack(ws) + + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "start" + payload = json_result["payload"] + query = payload["query"] + query_id = json_result["id"] + + count_found = search("count: {:d}", query) + count = count_found[0] + stopping_before = count // 2 + print(f"Countdown started from: {count}, stopping server before {stopping_before}") + for number in range(count, stopping_before, -1): + await ws.send(countdown_server_answer.format(query_id=query_id, number=number)) + await asyncio.sleep(2 * MS) + + print("Closing server while subscription is still running now") + await ws.close() + await ws.wait_closed() + print("Server is now closed") + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_countdown_close_connection_in_middle], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_server_connection_closed( + client_and_server, subscription_str +): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + with pytest.raises(websockets.exceptions.ConnectionClosedOK): + + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + assert count > 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_countdown], indirect=True +) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_slow_consumer( + client_and_server, subscription_str +): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in client.subscribe(subscription): + await asyncio.sleep(10 * MS) + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + + count -= 1 + + assert count == -1 + diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py new file mode 100644 index 00000000..7c24bfca --- /dev/null +++ b/tests_py36/websocket_fixtures.py @@ -0,0 +1,111 @@ +import websockets +import asyncio +import json +import os +import pytest +import logging + +from gql.transport.websockets import WebsocketsTransport +from gql import AsyncClient + +# Adding debug logs to websocket tests +for name in ["websockets.server", "gql.transport.websockets"]: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + if len(logger.handlers) < 1: + logger.addHandler(logging.StreamHandler()) + +# Unit for timeouts. May be increased on slow machines by setting the +# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) + + +class TestServer: + """ + Class used to generate a websocket server on localhost on a free port + + Will allow us to test our client by simulating different correct and incorrect server responses + """ + + async def start(self, handler): + + print("Starting server") + + # Start a server with a random open port + self.start_server = websockets.server.serve(handler, "localhost", 0) + + # Wait that the server is started + self.server = await self.start_server + + # Get hostname and port + hostname, port = self.server.sockets[0].getsockname() + + self.hostname = hostname + self.port = port + + print(f"Server started on port {port}") + + async def stop(self): + print("Stopping server") + + self.server.close() + try: + await asyncio.wait_for(self.server.wait_closed(), timeout=1) + except asyncio.TimeoutError: # pragma: no cover + assert False, "Server failed to stop" + + print("Server stopped\n\n\n") + + @staticmethod + async def send_complete(ws, query_id): + await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') + + @staticmethod + async def send_connection_ack(ws): + + # Line return for easy debugging + print("") + + # Wait for init + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "connection_init" + + # Send ack + await ws.send('{"type":"connection_ack"}') + + @staticmethod + async def wait_connection_terminate(ws): + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "connection_terminate" + + +@pytest.fixture +async def server(request): + try: + test_server = TestServer() + + # Starting the server with the fixture param as the handler function + await test_server.start(request.param) + + yield test_server + except Exception as e: + print("Exception received in server fixture: " + str(e)) + finally: + await test_server.stop() + + +@pytest.fixture +async def client_and_server(server): + + # Generate transport to connect to the server fixture + path = "/graphql" + url = "ws://" + server.hostname + ":" + str(server.port) + path + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient(transport=sample_transport) as client: + + # Yield both client and server + yield (client, server) From ece46e89ab5fbe33930a47b793520ece01910f5c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 9 Apr 2020 00:32:58 +0200 Subject: [PATCH 17/46] Fix asyncio.wait_for for pypy3 ? --- gql/transport/websockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index ed15f12c..1e09e235 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -407,7 +407,7 @@ async def close(self) -> None: listener.send_stop = False # Wait that there is no more listeners (we received 'complete' for all queries) - await asyncio.wait_for(self._no_more_listeners.wait(), timeout=5) + await asyncio.wait_for(self._no_more_listeners.wait(), 10) await self._send_connection_terminate_message() From 6c90378a451ced9a152d02dacca160816b46e386 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 9 Apr 2020 00:41:56 +0200 Subject: [PATCH 18/46] Fix asyncio.wait_for for pypy3 ? - try 2 --- gql/transport/websockets.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 1e09e235..50361ddf 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -407,7 +407,10 @@ async def close(self) -> None: listener.send_stop = False # Wait that there is no more listeners (we received 'complete' for all queries) - await asyncio.wait_for(self._no_more_listeners.wait(), 10) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), 10) + except asyncio.TimeoutError: + pass await self._send_connection_terminate_message() From e0cf9bb1013bd1ba103495c6fb9aa359395db43d Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 13 Apr 2020 16:45:32 +0200 Subject: [PATCH 19/46] Better managenement of edge cases and 100% websocket transport coverage Adding exceptions: - TransportProtocolError - TransportQueryError - TransportServerError - TransportClosed Now 100% code coverage for transport/websockets.py Improvement of 'server' test fixture to simplify tests (list only answers) --- gql/transport/async_transport.py | 8 +- gql/transport/exceptions.py | 29 +++ gql/transport/transport.py | 2 +- gql/transport/websockets.py | 231 +++++++++++++-------- tests_py36/test_websocket_exceptions.py | 232 ++++++++++++++++++++++ tests_py36/test_websocket_online.py | 145 +++++++++++++- tests_py36/test_websocket_query.py | 107 +++++++--- tests_py36/test_websocket_subscription.py | 50 ++++- tests_py36/websocket_fixtures.py | 50 ++++- 9 files changed, 733 insertions(+), 121 deletions(-) create mode 100644 gql/transport/exceptions.py create mode 100644 tests_py36/test_websocket_exceptions.py diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 2d4107cf..81aada2d 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -15,7 +15,7 @@ async def connect(self): """ raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" - ) + ) # pragma: no cover @abc.abstractmethod async def close(self): @@ -23,7 +23,7 @@ async def close(self): """ raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" - ) + ) # pragma: no cover @abc.abstractmethod async def execute( @@ -36,7 +36,7 @@ async def execute( """ raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" - ) + ) # pragma: no cover @abc.abstractmethod def subscribe( @@ -53,4 +53,4 @@ def subscribe( """ raise NotImplementedError( "Any AsyncTransport subclass must implement execute method" - ) + ) # pragma: no cover diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py new file mode 100644 index 00000000..d1ab4e9b --- /dev/null +++ b/gql/transport/exceptions.py @@ -0,0 +1,29 @@ +class TransportError(Exception): + pass + + +class TransportProtocolError(TransportError): + """ An answer received from the server does not correspond to the transport protocol""" + + +class TransportServerError(TransportError): + """ The server returned a global error + + This exception will close the transport connection + """ + + +class TransportQueryError(Exception): + """ The server returned an error for a specific query + + This exception should not close the transport connection + """ + + def __init__(self, msg, query_id=None): + super().__init__(msg) + self.query_id = query_id + + +class TransportClosed(TransportError): + """ Exception generated when the client is trying to use the transport + while the transport was previously closed """ diff --git a/gql/transport/transport.py b/gql/transport/transport.py index 89c636a0..476f8be2 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -19,4 +19,4 @@ def execute(self, document): """ raise NotImplementedError( "Any Transport subclass must implement execute method" - ) + ) # pragma: no cover diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 50361ddf..0ae96274 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -4,7 +4,7 @@ from websockets.http import HeadersLike from websockets.typing import Data, Subprotocol from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosedOK, ConnectionClosed +from websockets.exceptions import ConnectionClosed from ssl import SSLContext @@ -19,6 +19,12 @@ from graphql.language.printer import print_ast from .async_transport import AsyncTransport +from .exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, + TransportClosed, +) log = logging.getLogger(__name__) @@ -98,45 +104,48 @@ def __init__( self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} self._is_closing: bool = False + + self.listen_loop: Optional[asyncio.Future] = None + self._no_more_listeners: asyncio.Event = asyncio.Event() + self._no_more_listeners.set() + + self.close_exception: Optional[Exception] = None async def _send(self, message: str) -> None: """Send the provided message to the websocket connection and log the message """ if not self.websocket: - raise Exception("Transport is not connected") + raise TransportClosed( + "Transport is not connected" + ) from self.close_exception try: await self.websocket.send(message) log.info(">>> %s", message) except (ConnectionClosed) as e: - await self._close_with_exception(e) + await self._close(e, clean_close=False) raise e async def _receive(self) -> str: """Wait the next message from the websocket connection and log the answer """ - answer: Optional[str] = None + # We should always have an active websocket connection here + assert self.websocket is not None - if not self.websocket: - raise Exception("Transport is not connected") + # Wait for the next websocket frame. Can raise ConnectionClosed + data: Data = await self.websocket.recv() - try: - data: Data = await self.websocket.recv() - - # websocket.recv() can return either str or bytes - # In our case, we should receive only str here - if not isinstance(data, str): - raise Exception("Binary data received in the websocket") + # websocket.recv() can return either str or bytes + # In our case, we should receive only str here + if not isinstance(data, str): + raise TransportProtocolError("Binary data received in the websocket") - answer = data + answer: str = data - log.info("<<< %s", answer) - except ConnectionClosed as e: - await self._close_with_exception(e) - raise e + log.info("<<< %s", answer) return answer @@ -153,7 +162,9 @@ async def _send_init_message_and_wait_ack(self) -> None: answer_type, answer_id, execution_result = self._parse_answer(init_answer) if answer_type != "connection_ack": - raise Exception("Websocket server did not return a connection ack") + raise TransportProtocolError( + "Websocket server did not return a connection ack" + ) async def _send_stop_message(self, query_id: int) -> None: """Send a stop message to the provided websocket connection for the provided query_id @@ -225,29 +236,32 @@ def _parse_answer( try: json_answer = json.loads(answer) - if not isinstance(json_answer, dict): - raise ValueError - answer_type = str(json_answer.get("type")) if answer_type in ["data", "error", "complete"]: answer_id = int(str(json_answer.get("id"))) - if answer_type == "data": - result = json_answer.get("payload") + if answer_type == "data" or answer_type == "error": - if not isinstance(result, Dict): - raise ValueError + payload = json_answer.get("payload") - if "errors" not in result and "data" not in result: - raise ValueError + if not isinstance(payload, dict): + raise ValueError("payload is not a dict") - execution_result = ExecutionResult( - errors=result.get("errors"), data=result.get("data") - ) + if answer_type == "data": + + if "errors" not in payload and "data" not in payload: + raise ValueError( + "payload does not contain 'data' or 'errors' fields" + ) + + execution_result = ExecutionResult( + errors=payload.get("errors"), data=payload.get("data") + ) - elif answer_type == "error": - raise Exception("Websocket server error") + elif answer_type == "error": + + raise TransportQueryError(str(payload), query_id=answer_id) elif answer_type == "ka": # KeepAlive message @@ -255,34 +269,66 @@ def _parse_answer( elif answer_type == "connection_ack": pass elif answer_type == "connection_error": - raise Exception("Websocket Connection Error") + error_payload = json_answer.get("payload") + raise TransportServerError(f"Server error: '{repr(error_payload)}'") else: raise ValueError - except ValueError: - raise Exception("Websocket server did not return a GraphQL result") + except ValueError as e: + raise TransportProtocolError( + "Server did not return a GraphQL result" + ) from e return (answer_type, answer_id, execution_result) async def _answer_loop(self) -> None: - while True: + try: + while True: - # Wait the next answer from the websocket server - try: - answer = await self._receive() - except ConnectionClosed: - return + # Wait the next answer from the websocket server + try: + answer = await self._receive() + except (ConnectionClosed, TransportProtocolError) as e: + await self._close(e, clean_close=False) + break - # Parse the answer - answer_type, answer_id, execution_result = self._parse_answer(answer) + # Parse the answer + try: + answer_type, answer_id, execution_result = self._parse_answer( + answer + ) + except TransportQueryError as e: + # Received an exception for a specific query + # ==> Add an exception to this query queue + # The exception is raised for this specific query but the transport is not closed + try: + await self.listeners[e.query_id].set_exception(e) + except KeyError: + # Do nothing if no one is listening to this query_id + pass + + continue + + except (TransportServerError, TransportProtocolError) as e: + # Received a global exception for this transport + # ==> close the transport + # The exception will be raised for all current queries + await self._close(e, clean_close=False) + break - # Continue if no listener exists for this id - if answer_id not in self.listeners: - continue + try: + # Put the answer in the queue + if answer_id is not None: + await self.listeners[answer_id].put( + (answer_type, execution_result) + ) + except KeyError: + # Do nothing if no one is listening to this query_id + pass - # Put the answer in the queue - await self.listeners[answer_id].put((answer_type, execution_result)) + finally: + log.debug("Exiting _answer_loop()") async def subscribe( self, @@ -307,11 +353,15 @@ async def subscribe( listener = ListenerQueue(query_id, send_stop=(send_stop is True)) self.listeners[query_id] = listener + # We will need to wait at close for this query to clean properly + self._no_more_listeners.clear() + try: # Loop over the received answers while True: # Wait for the answer from the queue of this query_id + # This can raise a TransportError exception or a ConnectionClosed exception answer_type, execution_result = await listener.get() # If the received answer contains data, @@ -348,10 +398,18 @@ async def execute( The result is sent as an ExecutionResult object """ + first_result = None + async for result in self.subscribe( document, variable_values, operation_name, send_stop=False ): first_result = result + break + + if first_result is None: + raise TransportQueryError( + "Query completed without any answer received from the server" + ) return first_result @@ -378,65 +436,70 @@ async def connect(self) -> None: subprotocols=[GRAPHQLWS_SUBPROTOCOL], ) - # Reset the next query id self.next_query_id = 1 + self.close_exception = None # Send the init message and wait for the ack from the server - await self._send_init_message_and_wait_ack() + try: + await self._send_init_message_and_wait_ack() + except ConnectionClosed as e: + raise e + except TransportProtocolError as e: + await self._close(e, clean_close=False) + raise e # Create a task to listen to the incoming websocket messages self.listen_loop = asyncio.ensure_future(self._answer_loop()) - async def close(self) -> None: + async def _clean_close(self, e: Exception) -> None: """Coroutine which will: - - send stop messages for each active query to the server + - send stop messages for each active subscription to the server - send the connection terminate message - - close the websocket connection - - - send the exceptions to all current listeners - - remove the listen_loop task """ - if self.websocket and not self._is_closing: - - # Send stop message for all current queries - for query_id, listener in self.listeners.items(): - if listener.send_stop: - await self._send_stop_message(query_id) - listener.send_stop = False + # Send stop message for all current queries + for query_id, listener in self.listeners.items(): - # Wait that there is no more listeners (we received 'complete' for all queries) - try: - await asyncio.wait_for(self._no_more_listeners.wait(), 10) - except asyncio.TimeoutError: - pass - - await self._send_connection_terminate_message() + if listener.send_stop: + await self._send_stop_message(query_id) + listener.send_stop = False - await self.websocket.close() + # Wait that there is no more listeners (we received 'complete' for all queries) + try: + await asyncio.wait_for(self._no_more_listeners.wait(), 10) + except asyncio.TimeoutError: # pragma: no cover + pass - await self._close_with_exception( - ConnectionClosedOK( - code=1000, reason="Websocket GraphQL transport closed by user" - ) - ) + await self._send_connection_terminate_message() - async def _close_with_exception(self, e: Exception) -> None: - """Coroutine called to close the transport if the underlaying websocket transport - has closed itself + async def _close(self, e: Exception, clean_close: bool = True) -> None: + """Coroutine which will: - - send the exceptions to all current listeners - - remove the listen_loop task + - do a clean_close if possible: + - send stop messages for each active query to the server + - send the connection terminate message + - close the websocket connection + - send the exception to all the remaining listeners """ if self.websocket and not self._is_closing: self._is_closing = True - for query_id, listener in self.listeners.items(): + # Saving exception to raise it later if trying to use the transport after it has closed + self.close_exception = e + if clean_close: + await self._clean_close(e) + + # Send an exception to all remaining listeners + for query_id, listener in self.listeners.items(): await listener.set_exception(e) + await self.websocket.close() + self.websocket = None - self.listen_loop.cancel() + async def close(self) -> None: + + await self._close(TransportClosed("Websocket GraphQL transport closed by user")) diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py new file mode 100644 index 00000000..87e983bb --- /dev/null +++ b/tests_py36/test_websocket_exceptions.py @@ -0,0 +1,232 @@ +import asyncio +import pytest +import json +import websockets +import types +import gql + +from parse import search +from .websocket_fixtures import MS, server, client_and_server, TestServer +from graphql.execution import ExecutionResult +from gql import gql, AsyncClient +from gql.transport.websockets import WebsocketsTransport +from gql.transport.exceptions import ( + TransportProtocolError, + TransportQueryError, + TransportServerError, + TransportClosed, +) + + +invalid_query_str = """ + query getContinents { + continents { + code + bloh + } + } +""" + +invalid_query1_server_answer = ( + '{{"type":"data","id":"{query_id}",' + '"payload":{{"errors":[{{"message":"Cannot query field \\"bloh\\" on type \\"Continent\\".",' + '"locations":[{{"line":4,"column":5}}],"extensions":{{"code":"INTERNAL_SERVER_ERROR"}}}}]}}}}' +) + +invalid_query1_server = [ + invalid_query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server,], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_websocket_invalid_query(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + result = await client.execute(query) + + print("Client received: " + str(result.data)) + + assert isinstance(result, ExecutionResult) + + print(f"result = {repr(result.data)}, {repr(result.errors)}") + + assert result.data is None + assert result.errors is not None + + +connection_error_server_answer = ( + '{"type":"connection_error","id":null,' + '"payload":{"message":"Unexpected token Q in JSON at position 0"}}' +) + + +async def server_connection_error(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(connection_error_server_answer) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_connection_error], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_websocket_sending_invalid_data(client_and_server, query_str): + + client, server = client_and_server + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await client.transport.websocket.send(invalid_data) + + await asyncio.sleep(2 * MS) + + +invalid_payload_server_answer = ( + '{"type":"error","id":"1","payload":{"message":"Must provide document"}}' +) + + +async def server_invalid_payload(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(invalid_payload_server_answer) + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_websocket_sending_invalid_payload(client_and_server, query_str): + + client, server = client_and_server + + # Monkey patching the _send_query method to send an invalid payload + + async def monkey_patch_send_query( + self, document, variable_values=None, operation_name=None, + ) -> int: + query_id = self.next_query_id + self.next_query_id += 1 + + query_str = json.dumps( + {"id": str(query_id), "type": "start", "payload": "BLAHBLAH",} + ) + + await self._send(query_str) + return query_id + + client.transport._send_query = types.MethodType( + monkey_patch_send_query, client.transport + ) + + query = gql(query_str) + + with pytest.raises(TransportQueryError): + result = await client.execute(query) + + +not_json_answer = ["BLAHBLAH"] +missing_type_answer = ["{}"] +missing_id_answer_1 = ['{"type": "data"}'] +missing_id_answer_2 = ['{"type": "error"}'] +missing_id_answer_3 = ['{"type": "complete"}'] +data_without_payload = ['{"type": "data", "id":"1"}'] +error_without_payload = ['{"type": "error", "id":"1"}'] +payload_is_not_a_dict = ['{"type": "data", "id":"1", "payload": "BLAH"}'] +empty_payload = ['{"type": "data", "id":"1", "payload": {}}'] +sending_bytes = [b"\x01\x02\x03"] + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", + [ + not_json_answer, + missing_type_answer, + missing_id_answer_1, + missing_id_answer_2, + missing_id_answer_3, + data_without_payload, + error_without_payload, + payload_is_not_a_dict, + empty_payload, + sending_bytes, + ], + indirect=True, +) +async def test_websocket_transport_protocol_errors(client_and_server): + + client, server = client_and_server + + query = gql("query { hello }") + + with pytest.raises(TransportProtocolError): + result = await client.execute(query) + + +async def server_without_ack(ws, path): + # Sending something else than an ack + await TestServer.send_keepalive(ws) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_without_ack,], indirect=True) +async def test_websocket_server_does_not_ack(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(TransportProtocolError): + async with AsyncClient(transport=sample_transport) as client: + + pass + + +async def server_closing_directly(ws, path): + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_directly,], indirect=True) +async def test_websocket_server_closing_directly(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(websockets.exceptions.ConnectionClosed): + async with AsyncClient(transport=sample_transport) as client: + + pass + + +async def server_closing_after_ack(ws, path): + await TestServer.send_connection_ack(ws) + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_closing_after_ack,], indirect=True) +async def test_websocket_server_closing_after_ack(client_and_server): + + client, server = client_and_server + + query = gql("query { hello }") + + with pytest.raises(websockets.exceptions.ConnectionClosed): + result = await client.execute(query) + + with pytest.raises(TransportClosed): + result = await client.execute(query) diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index bb503bdb..c13bf18a 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -5,8 +5,10 @@ from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport +from gql.transport.exceptions import TransportError from graphql.execution import ExecutionResult from typing import Dict +from .websocket_fixtures import MS logging.basicConfig(level=logging.INFO) @@ -51,7 +53,148 @@ async def test_websocket_simple_query(): assert africa["code"] == "AF" - print(sys.version_info) + +@pytest.mark.asyncio +async def test_websocket_invalid_query(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url="wss://countries.trevorblades.com/graphql", ssl=True + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query = gql( + """ + query getContinents { + continents { + code + bloh + } + } + """ + ) + + # Execute query + result = await client.execute(query) + + # Verify result + assert isinstance(result, ExecutionResult) + + assert result.data is None + + print(f"result = {repr(result.data)}, {repr(result.errors)}") + assert result.errors is not None + + +@pytest.mark.asyncio +async def test_websocket_sending_invalid_data(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url="wss://countries.trevorblades.com/graphql", ssl=True + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query = gql( + """ + query getContinents { + continents { + code + } + } + """ + ) + + # Execute query + result = await client.execute(query) + + # Verify result + assert isinstance(result, ExecutionResult) + + print(f"result = {repr(result.data)}, {repr(result.errors)}") + + assert result.errors is None + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await sample_transport.websocket.send(invalid_data) + + await asyncio.sleep(2) + + +@pytest.mark.asyncio +async def test_websocket_sending_invalid_payload(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url="wss://countries.trevorblades.com/graphql", ssl=True + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' + + print(f">>> {invalid_payload}") + await sample_transport.websocket.send(invalid_payload) + + await asyncio.sleep(2) + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.asyncio +async def test_websocket_sending_invalid_data_while_other_query_is_running(): + + # Get Websockets transport + sample_transport = WebsocketsTransport( + url="wss://countries.trevorblades.com/graphql", ssl=True + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query = gql( + """ + query getContinents { + continents { + code + } + } + """ + ) + + async def query_task1(): + await asyncio.sleep(2 * MS) + + with pytest.raises(TransportError): + result = await client.execute(query) + + assert isinstance(result, ExecutionResult) + assert result.errors is None + + assert isinstance(result.data, Dict) + + continents = result.data["continents"] + + africa = continents[0] + assert africa["code"] == "AF" + + async def query_task2(): + + invalid_data = "QSDF" + print(f">>> {invalid_data}") + await sample_transport.websocket.send(invalid_data) + + task1 = asyncio.create_task(query_task1()) + task2 = asyncio.create_task(query_task2()) + + # await task1 + # await task2 + await asyncio.gather(task1, task2) @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 9d448e21..db24e1d4 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -2,9 +2,10 @@ import pytest import websockets -from .websocket_fixtures import server, client_and_server, TestServer +from .websocket_fixtures import MS, server, client_and_server, TestServer from graphql.execution import ExecutionResult from gql.transport.websockets import WebsocketsTransport +from gql.transport.exceptions import TransportClosed, TransportQueryError from gql import gql, AsyncClient from typing import Dict @@ -25,19 +26,13 @@ '{{"code":"SA","name":"South America"}}]}}}}}}' ) - -async def server1(ws, path): - await TestServer.send_connection_ack(ws) - result = await ws.recv() - print(f"Server received: {result}") - await ws.send(query1_server_answer.format(query_id=1)) - await TestServer.send_complete(ws, 1) - await TestServer.wait_connection_terminate(ws) - await ws.wait_closed() +server1_answers = [ + query1_server_answer, +] @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1], indirect=True) +@pytest.mark.parametrize("server", [server1_answers,], indirect=True) async def test_websocket_starting_client_in_context_manager(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -73,7 +68,7 @@ async def test_websocket_starting_client_in_context_manager(server): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1], indirect=True) +@pytest.mark.parametrize("server", [server1_answers,], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_simple_query(client_and_server, query_str): @@ -86,22 +81,14 @@ async def test_websocket_simple_query(client_and_server, query_str): print("Client received: " + str(result.data)) -async def server1_two_queries_in_series(ws, path): - await TestServer.send_connection_ack(ws) - result = await ws.recv() - print(f"Server received: {result}") - await ws.send(query1_server_answer.format(query_id=1)) - await TestServer.send_complete(ws, 1) - result = await ws.recv() - print(f"Server received: {result}") - await ws.send(query1_server_answer.format(query_id=2)) - await TestServer.send_complete(ws, 2) - await TestServer.wait_connection_terminate(ws) - await ws.wait_closed() +server1_two_answers_in_series = [ + query1_server_answer, + query1_server_answer, +] @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_two_queries_in_series], indirect=True) +@pytest.mark.parametrize("server", [server1_two_answers_in_series,], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_two_queries_in_series(client_and_server, query_str): @@ -163,3 +150,73 @@ async def task2_coro(): print("Query2 received: " + str(result2.data)) assert str(result1.data) == str(result2.data) + + +async def server_closing_while_we_are_doing_something_else(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) + await asyncio.sleep(1 * MS) + + # Closing server after first query + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_closing_while_we_are_doing_something_else,], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_server_closing_after_first_query(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + # First query is working + result = await client.execute(query) + + assert isinstance(result, ExecutionResult) + assert result.data is not None + assert result.errors is None + + # Then we do other things + await asyncio.sleep(2 * MS) + await asyncio.sleep(2 * MS) + await asyncio.sleep(2 * MS) + + # Now the server is closed but we don't know it yet, we have to send a query + # to notice it and to receive the exception + with pytest.raises(TransportClosed): + result = await client.execute(query) + + +ignore_invalid_id_answers = [ + query1_server_answer, + '{"type":"complete","id": "55"}', + query1_server_answer, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [ignore_invalid_id_answers,], indirect=True) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_ignore_invalid_id(client_and_server, query_str): + + client, server = client_and_server + + query = gql(query_str) + + # First query is working + result = await client.execute(query) + assert isinstance(result, ExecutionResult) + + # Second query gets no answer -> raises + with pytest.raises(TransportQueryError): + result = await client.execute(query) + + # Third query is working + result = await client.execute(query) + assert isinstance(result, ExecutionResult) diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index c404784e..4c3984a6 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -13,10 +13,15 @@ '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' ) +WITH_KEEPALIVE = False + async def server_countdown(ws, path): + global WITH_KEEPALIVE try: await TestServer.send_connection_ack(ws) + if WITH_KEEPALIVE: + await TestServer.send_keepalive(ws) result = await ws.recv() json_result = json.loads(result) @@ -49,7 +54,13 @@ async def stopping_coro(): print("Cancelling counting task now") counting_task.cancel() + async def keepalive_coro(): + while True: + await asyncio.sleep(5 * MS) + await TestServer.send_keepalive(ws) + stopping_task = asyncio.ensure_future(stopping_coro()) + keepalive_task = asyncio.ensure_future(keepalive_coro()) try: await counting_task @@ -63,6 +74,13 @@ async def stopping_coro(): except asyncio.CancelledError: print("Now stopping task is cancelled") + if WITH_KEEPALIVE: + keepalive_task.cancel() + try: + await keepalive_task + except asyncio.CancelledError: + print("Now keepalive task is cancelled") + await TestServer.send_complete(ws, query_id) await TestServer.wait_connection_terminate(ws) except websockets.exceptions.ConnectionClosedOK: @@ -125,8 +143,6 @@ async def test_websocket_subscription_break(client_and_server, subscription_str) count -= 1 - #await asyncio.sleep(1) - assert count == 5 @@ -263,9 +279,7 @@ async def test_websocket_subscription_server_connection_closed( @pytest.mark.asyncio -@pytest.mark.parametrize( - "server", [server_countdown], indirect=True -) +@pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_slow_consumer( client_and_server, subscription_str @@ -289,3 +303,29 @@ async def test_websocket_subscription_slow_consumer( assert count == -1 + +WITH_KEEPALIVE = True + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_websocket_subscription_with_keepalive( + client_and_server, subscription_str +): + + client, server = client_and_server + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + async for result in client.subscribe(subscription): + assert isinstance(result, ExecutionResult) + + number = result.data["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py index 7c24bfca..c00fc670 100644 --- a/tests_py36/websocket_fixtures.py +++ b/tests_py36/websocket_fixtures.py @@ -4,8 +4,10 @@ import os import pytest import logging +import types from gql.transport.websockets import WebsocketsTransport +from websockets.exceptions import ConnectionClosed from gql import AsyncClient # Adding debug logs to websocket tests @@ -18,6 +20,7 @@ # Unit for timeouts. May be increased on slow machines by setting the # WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +# Copied from websockets source MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) @@ -61,6 +64,10 @@ async def stop(self): async def send_complete(ws, query_id): await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') + @staticmethod + async def send_keepalive(ws): + await ws.send('{"type":"ka"}') + @staticmethod async def send_connection_ack(ws): @@ -84,11 +91,51 @@ async def wait_connection_terminate(ws): @pytest.fixture async def server(request): + """server is a fixture used to start a dummy server to test the client behaviour. + + It can take as argument either a handler function for the websocket server for complete control + OR an array of answers to be sent by the default server handler + """ + + if isinstance(request.param, types.FunctionType): + server_handler = request.param + + else: + answers = request.param + + async def default_server_handler(ws, path): + + try: + await TestServer.send_connection_ack(ws) + query_id = 1 + + for answer in answers: + result = await ws.recv() + print(f"Server received: {result}") + + if isinstance(answer, str) and "{query_id}" in answer: + answer_format_params = {} + answer_format_params["query_id"] = query_id + formatted_answer = answer.format(**answer_format_params) + else: + formatted_answer = answer + + await ws.send(formatted_answer) + await TestServer.send_complete(ws, query_id) + query_id += 1 + + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + except ConnectionClosed: + pass + + server_handler = default_server_handler + try: test_server = TestServer() # Starting the server with the fixture param as the handler function - await test_server.start(request.param) + await test_server.start(server_handler) yield test_server except Exception as e: @@ -99,6 +146,7 @@ async def server(request): @pytest.fixture async def client_and_server(server): + """client_and_server is a helper fixture to start a server and a client connected to its port""" # Generate transport to connect to the server fixture path = "/graphql" From 705edd1c6a1ec8d31649f74d42276d556465d615 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 19 Apr 2020 15:54:55 +0200 Subject: [PATCH 20/46] Close is now done in a shielded task Added _wait_closed event and wait_close coroutine to wait for the close task to finish Add exception 'TransportAlreadyConnected' used when trying to connect to an already connected transport Added new tests to test multiple connections in series or in parallel Refactor: - rename _close to _fail - rename _answer_loop to _receive_data_loop - rename listen_loop to receive_data_task --- gql/transport/exceptions.py | 5 ++ gql/transport/websockets.py | 53 ++++++++++------ tests_py36/test_websocket_exceptions.py | 2 + tests_py36/test_websocket_query.py | 81 ++++++++++++++++++++++++- 4 files changed, 123 insertions(+), 18 deletions(-) diff --git a/gql/transport/exceptions.py b/gql/transport/exceptions.py index d1ab4e9b..9a00dbcc 100644 --- a/gql/transport/exceptions.py +++ b/gql/transport/exceptions.py @@ -27,3 +27,8 @@ def __init__(self, msg, query_id=None): class TransportClosed(TransportError): """ Exception generated when the client is trying to use the transport while the transport was previously closed """ + + +class TransportAlreadyConnected(TransportError): + """ Exception generated when the client is trying to connect to the transport + while the transport is already connected """ diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 0ae96274..01b76d71 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -24,6 +24,7 @@ TransportQueryError, TransportServerError, TransportClosed, + TransportAlreadyConnected, ) log = logging.getLogger(__name__) @@ -103,9 +104,12 @@ def __init__( self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} - self._is_closing: bool = False - self.listen_loop: Optional[asyncio.Future] = None + self.receive_data_task: Optional[asyncio.Future] = None + self.close_task: Optional[asyncio.Future] = None + + self._wait_closed: asyncio.Event = asyncio.Event() + self._wait_closed.set() self._no_more_listeners: asyncio.Event = asyncio.Event() self._no_more_listeners.set() @@ -125,7 +129,7 @@ async def _send(self, message: str) -> None: await self.websocket.send(message) log.info(">>> %s", message) except (ConnectionClosed) as e: - await self._close(e, clean_close=False) + await self._fail(e, clean_close=False) raise e async def _receive(self) -> str: @@ -281,7 +285,7 @@ def _parse_answer( return (answer_type, answer_id, execution_result) - async def _answer_loop(self) -> None: + async def _receive_data_loop(self) -> None: try: while True: @@ -290,7 +294,7 @@ async def _answer_loop(self) -> None: try: answer = await self._receive() except (ConnectionClosed, TransportProtocolError) as e: - await self._close(e, clean_close=False) + await self._fail(e, clean_close=False) break # Parse the answer @@ -314,7 +318,7 @@ async def _answer_loop(self) -> None: # Received a global exception for this transport # ==> close the transport # The exception will be raised for all current queries - await self._close(e, clean_close=False) + await self._fail(e, clean_close=False) break try: @@ -328,7 +332,7 @@ async def _answer_loop(self) -> None: pass finally: - log.debug("Exiting _answer_loop()") + log.debug("Exiting _receive_data_loop()") async def subscribe( self, @@ -365,12 +369,12 @@ async def subscribe( answer_type, execution_result = await listener.get() # If the received answer contains data, - # Then we will yield the results back as an ExecutionResult object + # Then we will yield the results back as an ExecutionResult object if execution_result is not None: yield execution_result # If we receive a 'complete' answer from the server, - # Then we will end this async generator output and disconnect from the server + # Then we will end this async generator output without errors elif answer_type == "complete": log.debug( f"Complete received for query {query_id} --> exit without error" @@ -438,6 +442,7 @@ async def connect(self) -> None: self.next_query_id = 1 self.close_exception = None + self._wait_closed.clear() # Send the init message and wait for the ack from the server try: @@ -445,11 +450,14 @@ async def connect(self) -> None: except ConnectionClosed as e: raise e except TransportProtocolError as e: - await self._close(e, clean_close=False) + await self._fail(e, clean_close=False) raise e # Create a task to listen to the incoming websocket messages - self.listen_loop = asyncio.ensure_future(self._answer_loop()) + self.receive_data_task = asyncio.ensure_future(self._receive_data_loop()) + + else: + raise TransportAlreadyConnected("Transport is already connected") async def _clean_close(self, e: Exception) -> None: """Coroutine which will: @@ -458,7 +466,7 @@ async def _clean_close(self, e: Exception) -> None: - send the connection terminate message """ - # Send stop message for all current queries + # Send 'stop' message for all current queries for query_id, listener in self.listeners.items(): if listener.send_stop: @@ -471,9 +479,10 @@ async def _clean_close(self, e: Exception) -> None: except asyncio.TimeoutError: # pragma: no cover pass + # Finally send the 'connection_terminate' message await self._send_connection_terminate_message() - async def _close(self, e: Exception, clean_close: bool = True) -> None: + async def _close_coro(self, e: Exception, clean_close: bool = True) -> None: """Coroutine which will: - do a clean_close if possible: @@ -482,9 +491,7 @@ async def _close(self, e: Exception, clean_close: bool = True) -> None: - close the websocket connection - send the exception to all the remaining listeners """ - if self.websocket and not self._is_closing: - - self._is_closing = True + if self.websocket: # Saving exception to raise it later if trying to use the transport after it has closed self.close_exception = e @@ -500,6 +507,18 @@ async def _close(self, e: Exception, clean_close: bool = True) -> None: self.websocket = None + self.close_task = None + self._wait_closed.set() + + async def _fail(self, e: Exception, clean_close: bool = True) -> None: + if self.close_task is None: + self.close_task = asyncio.shield( + asyncio.ensure_future(self._close_coro(e, clean_close=clean_close)) + ) + async def close(self) -> None: + await self._fail(TransportClosed("Websocket GraphQL transport closed by user")) + await self.wait_closed() - await self._close(TransportClosed("Websocket GraphQL transport closed by user")) + async def wait_closed(self) -> None: + await self._wait_closed.wait() diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 87e983bb..f6f737d3 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -228,5 +228,7 @@ async def test_websocket_server_closing_after_ack(client_and_server): with pytest.raises(websockets.exceptions.ConnectionClosed): result = await client.execute(query) + await client.transport.wait_closed() + with pytest.raises(TransportClosed): result = await client.execute(query) diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index db24e1d4..9b8e0364 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -5,7 +5,11 @@ from .websocket_fixtures import MS, server, client_and_server, TestServer from graphql.execution import ExecutionResult from gql.transport.websockets import WebsocketsTransport -from gql.transport.exceptions import TransportClosed, TransportQueryError +from gql.transport.exceptions import ( + TransportClosed, + TransportQueryError, + TransportAlreadyConnected, +) from gql import gql, AsyncClient from typing import Dict @@ -220,3 +224,78 @@ async def test_websocket_ignore_invalid_id(client_and_server, query_str): # Third query is working result = await client.execute(query) assert isinstance(result, ExecutionResult) + + +async def assert_client_is_working(client): + query1 = gql(query1_str) + + result = await client.execute(query1) + + assert isinstance(result, ExecutionResult) + + print("Client received: " + str(result.data)) + + # Verify result + assert result.errors is None + assert isinstance(result.data, Dict) + + continents = result.data["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +async def test_websocket_multiple_connections_in_series(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient(transport=sample_transport) as client: + await assert_client_is_working(client) + + # Check client is disconnect here + assert sample_transport.websocket is None + + async with AsyncClient(transport=sample_transport) as client: + await assert_client_is_working(client) + + # Check client is disconnect here + assert sample_transport.websocket is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +async def test_websocket_multiple_connections_in_parallel(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + async def task_coro(): + sample_transport = WebsocketsTransport(url=url) + async with AsyncClient(transport=sample_transport) as client: + await assert_client_is_working(client) + + task1 = asyncio.ensure_future(task_coro()) + task2 = asyncio.ensure_future(task_coro()) + + await asyncio.gather(task1, task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +async def test_websocket_trying_to_connect_to_already_connected_transport(server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + async with AsyncClient(transport=sample_transport) as client: + await assert_client_is_working(client) + + with pytest.raises(TransportAlreadyConnected): + async with AsyncClient(transport=sample_transport) as client2: + pass From a329324da9b31bbc61bd5230c34a051ab7b259c7 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 26 Apr 2020 15:23:36 +0200 Subject: [PATCH 21/46] TESTS_PY36 add tests for async client validation Testing validation using schema, type_def or introspection Fixing some flake8/black styling issues in tests_py36 --- tests_py36/schema.py | 40 +++-- tests_py36/test_async_client_validation.py | 170 +++++++++++++++++++++ tests_py36/test_query.py | 34 ++--- tests_py36/test_websocket_exceptions.py | 27 ++-- tests_py36/test_websocket_online.py | 2 +- tests_py36/test_websocket_query.py | 18 +-- tests_py36/test_websocket_subscription.py | 1 - 7 files changed, 237 insertions(+), 55 deletions(-) create mode 100644 tests_py36/test_async_client_validation.py diff --git a/tests_py36/schema.py b/tests_py36/schema.py index d3e1d272..de29d50a 100644 --- a/tests_py36/schema.py +++ b/tests_py36/schema.py @@ -1,28 +1,46 @@ -from graphql import GraphQLField, GraphQLArgument, GraphQLObjectType, GraphQLSchema +from graphql import ( + graphql, + print_schema, + GraphQLField, + GraphQLArgument, + GraphQLObjectType, + GraphQLSchema, +) +from graphql.utils.introspection_query import introspection_query -from tests.starwars.schema import reviewType, episodeEnum, queryType, mutationType, humanType, droidType, \ - reviewInputType +from tests.starwars.schema import ( + reviewType, + episodeEnum, + queryType, + mutationType, + humanType, + droidType, + reviewInputType, +) from tests_py36.fixtures import reviewAdded subscriptionType = GraphQLObjectType( - 'Subscription', + "Subscription", fields=lambda: { - 'reviewAdded': GraphQLField( + "reviewAdded": GraphQLField( reviewType, args={ - 'episode': GraphQLArgument( - description='Episode to review', - type=episodeEnum, + "episode": GraphQLArgument( + description="Episode to review", type=episodeEnum, ) }, - resolver=lambda root, info, **args: reviewAdded(args.get('episode')), + resolver=lambda root, info, **args: reviewAdded(args.get("episode")), ) - } + }, ) StarWarsSchema = GraphQLSchema( query=queryType, mutation=mutationType, subscription=subscriptionType, - types=[humanType, droidType, reviewType, reviewInputType] + types=[humanType, droidType, reviewType, reviewInputType], ) + +StarWarsIntrospection = graphql(StarWarsSchema, introspection_query).data # type: ignore + +StarWarsTypeDef = print_schema(StarWarsSchema) diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py new file mode 100644 index 00000000..38c8ecce --- /dev/null +++ b/tests_py36/test_async_client_validation.py @@ -0,0 +1,170 @@ +import asyncio +import pytest +import json +import websockets +import graphql + +from .websocket_fixtures import MS, server, TestServer +from graphql.execution import ExecutionResult +from gql import gql, AsyncClient +from gql.transport.websockets import WebsocketsTransport +from tests_py36.schema import StarWarsSchema, StarWarsTypeDef, StarWarsIntrospection + + +starwars_expected_one = { + "stars": 3, + "commentary": "Was expecting more stuff", + "episode": "JEDI", +} + +starwars_expected_two = { + "stars": 5, + "commentary": "This is a great movie!", + "episode": "JEDI", +} + + +async def server_starwars(ws, path): + await TestServer.send_connection_ack(ws) + + try: + await ws.recv() + + reviews = [starwars_expected_one, starwars_expected_two] + + for review in reviews: + + data = '{{"type":"data","id":"1","payload":{{"data":{{"reviewAdded": {0}}}}}}}'.format( + json.dumps(review) + ) + await ws.send(data) + await asyncio.sleep(2 * MS) + + await TestServer.send_complete(ws, 1) + await TestServer.wait_connection_terminate(ws) + + except websockets.exceptions.ConnectionClosedOK: + pass + + print("Server is now closed") + + +starwars_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + stars, + commentary, + episode + } + } +""" + +starwars_invalid_subscription_str = """ + subscription ListenEpisodeReviews($ep: Episode!) { + reviewAdded(episode: $ep) { + not_valid_field, + stars, + commentary, + episode + } + } +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema}, + {"introspection": StarWarsIntrospection}, + {"type_def": StarWarsTypeDef}, + ], +) +async def test_async_client_validation(server, subscription_str, client_params): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient(transport=sample_transport, **client_params) as client: + + variable_values = {"ep": "JEDI"} + + subscription = gql(subscription_str) + + expected = [] + + async for result in client.subscribe( + subscription, variable_values=variable_values + ): + + assert isinstance(result, ExecutionResult) + assert result.errors is None + + review = result.data["reviewAdded"] + expected.append(review) + + assert "stars" in review + assert "commentary" in review + assert "episode" in review + + assert expected[0] == starwars_expected_one + assert expected[1] == starwars_expected_two + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_invalid_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema}, + {"introspection": StarWarsIntrospection}, + {"type_def": StarWarsTypeDef}, + ], +) +async def test_async_client_validation_invalid_query( + server, subscription_str, client_params +): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient(transport=sample_transport, **client_params) as client: + + variable_values = {"ep": "JEDI"} + + subscription = gql(subscription_str) + + with pytest.raises(graphql.error.base.GraphQLError): + async for result in client.subscribe( + subscription, variable_values=variable_values + ): + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_starwars], indirect=True) +@pytest.mark.parametrize("subscription_str", [starwars_invalid_subscription_str]) +@pytest.mark.parametrize( + "client_params", + [ + {"schema": StarWarsSchema, "introspection": StarWarsIntrospection}, + {"schema": StarWarsSchema, "type_def": StarWarsTypeDef}, + {"introspection": StarWarsIntrospection, "type_def": StarWarsTypeDef}, + ], +) +async def test_async_client_validation_different_schemas_parameters_forbidden( + server, subscription_str, client_params +): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + sample_transport = WebsocketsTransport(url=url) + + with pytest.raises(AssertionError): + async with AsyncClient(transport=sample_transport, **client_params): + pass diff --git a/tests_py36/test_query.py b/tests_py36/test_query.py index 937ca2b0..aca7bfca 100644 --- a/tests_py36/test_query.py +++ b/tests_py36/test_query.py @@ -20,15 +20,15 @@ def __aiter__(self): async def __anext__(self): type_, val = await self.queue.get() - if type_ in ('E', 'C'): + if type_ in ("E", "C"): raise StopAsyncIteration() return val async def __aenter__(self): self.disposable = self.observable.subscribe( - on_next=lambda val: self.queue.put_nowait(('N', val)), - on_error=lambda exc: self.queue.put_nowait(('E', exc)), - on_completed=lambda: self.queue.put_nowait(('C', None)), + on_next=lambda val: self.queue.put_nowait(("N", val)), + on_error=lambda exc: self.queue.put_nowait(("E", exc)), + on_completed=lambda: self.queue.put_nowait(("C", None)), ) return self @@ -38,7 +38,8 @@ async def __aexit__(self, exc_type, exc_value, traceback): @pytest.mark.asyncio async def test_subscription_support(): - subs = gql(''' + subs = gql( + """ subscription ListenEpisodeReviews($ep: Episode!) { reviewAdded(episode: $ep) { stars, @@ -46,19 +47,18 @@ async def test_subscription_support(): episode } } - ''') - params = { - 'ep': 'JEDI' - } + """ + ) + params = {"ep": "JEDI"} expected_one = { - 'stars': 3, - 'commentary': 'Was expecting more stuff', - 'episode': 'JEDI' + "stars": 3, + "commentary": "Was expecting more stuff", + "episode": "JEDI", } expected_two = { - 'stars': 5, - 'commentary': 'This is a great movie!', - 'episode': 'JEDI' + "stars": 5, + "commentary": "This is a great movie!", + "episode": "JEDI", } # For asyncio, requires set return_promise=True as stated on the following comment # https://github.com/graphql-python/graphql-core/issues/63#issuecomment-568270864 @@ -68,12 +68,12 @@ async def test_subscription_support(): document_ast=subs, return_promise=True, variable_values=params, - executor=AsyncioExecutor(loop=loop) + executor=AsyncioExecutor(loop=loop), ) expected = [] async with ObservableAsyncIterable(execution_result) as oai: async for i in oai: review = i.to_dict() - expected.append(review['data']['reviewAdded']) + expected.append(review["data"]["reviewAdded"]) assert expected[0] == expected_one assert expected[1] == expected_two diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index f6f737d3..98ce2f7a 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -3,9 +3,7 @@ import json import websockets import types -import gql -from parse import search from .websocket_fixtures import MS, server, client_and_server, TestServer from graphql.execution import ExecutionResult from gql import gql, AsyncClient @@ -13,7 +11,6 @@ from gql.transport.exceptions import ( TransportProtocolError, TransportQueryError, - TransportServerError, TransportClosed, ) @@ -39,7 +36,7 @@ @pytest.mark.asyncio -@pytest.mark.parametrize("server", [invalid_query1_server,], indirect=True) +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_websocket_invalid_query(client_and_server, query_str): @@ -117,7 +114,7 @@ async def monkey_patch_send_query( self.next_query_id += 1 query_str = json.dumps( - {"id": str(query_id), "type": "start", "payload": "BLAHBLAH",} + {"id": str(query_id), "type": "start", "payload": "BLAHBLAH"} ) await self._send(query_str) @@ -130,7 +127,7 @@ async def monkey_patch_send_query( query = gql(query_str) with pytest.raises(TransportQueryError): - result = await client.execute(query) + await client.execute(query) not_json_answer = ["BLAHBLAH"] @@ -169,7 +166,7 @@ async def test_websocket_transport_protocol_errors(client_and_server): query = gql("query { hello }") with pytest.raises(TransportProtocolError): - result = await client.execute(query) + await client.execute(query) async def server_without_ack(ws, path): @@ -179,7 +176,7 @@ async def server_without_ack(ws, path): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server_without_ack,], indirect=True) +@pytest.mark.parametrize("server", [server_without_ack], indirect=True) async def test_websocket_server_does_not_ack(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -188,8 +185,7 @@ async def test_websocket_server_does_not_ack(server): sample_transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with AsyncClient(transport=sample_transport) as client: - + async with AsyncClient(transport=sample_transport): pass @@ -198,7 +194,7 @@ async def server_closing_directly(ws, path): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server_closing_directly,], indirect=True) +@pytest.mark.parametrize("server", [server_closing_directly], indirect=True) async def test_websocket_server_closing_directly(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -207,8 +203,7 @@ async def test_websocket_server_closing_directly(server): sample_transport = WebsocketsTransport(url=url) with pytest.raises(websockets.exceptions.ConnectionClosed): - async with AsyncClient(transport=sample_transport) as client: - + async with AsyncClient(transport=sample_transport): pass @@ -218,7 +213,7 @@ async def server_closing_after_ack(ws, path): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server_closing_after_ack,], indirect=True) +@pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) async def test_websocket_server_closing_after_ack(client_and_server): client, server = client_and_server @@ -226,9 +221,9 @@ async def test_websocket_server_closing_after_ack(client_and_server): query = gql("query { hello }") with pytest.raises(websockets.exceptions.ConnectionClosed): - result = await client.execute(query) + await client.execute(query) await client.transport.wait_closed() with pytest.raises(TransportClosed): - result = await client.execute(query) + await client.execute(query) diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index c13bf18a..af652f09 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -135,7 +135,7 @@ async def test_websocket_sending_invalid_payload(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport): invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 9b8e0364..28032268 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -36,7 +36,7 @@ @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +@pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_starting_client_in_context_manager(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -72,7 +72,7 @@ async def test_websocket_starting_client_in_context_manager(server): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +@pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_simple_query(client_and_server, query_str): @@ -92,7 +92,7 @@ async def test_websocket_simple_query(client_and_server, query_str): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_two_answers_in_series,], indirect=True) +@pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_two_queries_in_series(client_and_server, query_str): @@ -170,7 +170,7 @@ async def server_closing_while_we_are_doing_something_else(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize( - "server", [server_closing_while_we_are_doing_something_else,], indirect=True + "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_server_closing_after_first_query(client_and_server, query_str): @@ -205,7 +205,7 @@ async def test_websocket_server_closing_after_first_query(client_and_server, que @pytest.mark.asyncio -@pytest.mark.parametrize("server", [ignore_invalid_id_answers,], indirect=True) +@pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_ignore_invalid_id(client_and_server, query_str): @@ -246,7 +246,7 @@ async def assert_client_is_working(client): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +@pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_multiple_connections_in_series(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -268,7 +268,7 @@ async def test_websocket_multiple_connections_in_series(server): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +@pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_multiple_connections_in_parallel(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -286,7 +286,7 @@ async def task_coro(): @pytest.mark.asyncio -@pytest.mark.parametrize("server", [server1_answers,], indirect=True) +@pytest.mark.parametrize("server", [server1_answers], indirect=True) async def test_websocket_trying_to_connect_to_already_connected_transport(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -297,5 +297,5 @@ async def test_websocket_trying_to_connect_to_already_connected_transport(server await assert_client_is_working(client) with pytest.raises(TransportAlreadyConnected): - async with AsyncClient(transport=sample_transport) as client2: + async with AsyncClient(transport=sample_transport): pass diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 4c3984a6..97bb383a 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -219,7 +219,6 @@ async def close_transport_task_coro(): close_transport_task = asyncio.ensure_future(close_transport_task_coro()) - # with pytest.raises(websockets.exceptions.ConnectionClosedOK): await asyncio.gather(task, close_transport_task) assert count > 0 From f8ead7495ece24c01b2ee6a3e72860d8e8b40b21 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 26 Apr 2020 18:23:02 +0200 Subject: [PATCH 22/46] Fix tests for pytest-asyncio==0.11.0 From this commit of pytest-asyncio: https://github.com/pytest-dev/pytest-asyncio/commit/a9e2213224151f35a695dd588d3bb79c42ff9f3b The asyncio tests were not working anymore This is a bug of pytest-asyncio: https://github.com/pytest-dev/pytest-asyncio/issues/157 As a workaround, adding event_loop parameter to all asyncio tests --- tests_py36/test_async_client_validation.py | 6 +++--- tests_py36/test_websocket_exceptions.py | 14 +++++++------- tests_py36/test_websocket_query.py | 18 +++++++++--------- tests_py36/test_websocket_subscription.py | 14 +++++++------- 4 files changed, 26 insertions(+), 26 deletions(-) diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index 38c8ecce..ff8459f7 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -82,7 +82,7 @@ async def server_starwars(ws, path): {"type_def": StarWarsTypeDef}, ], ) -async def test_async_client_validation(server, subscription_str, client_params): +async def test_async_client_validation(event_loop, server, subscription_str, client_params): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -126,7 +126,7 @@ async def test_async_client_validation(server, subscription_str, client_params): ], ) async def test_async_client_validation_invalid_query( - server, subscription_str, client_params + event_loop, server, subscription_str, client_params ): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" @@ -158,7 +158,7 @@ async def test_async_client_validation_invalid_query( ], ) async def test_async_client_validation_different_schemas_parameters_forbidden( - server, subscription_str, client_params + event_loop, server, subscription_str, client_params ): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 98ce2f7a..996943e6 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -38,7 +38,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_invalid_query(client_and_server, query_str): +async def test_websocket_invalid_query(event_loop, client_and_server, query_str): client, server = client_and_server @@ -73,7 +73,7 @@ async def server_connection_error(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_connection_error], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_data(client_and_server, query_str): +async def test_websocket_sending_invalid_data(event_loop, client_and_server, query_str): client, server = client_and_server @@ -101,7 +101,7 @@ async def server_invalid_payload(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_payload(client_and_server, query_str): +async def test_websocket_sending_invalid_payload(event_loop, client_and_server, query_str): client, server = client_and_server @@ -159,7 +159,7 @@ async def monkey_patch_send_query( ], indirect=True, ) -async def test_websocket_transport_protocol_errors(client_and_server): +async def test_websocket_transport_protocol_errors(event_loop, client_and_server): client, server = client_and_server @@ -177,7 +177,7 @@ async def server_without_ack(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_without_ack], indirect=True) -async def test_websocket_server_does_not_ack(server): +async def test_websocket_server_does_not_ack(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -195,7 +195,7 @@ async def server_closing_directly(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_directly], indirect=True) -async def test_websocket_server_closing_directly(server): +async def test_websocket_server_closing_directly(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -214,7 +214,7 @@ async def server_closing_after_ack(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) -async def test_websocket_server_closing_after_ack(client_and_server): +async def test_websocket_server_closing_after_ack(event_loop, client_and_server): client, server = client_and_server diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 28032268..ff41ed3e 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -37,7 +37,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_starting_client_in_context_manager(server): +async def test_websocket_starting_client_in_context_manager(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -74,7 +74,7 @@ async def test_websocket_starting_client_in_context_manager(server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_simple_query(client_and_server, query_str): +async def test_websocket_simple_query(event_loop, client_and_server, query_str): client, server = client_and_server @@ -94,7 +94,7 @@ async def test_websocket_simple_query(client_and_server, query_str): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_series(client_and_server, query_str): +async def test_websocket_two_queries_in_series(event_loop, client_and_server, query_str): client, server = client_and_server @@ -128,7 +128,7 @@ async def server1_two_queries_in_parallel(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_parallel(client_and_server, query_str): +async def test_websocket_two_queries_in_parallel(event_loop, client_and_server, query_str): client, server = client_and_server @@ -173,7 +173,7 @@ async def server_closing_while_we_are_doing_something_else(ws, path): "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_server_closing_after_first_query(client_and_server, query_str): +async def test_websocket_server_closing_after_first_query(event_loop, client_and_server, query_str): client, server = client_and_server @@ -207,7 +207,7 @@ async def test_websocket_server_closing_after_first_query(client_and_server, que @pytest.mark.asyncio @pytest.mark.parametrize("server", [ignore_invalid_id_answers], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_ignore_invalid_id(client_and_server, query_str): +async def test_websocket_ignore_invalid_id(event_loop, client_and_server, query_str): client, server = client_and_server @@ -247,7 +247,7 @@ async def assert_client_is_working(client): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_series(server): +async def test_websocket_multiple_connections_in_series(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -269,7 +269,7 @@ async def test_websocket_multiple_connections_in_series(server): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_multiple_connections_in_parallel(server): +async def test_websocket_multiple_connections_in_parallel(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -287,7 +287,7 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_trying_to_connect_to_already_connected_transport(server): +async def test_websocket_trying_to_connect_to_already_connected_transport(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 97bb383a..e2f664d7 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -101,7 +101,7 @@ async def keepalive_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription(client_and_server, subscription_str): +async def test_websocket_subscription(event_loop, client_and_server, subscription_str): client, server = client_and_server @@ -123,7 +123,7 @@ async def test_websocket_subscription(client_and_server, subscription_str): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_break(client_and_server, subscription_str): +async def test_websocket_subscription_break(event_loop, client_and_server, subscription_str): client, server = client_and_server @@ -149,7 +149,7 @@ async def test_websocket_subscription_break(client_and_server, subscription_str) @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_task_cancel(client_and_server, subscription_str): +async def test_websocket_subscription_task_cancel(event_loop, client_and_server, subscription_str): client, server = client_and_server @@ -188,7 +188,7 @@ async def cancel_task_coro(): @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_close_transport( - client_and_server, subscription_str + event_loop, client_and_server, subscription_str ): client, server = client_and_server @@ -254,7 +254,7 @@ async def server_countdown_close_connection_in_middle(ws, path): ) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_server_connection_closed( - client_and_server, subscription_str + event_loop, client_and_server, subscription_str ): client, server = client_and_server @@ -281,7 +281,7 @@ async def test_websocket_subscription_server_connection_closed( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_slow_consumer( - client_and_server, subscription_str + event_loop, client_and_server, subscription_str ): client, server = client_and_server @@ -310,7 +310,7 @@ async def test_websocket_subscription_slow_consumer( @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription_with_keepalive( - client_and_server, subscription_str + event_loop, client_and_server, subscription_str ): client, server = client_and_server From 804d433eb1213535cbdeae2b5d8e78962bf05f98 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 4 May 2020 20:57:58 +0200 Subject: [PATCH 23/46] Adding init_payload parameter to websockets transport to allow to specify init payload + tests and documentation + black for style in tests_py36 folder --- README.md | 24 +++++ gql/transport/websockets.py | 11 ++- tests_py36/test_async_client_validation.py | 4 +- tests_py36/test_websocket_exceptions.py | 4 +- tests_py36/test_websocket_query.py | 102 ++++++++++++++++++++- tests_py36/test_websocket_subscription.py | 8 +- 6 files changed, 142 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 955f4595..34d6ca61 100644 --- a/README.md +++ b/README.md @@ -210,6 +210,30 @@ If you have also need to have a client ssl certificate, add: ssl_context.load_cert_chain(certfile='YOUR_CLIENT_CERTIFICATE.pem', keyfile='YOUR_CLIENT_CERTIFICATE_KEY.key') ``` +### Websockets authentication + +There are two ways to send authentication tokens with websockets depending on the server configuration. + +1. Using HTTP Headers + +```python +sample_transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + headers={'Authorization': 'token'}, + ssl=True +) +``` + +2. With a payload in the connection_init websocket message + +```python +sample_transport = WebsocketsTransport( + url='wss://SERVER_URL:SERVER_PORT/graphql', + init_payload={'Authorization': 'token'}, + ssl=True +) +``` + ### Websockets advanced usage It is possible to send multiple GraphQL queries (query, mutation or subscription) in parallel, diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 01b76d71..eb2dfa72 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -12,7 +12,7 @@ import json import logging -from typing import cast, Dict, Optional, Tuple, Union, AsyncGenerator +from typing import cast, Dict, Optional, Tuple, Union, AsyncGenerator, Any from graphql.execution import ExecutionResult from graphql.language.ast import Document @@ -90,16 +90,19 @@ def __init__( url: str, headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, + init_payload: Dict[str, Any] = {}, ) -> None: """Initialize the transport with the given request parameters. :param url: The GraphQL server URL. Example: 'wss://server.com:PORT/graphql'. :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param init_payload: Dict of the payload sent in the connection_init message. """ self.url: str = url self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers + self.init_payload: Dict[str, Any] = init_payload self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 @@ -159,7 +162,11 @@ async def _send_init_message_and_wait_ack(self) -> None: If the answer is not a connection_ack message, we will return an Exception """ - await self._send('{"type":"connection_init","payload":{}}') + init_message = json.dumps( + {"type": "connection_init", "payload": self.init_payload} + ) + + await self._send(init_message) init_answer = await self._receive() diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index ff8459f7..78441df8 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -82,7 +82,9 @@ async def server_starwars(ws, path): {"type_def": StarWarsTypeDef}, ], ) -async def test_async_client_validation(event_loop, server, subscription_str, client_params): +async def test_async_client_validation( + event_loop, server, subscription_str, client_params +): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 996943e6..3066eded 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -101,7 +101,9 @@ async def server_invalid_payload(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_invalid_payload], indirect=True) @pytest.mark.parametrize("query_str", [invalid_query_str]) -async def test_websocket_sending_invalid_payload(event_loop, client_and_server, query_str): +async def test_websocket_sending_invalid_payload( + event_loop, client_and_server, query_str +): client, server = client_and_server diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index ff41ed3e..048bce3b 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -1,6 +1,7 @@ import asyncio import pytest import websockets +import json from .websocket_fixtures import MS, server, client_and_server, TestServer from graphql.execution import ExecutionResult @@ -8,6 +9,7 @@ from gql.transport.exceptions import ( TransportClosed, TransportQueryError, + TransportServerError, TransportAlreadyConnected, ) from gql import gql, AsyncClient @@ -37,7 +39,7 @@ @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_starting_client_in_context_manager(event_loop, server): +async def test_websocket_starting_client_in_context_manager(event_loop, server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -94,7 +96,9 @@ async def test_websocket_simple_query(event_loop, client_and_server, query_str): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_answers_in_series], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_series(event_loop, client_and_server, query_str): +async def test_websocket_two_queries_in_series( + event_loop, client_and_server, query_str +): client, server = client_and_server @@ -128,7 +132,9 @@ async def server1_two_queries_in_parallel(ws, path): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_two_queries_in_parallel], indirect=True) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_two_queries_in_parallel(event_loop, client_and_server, query_str): +async def test_websocket_two_queries_in_parallel( + event_loop, client_and_server, query_str +): client, server = client_and_server @@ -173,7 +179,9 @@ async def server_closing_while_we_are_doing_something_else(ws, path): "server", [server_closing_while_we_are_doing_something_else], indirect=True ) @pytest.mark.parametrize("query_str", [query1_str]) -async def test_websocket_server_closing_after_first_query(event_loop, client_and_server, query_str): +async def test_websocket_server_closing_after_first_query( + event_loop, client_and_server, query_str +): client, server = client_and_server @@ -287,7 +295,9 @@ async def task_coro(): @pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -async def test_websocket_trying_to_connect_to_already_connected_transport(event_loop, server): +async def test_websocket_trying_to_connect_to_already_connected_transport( + event_loop, server +): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") @@ -299,3 +309,85 @@ async def test_websocket_trying_to_connect_to_already_connected_transport(event_ with pytest.raises(TransportAlreadyConnected): async with AsyncClient(transport=sample_transport): pass + + +async def server_with_authentication_in_connection_init_payload(ws, path): + # Wait the connection_init message + init_message_str = await ws.recv() + init_message = json.loads(init_message_str) + payload = init_message["payload"] + + if "Authorization" in payload and payload["Authorization"] == 12345: + + await ws.send('{"type":"connection_ack"}') + + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) + await asyncio.sleep(1 * MS) + + else: + await ws.send( + '{"type":"connection_error", "payload": "Invalid Authorization token"}' + ) + + await ws.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +async def test_websocket_connect_success_with_authentication_in_connection_init( + event_loop, server, query_str +): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + init_payload = {"Authorization": 12345} + + sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) + + async with AsyncClient(transport=sample_transport) as client: + + query1 = gql(query_str) + + result = await client.execute(query1) + + assert isinstance(result, ExecutionResult) + + print("Client received: " + str(result.data)) + + # Verify result + assert result.errors is None + assert isinstance(result.data, Dict) + + continents = result.data["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "server", [server_with_authentication_in_connection_init_payload], indirect=True +) +@pytest.mark.parametrize("query_str", [query1_str]) +@pytest.mark.parametrize("init_payload", [{}, {"Authorization": "invalid_code"}]) +async def test_websocket_connect_failed_with_authentication_in_connection_init( + event_loop, server, query_str, init_payload +): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) + + with pytest.raises(TransportServerError): + async with AsyncClient(transport=sample_transport) as client: + query1 = gql(query_str) + + await client.execute(query1) diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index e2f664d7..6df530a7 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -123,7 +123,9 @@ async def test_websocket_subscription(event_loop, client_and_server, subscriptio @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_break(event_loop, client_and_server, subscription_str): +async def test_websocket_subscription_break( + event_loop, client_and_server, subscription_str +): client, server = client_and_server @@ -149,7 +151,9 @@ async def test_websocket_subscription_break(event_loop, client_and_server, subsc @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -async def test_websocket_subscription_task_cancel(event_loop, client_and_server, subscription_str): +async def test_websocket_subscription_task_cancel( + event_loop, client_and_server, subscription_str +): client, server = client_and_server From ff22c0af0b7b0389261345f56a6b84747a7c961a Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 5 May 2020 12:36:50 +0200 Subject: [PATCH 24/46] Skip online tests if --run-online pytest arg is not set Adding tests to have 100% async client coverage without online tests --- tests_py36/conftest.py | 26 ++++++++ tests_py36/test_async_client_validation.py | 69 +++++++++++++++++++++- tests_py36/test_websocket_exceptions.py | 20 +++++++ tests_py36/test_websocket_online.py | 6 ++ 4 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 tests_py36/conftest.py diff --git a/tests_py36/conftest.py b/tests_py36/conftest.py new file mode 100644 index 00000000..691f6c93 --- /dev/null +++ b/tests_py36/conftest.py @@ -0,0 +1,26 @@ +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--run-online", + action="store_true", + default=False, + help="run tests necessitating online ressources", + ) + + +def pytest_configure(config): + config.addinivalue_line( + "markers", "online: mark test as necessitating external online ressources" + ) + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--run-online"): + # --run-online given in cli: do not skip online tests + return + skip_online = pytest.mark.skip(reason="need --run-online option to run") + for item in items: + if "online" in item.keywords: + item.add_marker(skip_online) diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index 78441df8..81d79bed 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -4,7 +4,7 @@ import websockets import graphql -from .websocket_fixtures import MS, server, TestServer +from .websocket_fixtures import MS, server, client_and_server, TestServer from graphql.execution import ExecutionResult from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport @@ -170,3 +170,70 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( with pytest.raises(AssertionError): async with AsyncClient(transport=sample_transport, **client_params): pass + + +hero_server_answers = ( + f'{{"type":"data","id":"1","payload":{{"data":{json.dumps(StarWarsIntrospection)}}}}}', + '{"type":"data","id":"2","payload":{"data":{"hero":{"name": "R2-D2"}}}}', +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [hero_server_answers], indirect=True) +async def test_async_client_validation_fetch_schema_from_server_valid_query( + event_loop, client_and_server +): + client, server = client_and_server + + # No schema in the client at the beginning + assert client.introspection is None + assert client.schema is None + + # Fetch schema from server + await client.fetch_schema() + + # Check that the async client correctly recreated the schema + assert client.introspection == StarWarsIntrospection + assert client.schema is not None + + query = gql( + """ + query HeroNameQuery { + hero { + name + } + } + """ + ) + + result = await client.execute(query) + + print("Client received: " + str(result.data)) + expected = {"hero": {"name": "R2-D2"}} + + assert result.data == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [hero_server_answers], indirect=True) +async def test_async_client_validation_fetch_schema_from_server_invalid_query( + event_loop, client_and_server +): + client, server = client_and_server + + # Fetch schema from server + await client.fetch_schema() + + query = gql( + """ + query HeroNameQuery { + hero { + name + sldkfjqlmsdkjfqlskjfmlqkjsfmkjqsdf + } + } + """ + ) + + with pytest.raises(graphql.error.base.GraphQLError): + await client.execute(query) diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 3066eded..45c68eb6 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -229,3 +229,23 @@ async def test_websocket_server_closing_after_ack(event_loop, client_and_server) with pytest.raises(TransportClosed): await client.execute(query) + + +async def server_sending_invalid_query_errors(ws, path): + await TestServer.send_connection_ack(ws) + invalid_error = '{"type":"error","id":"404","payload":{"message":"error for no good reason on non existing query"}}' + await ws.send(invalid_error) + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_sending_invalid_query_errors], indirect=True) +async def test_websocket_server_sending_invalid_query_errors(event_loop, server): + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + # Invalid server message is ignored + async with AsyncClient(transport=sample_transport): + await asyncio.sleep(2 * MS) diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index af652f09..afbbe573 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -13,6 +13,7 @@ logging.basicConfig(level=logging.INFO) +@pytest.mark.online @pytest.mark.asyncio async def test_websocket_simple_query(): @@ -54,6 +55,7 @@ async def test_websocket_simple_query(): assert africa["code"] == "AF" +@pytest.mark.online @pytest.mark.asyncio async def test_websocket_invalid_query(): @@ -88,6 +90,7 @@ async def test_websocket_invalid_query(): assert result.errors is not None +@pytest.mark.online @pytest.mark.asyncio async def test_websocket_sending_invalid_data(): @@ -126,6 +129,7 @@ async def test_websocket_sending_invalid_data(): await asyncio.sleep(2) +@pytest.mark.online @pytest.mark.asyncio async def test_websocket_sending_invalid_payload(): @@ -145,6 +149,7 @@ async def test_websocket_sending_invalid_payload(): await asyncio.sleep(2) +@pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio async def test_websocket_sending_invalid_data_while_other_query_is_running(): @@ -197,6 +202,7 @@ async def query_task2(): await asyncio.gather(task1, task2) +@pytest.mark.online @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @pytest.mark.asyncio async def test_websocket_two_queries_in_parallel_using_two_tasks(): From 268d7a76487373c5aeb658f6495bb9649a0341b5 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 5 May 2020 14:26:47 +0200 Subject: [PATCH 25/46] Fix tests on pypy3 --- tests_py36/test_websocket_query.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 048bce3b..3e86b064 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -317,22 +317,24 @@ async def server_with_authentication_in_connection_init_payload(ws, path): init_message = json.loads(init_message_str) payload = init_message["payload"] - if "Authorization" in payload and payload["Authorization"] == 12345: - - await ws.send('{"type":"connection_ack"}') - - result = await ws.recv() - print(f"Server received: {result}") - await ws.send(query1_server_answer.format(query_id=1)) - await TestServer.send_complete(ws, 1) - await asyncio.sleep(1 * MS) - + if "Authorization" in payload: + if payload["Authorization"] == 12345: + await ws.send('{"type":"connection_ack"}') + + result = await ws.recv() + print(f"Server received: {result}") + await ws.send(query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) + else: + await ws.send( + '{"type":"connection_error", "payload": "Invalid Authorization token"}' + ) else: await ws.send( - '{"type":"connection_error", "payload": "Invalid Authorization token"}' + '{"type":"connection_error", "payload": "No Authorization token"}' ) - await ws.close() + await ws.wait_closed() @pytest.mark.asyncio From c167755be511e2d9006c3dbb1221c9eae16e8f8a Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 7 May 2020 14:50:51 +0200 Subject: [PATCH 26/46] Implementation of AIOHTTPTransport --- .gitignore | 3 + README.md | 49 +++++++++- gql/transport/aiohttp.py | 144 ++++++++++++++++++++++++++++ setup.py | 1 + tests_py36/test_aiohttp_online.py | 154 ++++++++++++++++++++++++++++++ 5 files changed, 348 insertions(+), 3 deletions(-) create mode 100644 gql/transport/aiohttp.py create mode 100644 tests_py36/test_aiohttp_online.py diff --git a/.gitignore b/.gitignore index 49246968..3904dc54 100644 --- a/.gitignore +++ b/.gitignore @@ -89,3 +89,6 @@ target/ ### VisualStudioCode ### .vscode/* + +# VIM +*.swp diff --git a/README.md b/README.md index 34d6ca61..e6d6848c 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,53 @@ query = gql(''' client.execute(query) ``` -## Websockets transport with asyncio +# Async clients and transports -It is possible to use the websockets transport using the `asyncio` library. -Python3.6 is required for this transport. +It is possible to use async clients and transports using [asyncio](https://docs.python.org/3/library/asyncio.html). +Python3.6 is required for async clients and transports + +## HTTP async transport + +This transport uses the [aiohttp library](https://docs.aiohttp.org) + +GraphQL subscriptions are not supported on this HTTP transport. +For subscriptions you should use the websockets transport. + +```python +from gql import gql, AsyncClient +from gql.transport.aiohttp import AIOHTTPTransport +import asyncio + +async def main(): + + sample_transport = AIOHTTPTransport( + url='https://countries.trevorblades.com/graphql', + headers={'Authorization': 'token'} + ) + + async with AsyncClient(transport=sample_transport) as client: + + # Fetch schema (optional) + await client.fetch_schema() + + # Execute single query + query = gql(''' + query getContinents { + continents { + code + name + } + } + ''') + + result = await client.execute(query) + + print (f'result data = {result.data}, errors = {result.errors}') + +asyncio.run(main()) +``` + +## Websockets async transport The websockets transport uses the apollo protocol described here: diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py new file mode 100644 index 00000000..1ceccf80 --- /dev/null +++ b/gql/transport/aiohttp.py @@ -0,0 +1,144 @@ +import aiohttp + +from aiohttp.typedefs import LooseCookies, LooseHeaders +from aiohttp.helpers import BasicAuth +from aiohttp.client_reqrep import Fingerprint + +from ssl import SSLContext + +from typing import Dict, Optional, Union, AsyncGenerator, Any + +from graphql.execution import ExecutionResult +from graphql.language.ast import Document +from graphql.language.printer import print_ast + +from .async_transport import AsyncTransport +from .exceptions import ( + TransportProtocolError, + TransportClosed, + TransportAlreadyConnected, +) + + +class AIOHTTPTransport(AsyncTransport): + """Transport to execute GraphQL queries on remote servers with an http connection. + + This transport use the aiohttp library with asyncio + + See README.md for Usage + """ + + def __init__( + self, + url: str, + headers: Optional[LooseHeaders] = None, + cookies: Optional[LooseCookies] = None, + auth: Optional[BasicAuth] = None, + ssl: Union[SSLContext, bool, Fingerprint] = False, + timeout: Optional[int] = None, + **kwargs, + ) -> None: + """Initialize the transport with the given aiohttp parameters. + + :param url: The GraphQL server URL. Example: 'https://server.com:PORT/path'. + :param headers: Dict of HTTP Headers. + :param cookies: Dict of HTTP cookies. + :param auth: BasicAuth object to enable Basic HTTP auth if needed + :param ssl: ssl_context of the connection. Use ssl=False to disable encryption + :param kwargs: Other parameters forwarded to aiohttp.ClientSession + """ + self.url: str = url + self.headers: Optional[LooseHeaders] = headers + self.cookies: Optional[LooseCookies] = cookies + self.auth: Optional[BasicAuth] = auth + self.ssl: Union[SSLContext, bool, Fingerprint] = ssl + self.timeout: Optional[int] = timeout + self.kwargs = kwargs + + self.session: Optional[aiohttp.ClientSession] = None + + async def connect(self) -> None: + """Coroutine which will: + + - create an aiohttp ClientSession() as self.session + + Should be cleaned with a call to the close coroutine + """ + + if self.session is None: + + client_session_args: Dict[str, Any] = { + "cookies": self.cookies, + "headers": self.headers, + "auth": self.auth, + } + + if self.timeout is not None: + client_session_args["timeout"] = aiohttp.ClientTimeout( + total=self.timeout + ) + + # Adding custom parameters passed from init + client_session_args.update(self.kwargs) + + self.session = aiohttp.ClientSession(**client_session_args) + + else: + raise TransportAlreadyConnected("Transport is already connected") + + async def close(self) -> None: + if self.session is not None: + await self.session.close() + self.session = None + + async def execute( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + **kwargs, + ) -> ExecutionResult: + """Execute the provided document AST against the configured remote server. + This uses the aiohttp library to perform a HTTP POST request asynchronously to the remote server. + + The result is sent as an ExecutionResult object + """ + + query_str = print_ast(document) + payload = { + "query": query_str, + "variables": variable_values or {}, + "operationName": operation_name or "", + } + + post_args = { + "json": payload, + } + + # Pass kwargs to aiohttp post method + post_args.update(kwargs) + + if self.session is None: + raise TransportClosed("Transport is not connected") + + async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: + try: + result = await resp.json() + if not isinstance(result, dict): + raise ValueError + except ValueError: + result = {} + + if "errors" not in result and "data" not in result: + resp.raise_for_status() + raise TransportProtocolError("Server did not return a GraphQL result") + + return ExecutionResult(errors=result.get("errors"), data=result.get("data")) + + def subscribe( + self, + document: Document, + variable_values: Optional[Dict[str, str]] = None, + operation_name: Optional[str] = None, + ) -> AsyncGenerator[ExecutionResult, None]: + raise NotImplementedError(" The HTTP transport does not support subscriptions") diff --git a/setup.py b/setup.py index 598ee8fe..e3d79e3b 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ 'parse>=1.6.0', ]) install_requires.append('websockets>=8.1,<9') + install_requires.append('aiohttp==3.6.2') scripts.append('scripts/gql-cli') else: tests_require.append([ diff --git a/tests_py36/test_aiohttp_online.py b/tests_py36/test_aiohttp_online.py new file mode 100644 index 00000000..d495c814 --- /dev/null +++ b/tests_py36/test_aiohttp_online.py @@ -0,0 +1,154 @@ +import asyncio +import pytest +import sys + +from gql import gql, AsyncClient +from gql.transport.aiohttp import AIOHTTPTransport +from graphql.execution import ExecutionResult +from typing import Dict + + +@pytest.mark.online +@pytest.mark.asyncio +@pytest.mark.parametrize("protocol", ["http", "https"]) +async def test_aiohttp_simple_query(event_loop, protocol): + + # Create http or https url + url = f"{protocol}://countries.trevorblades.com/graphql" + + # Get transport + sample_transport = AIOHTTPTransport(url=url) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Fetch schema + await client.fetch_schema() + + # Execute query + result = await client.execute(query) + + # Verify result + assert isinstance(result, ExecutionResult) + assert result.errors is None + + assert isinstance(result.data, Dict) + + print(result.data) + + continents = result.data["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.online +@pytest.mark.asyncio +async def test_aiohttp_invalid_query(event_loop): + + sample_transport = AIOHTTPTransport( + url="https://countries.trevorblades.com/graphql" + ) + + async with AsyncClient(transport=sample_transport) as client: + + query = gql( + """ + query getContinents { + continents { + code + bloh + } + } + """ + ) + + result = await client.execute(query) + + assert isinstance(result, ExecutionResult) + + assert result.data is None + + print(f"result = {repr(result.data)}, {repr(result.errors)}") + assert result.errors is not None + + +@pytest.mark.online +@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") +@pytest.mark.asyncio +async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): + + sample_transport = AIOHTTPTransport( + url="https://countries.trevorblades.com/graphql", + ) + + # Instanciate client + async with AsyncClient(transport=sample_transport) as client: + + query1 = gql( + """ + query getContinents { + continents { + code + } + } + """ + ) + + query2 = gql( + """ + query getContinents { + continents { + name + } + } + """ + ) + + async def query_task1(): + result = await client.execute(query1) + + assert isinstance(result, ExecutionResult) + assert result.errors is None + + assert isinstance(result.data, Dict) + + print(result.data) + + continents = result.data["continents"] + + africa = continents[0] + assert africa["code"] == "AF" + + async def query_task2(): + result = await client.execute(query2) + + assert isinstance(result, ExecutionResult) + assert result.errors is None + + assert isinstance(result.data, Dict) + + print(result.data) + + continents = result.data["continents"] + + africa = continents[0] + assert africa["name"] == "Africa" + + task1 = asyncio.create_task(query_task1()) + task2 = asyncio.create_task(query_task2()) + + await task1 + await task2 From 51df20422501387b5b9f922cefb0ce0d8522e4b3 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 May 2020 14:29:11 +0200 Subject: [PATCH 27/46] MAKEFILE add clean target --- Makefile | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 83609cef..6e7b9b75 100644 --- a/Makefile +++ b/Makefile @@ -2,4 +2,16 @@ dev-setup: python pip install -e ".[test]" tests: - pytest tests --cov=gql -vv \ No newline at end of file + pytest tests --cov=gql -vv + +clean: + find . -name "*.pyc" -delete + find . -name "__pycache__" | xargs -I {} rm -rf {} + rm -rf ./htmlcov + rm -rf ./.mypy_cache + rm -rf ./.pytest_cache + rm -rf ./.tox + rm -rf ./gql.egg-info + rm -rf ./dist + rm -rf ./build + rm -f ./.coverage From ec5d4ee5f46d249dafdaca98fb3d9b88a218eb00 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sat, 9 May 2020 14:39:33 +0200 Subject: [PATCH 28/46] Implementation of class AsyncClientSession This allows us to keep the execute method sync in the AsyncClient class Rename client to session in all tests It is now possible to execute GraphQL queries synchronously on asynchronous transports Returning ExecutionResult.data instead of ExecutionResult in AsyncClient (This corresponds to the previous usage of the library) --- README.md | 61 +++++---- gql/async_client.py | 145 +++++++++++++++++--- tests_py36/test_aiohttp_online.py | 53 +++---- tests_py36/test_async_client_validation.py | 59 +++++--- tests_py36/test_http_async_sync.py | 152 +++++++++++++++++++++ tests_py36/test_websocket_exceptions.py | 39 ++---- tests_py36/test_websocket_online.py | 71 ++++------ tests_py36/test_websocket_query.py | 110 +++++++-------- tests_py36/test_websocket_subscription.py | 54 +++----- tests_py36/websocket_fixtures.py | 6 +- 10 files changed, 482 insertions(+), 268 deletions(-) create mode 100644 tests_py36/test_http_async_sync.py diff --git a/README.md b/README.md index e6d6848c..a42568ac 100644 --- a/README.md +++ b/README.md @@ -113,10 +113,21 @@ query = gql(''' client.execute(query) ``` -# Async clients and transports +# Async usage with asyncio and subscriptions -It is possible to use async clients and transports using [asyncio](https://docs.python.org/3/library/asyncio.html). -Python3.6 is required for async clients and transports +When using the `execute` function directly on the client, the execution is synchronous. +It means that we are blocked until we receive an answer from the server and +we cannot do anything else while waiting for this answer. + +It is now possible to use this library asynchronously using [asyncio](https://docs.python.org/3/library/asyncio.html). + +Async Features: +* Execute GraphQL subscriptions (See [using the websockets transport](#Websockets-async-transport)) +* Execute GraphQL queries and subscriptions in parallel + +To use the async features, you need to use an async transport: +* [AIOHTTPTransport](#HTTP-async-transport) for the HTTP(s) protocols +* [WebsocketsTransport](#Websockets-async-transport) for the ws(s) protocols ## HTTP async transport @@ -137,10 +148,10 @@ async def main(): headers={'Authorization': 'token'} ) - async with AsyncClient(transport=sample_transport) as client: - - # Fetch schema (optional) - await client.fetch_schema() + async with AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=True, + ) as session: # Execute single query query = gql(''' @@ -152,9 +163,9 @@ async def main(): } ''') - result = await client.execute(query) + result = await session.execute(query) - print (f'result data = {result.data}, errors = {result.errors}') + print(result) asyncio.run(main()) ``` @@ -183,10 +194,10 @@ async def main(): headers={'Authorization': 'token'} ) - async with AsyncClient(transport=sample_transport) as client: - - # Fetch schema (optional) - await client.fetch_schema() + async with AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=True, + ) as session: # Execute single query query = gql(''' @@ -197,8 +208,8 @@ async def main(): } } ''') - result = await client.execute(query) - print (f'result data = {result.data}, errors = {result.errors}') + result = await session.execute(query) + print(result) # Request subscription subscription = gql(''' @@ -208,8 +219,8 @@ async def main(): } } ''') - async for result in client.subscribe(subscription): - print (f'result.data = {result.data}') + async for result in session.subscribe(subscription): + print(result) asyncio.run(main()) ``` @@ -285,20 +296,20 @@ on the same websocket connection, using asyncio tasks ```python async def execute_query1(): - result = await client.execute(query1) - print (f'result data = {result.data}, errors = {result.errors}') + result = await session.execute(query1) + print(result) async def execute_query2(): - result = await client.execute(query2) - print (f'result data = {result.data}, errors = {result.errors}') + result = await session.execute(query2) + print(result) async def execute_subscription1(): - async for result in client.subscribe(subscription1): - print (f'result data = {result.data}, errors = {result.errors}') + async for result in session.subscribe(subscription1): + print(result) async def execute_subscription2(): - async for result in client.subscribe(subscription2): - print (f'result data = {result.data}, errors = {result.errors}') + async for result in session.subscribe(subscription2): + print(result) task1 = asyncio.create_task(execute_query1()) task2 = asyncio.create_task(execute_query2()) diff --git a/gql/async_client.py b/gql/async_client.py index 2592f766..46026497 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -1,23 +1,34 @@ +import asyncio + from graphql import build_ast_schema, build_client_schema, introspection_query, parse -from graphql.execution import ExecutionResult from graphql.language.ast import Document -from typing import AsyncGenerator +from typing import AsyncGenerator, Dict from .transport.async_transport import AsyncTransport +from .transport.exceptions import TransportQueryError from .client import Client class AsyncClient(Client): def __init__( - self, schema=None, introspection=None, type_def=None, transport=None, + self, + schema=None, + introspection=None, + type_def=None, + transport=None, + fetch_schema_from_transport=False, ): - assert isinstance( - transport, AsyncTransport - ), "Only a transport of type AsyncTransport is supported on AsyncClient" assert not ( type_def and introspection ), "Cant provide introspection type definition at the same time" + if transport and fetch_schema_from_transport: + assert ( + not schema + ), "Cant fetch the schema from transport if is already provided" + if not isinstance(transport, AsyncTransport): + # For sync transports, we fetch the schema directly + introspection = transport.execute(parse(introspection_query)).data if introspection: assert not schema, "Cant provide introspection and schema at the same time" schema = build_client_schema(introspection) @@ -31,30 +42,120 @@ def __init__( self.schema = schema self.introspection = introspection self.transport = transport + self.fetch_schema_from_transport = fetch_schema_from_transport - async def subscribe( - self, document: Document, *args, **kwargs - ) -> AsyncGenerator[ExecutionResult, None]: - if self.schema: - self.validate(document) + async def _execute_in_async_session(self, document: Document, *args, **kwargs): + async with self as session: + return await session.execute(document, *args, **kwargs) - async for result in self.transport.subscribe(document, *args, **kwargs): - yield result + def execute(self, document: Document, *args, **kwargs) -> Dict: + """Execute the provided document AST against the configured remote server. - async def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: - if self.schema: - self.validate(document) + This function is synchronous and WILL BLOCK until the result is received from the server. - return await self.transport.execute(document, *args, **kwargs) + Either the transport is sync and we execute the query directly + OR the transport is async and we will create a new asyncio event loop to + execute the query in a synchronous way (blocking here until answer) + """ - async def fetch_schema(self) -> None: - execution_result = await self.transport.execute(parse(introspection_query)) - self.introspection = execution_result.data - self.schema = build_client_schema(self.introspection) + if isinstance(self.transport, AsyncTransport): + + loop = asyncio.new_event_loop() + + timeout = kwargs.get("timeout", 10) + result = loop.run_until_complete( + asyncio.wait_for( + self._execute_in_async_session(document, *args, **kwargs), timeout + ) + ) + + loop.stop() + loop.close() + + return result + + else: # Sync transports + + if self.schema: + self.validate(document) + + result = self.transport.execute(document, *args, **kwargs) + + if result.errors: + raise TransportQueryError(str(result.errors[0])) + + return result.data async def __aenter__(self): + + assert isinstance( + self.transport, AsyncTransport + ), "Only a transport of type AsyncTransport can be used asynchronously" + await self.transport.connect() - return self + + if not hasattr(self, "session"): + self.session = AsyncClientSession(client=self) + + return self.session async def __aexit__(self, *args): + await self.transport.close() + + +class AsyncClientSession: + """ An instance of this class is created when using 'async with' on the client. + + It contains the async methods (execute, subscribe) to send queries with the async transports""" + + def __init__(self, client: AsyncClient): + self.client = client + + async def validate(self, document: Document): + """ Fetch schema from transport if needed and validate document if schema is present """ + + # Get schema from transport if needed + if self.client.fetch_schema_from_transport and not self.client.schema: + await self.fetch_schema() + + # Validate document + if self.client.schema: + self.client.validate(document) + + async def subscribe( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[Dict, None]: + + # Fetch schema from transport if needed and validate document if schema is present + await self.validate(document) + + # Subscribe to the transport and yield data or raise error + async for result in self.transport.subscribe(document, *args, **kwargs): + if result.errors: + raise TransportQueryError(str(result.errors[0])) + + yield result.data + + async def execute(self, document: Document, *args, **kwargs) -> Dict: + + # Fetch schema from transport if needed and validate document if schema is present + await self.validate(document) + + # Execute the query with the transport + result = await self.transport.execute(document, *args, **kwargs) + + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError(str(result.errors[0])) + + return result.data + + async def fetch_schema(self) -> None: + execution_result = await self.transport.execute(parse(introspection_query)) + self.client.introspection = execution_result.data + self.client.schema = build_client_schema(self.client.introspection) + + @property + def transport(self): + return self.client.transport diff --git a/tests_py36/test_aiohttp_online.py b/tests_py36/test_aiohttp_online.py index d495c814..e21d3b46 100644 --- a/tests_py36/test_aiohttp_online.py +++ b/tests_py36/test_aiohttp_online.py @@ -4,7 +4,7 @@ from gql import gql, AsyncClient from gql.transport.aiohttp import AIOHTTPTransport -from graphql.execution import ExecutionResult +from gql.transport.exceptions import TransportQueryError from typing import Dict @@ -20,7 +20,7 @@ async def test_aiohttp_simple_query(event_loop, protocol): sample_transport = AIOHTTPTransport(url=url) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -34,20 +34,17 @@ async def test_aiohttp_simple_query(event_loop, protocol): ) # Fetch schema - await client.fetch_schema() + await session.fetch_schema() # Execute query - result = await client.execute(query) + result = await session.execute(query) # Verify result - assert isinstance(result, ExecutionResult) - assert result.errors is None + assert isinstance(result, Dict) - assert isinstance(result.data, Dict) + print(result) - print(result.data) - - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] @@ -62,7 +59,7 @@ async def test_aiohttp_invalid_query(event_loop): url="https://countries.trevorblades.com/graphql" ) - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -75,14 +72,8 @@ async def test_aiohttp_invalid_query(event_loop): """ ) - result = await client.execute(query) - - assert isinstance(result, ExecutionResult) - - assert result.data is None - - print(f"result = {repr(result.data)}, {repr(result.errors)}") - assert result.errors is not None + with pytest.raises(TransportQueryError): + await session.execute(query) @pytest.mark.online @@ -95,7 +86,7 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query1 = gql( """ @@ -118,31 +109,25 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): ) async def query_task1(): - result = await client.execute(query1) - - assert isinstance(result, ExecutionResult) - assert result.errors is None + result = await session.execute(query1) - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - print(result.data) + print(result) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" async def query_task2(): - result = await client.execute(query2) - - assert isinstance(result, ExecutionResult) - assert result.errors is None + result = await session.execute(query2) - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - print(result.data) + print(result) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["name"] == "Africa" diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index 81d79bed..acded7af 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -5,7 +5,6 @@ import graphql from .websocket_fixtures import MS, server, client_and_server, TestServer -from graphql.execution import ExecutionResult from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport from tests_py36.schema import StarWarsSchema, StarWarsTypeDef, StarWarsIntrospection @@ -90,7 +89,7 @@ async def test_async_client_validation( sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport, **client_params) as client: + async with AsyncClient(transport=sample_transport, **client_params) as session: variable_values = {"ep": "JEDI"} @@ -98,14 +97,11 @@ async def test_async_client_validation( expected = [] - async for result in client.subscribe( + async for result in session.subscribe( subscription, variable_values=variable_values ): - assert isinstance(result, ExecutionResult) - assert result.errors is None - - review = result.data["reviewAdded"] + review = result["reviewAdded"] expected.append(review) assert "stars" in review @@ -135,14 +131,14 @@ async def test_async_client_validation_invalid_query( sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport, **client_params) as client: + async with AsyncClient(transport=sample_transport, **client_params) as session: variable_values = {"ep": "JEDI"} subscription = gql(subscription_str) with pytest.raises(graphql.error.base.GraphQLError): - async for result in client.subscribe( + async for result in session.subscribe( subscription, variable_values=variable_values ): pass @@ -183,14 +179,15 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( async def test_async_client_validation_fetch_schema_from_server_valid_query( event_loop, client_and_server ): - client, server = client_and_server + session, server = client_and_server + client = session.client # No schema in the client at the beginning assert client.introspection is None assert client.schema is None # Fetch schema from server - await client.fetch_schema() + await session.fetch_schema() # Check that the async client correctly recreated the schema assert client.introspection == StarWarsIntrospection @@ -206,12 +203,12 @@ async def test_async_client_validation_fetch_schema_from_server_valid_query( """ ) - result = await client.execute(query) + result = await session.execute(query) - print("Client received: " + str(result.data)) + print("Client received: " + str(result)) expected = {"hero": {"name": "R2-D2"}} - assert result.data == expected + assert result == expected @pytest.mark.asyncio @@ -219,10 +216,10 @@ async def test_async_client_validation_fetch_schema_from_server_valid_query( async def test_async_client_validation_fetch_schema_from_server_invalid_query( event_loop, client_and_server ): - client, server = client_and_server + session, server = client_and_server # Fetch schema from server - await client.fetch_schema() + await session.fetch_schema() query = gql( """ @@ -236,4 +233,32 @@ async def test_async_client_validation_fetch_schema_from_server_invalid_query( ) with pytest.raises(graphql.error.base.GraphQLError): - await client.execute(query) + await session.execute(query) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [hero_server_answers], indirect=True) +async def test_async_client_validation_fetch_schema_from_server_with_client_argument( + event_loop, server +): + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + sample_transport = WebsocketsTransport(url=url) + + async with AsyncClient( + transport=sample_transport, fetch_schema_from_transport=True, + ) as session: + + query = gql( + """ + query HeroNameQuery { + hero { + name + sldkfjqlmsdkjfqlskjfmlqkjsfmkjqsdf + } + } + """ + ) + + with pytest.raises(graphql.error.base.GraphQLError): + await session.execute(query) diff --git a/tests_py36/test_http_async_sync.py b/tests_py36/test_http_async_sync.py new file mode 100644 index 00000000..86d33eea --- /dev/null +++ b/tests_py36/test_http_async_sync.py @@ -0,0 +1,152 @@ +import pytest + +from gql import gql, AsyncClient +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.requests import RequestsHTTPTransport + + +@pytest.mark.online +@pytest.mark.asyncio +@pytest.mark.parametrize("protocol", ["http", "https"]) +@pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) +async def test_async_client_async_transport( + event_loop, protocol, fetch_schema_from_transport +): + + # Create http or https url + url = f"{protocol}://countries.trevorblades.com/graphql" + + # Get async transport + sample_transport = AIOHTTPTransport(url=url) + + # Instanciate client + async with AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=fetch_schema_from_transport, + ) as session: + + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Execute query + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + if fetch_schema_from_transport: + assert session.client.schema is not None + + +@pytest.mark.online +@pytest.mark.asyncio +@pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) +async def test_async_client_sync_transport(event_loop, fetch_schema_from_transport): + + url = "http://countries.trevorblades.com/graphql" + + # Get sync transport + sample_transport = RequestsHTTPTransport(url=url, use_json=True) + + # Impossible to use a sync transport asynchronously + with pytest.raises(AssertionError): + async with AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=fetch_schema_from_transport, + ): + pass + + sample_transport.close() + + +@pytest.mark.online +@pytest.mark.parametrize("protocol", ["http", "https"]) +@pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) +def test_sync_client_async_transport(protocol, fetch_schema_from_transport): + + # Create http or https url + url = f"{protocol}://countries.trevorblades.com/graphql" + + # Get async transport + sample_transport = AIOHTTPTransport(url=url) + + # Instanciate client + client = AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=fetch_schema_from_transport, + ) + + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Execute query synchronously + result = client.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + if fetch_schema_from_transport: + assert client.schema is not None + + +@pytest.mark.online +@pytest.mark.parametrize("protocol", ["http", "https"]) +@pytest.mark.parametrize("fetch_schema_from_transport", [True, False]) +def test_sync_client_sync_transport(protocol, fetch_schema_from_transport): + + # Create http or https url + url = f"{protocol}://countries.trevorblades.com/graphql" + + # Get sync transport + sample_transport = RequestsHTTPTransport(url=url, use_json=True) + + # Instanciate client + client = AsyncClient( + transport=sample_transport, + fetch_schema_from_transport=fetch_schema_from_transport, + ) + + query = gql( + """ + query getContinents { + continents { + code + name + } + } + """ + ) + + # Execute query synchronously + result = client.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + if fetch_schema_from_transport: + assert client.schema is not None diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 45c68eb6..473e2308 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -5,7 +5,6 @@ import types from .websocket_fixtures import MS, server, client_and_server, TestServer -from graphql.execution import ExecutionResult from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import ( @@ -40,20 +39,12 @@ @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_websocket_invalid_query(event_loop, client_and_server, query_str): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) - result = await client.execute(query) - - print("Client received: " + str(result.data)) - - assert isinstance(result, ExecutionResult) - - print(f"result = {repr(result.data)}, {repr(result.errors)}") - - assert result.data is None - assert result.errors is not None + with pytest.raises(TransportQueryError): + await session.execute(query) connection_error_server_answer = ( @@ -75,11 +66,11 @@ async def server_connection_error(ws, path): @pytest.mark.parametrize("query_str", [invalid_query_str]) async def test_websocket_sending_invalid_data(event_loop, client_and_server, query_str): - client, server = client_and_server + session, server = client_and_server invalid_data = "QSDF" print(f">>> {invalid_data}") - await client.transport.websocket.send(invalid_data) + await session.transport.websocket.send(invalid_data) await asyncio.sleep(2 * MS) @@ -105,7 +96,7 @@ async def test_websocket_sending_invalid_payload( event_loop, client_and_server, query_str ): - client, server = client_and_server + session, server = client_and_server # Monkey patching the _send_query method to send an invalid payload @@ -122,14 +113,14 @@ async def monkey_patch_send_query( await self._send(query_str) return query_id - client.transport._send_query = types.MethodType( - monkey_patch_send_query, client.transport + session.transport._send_query = types.MethodType( + monkey_patch_send_query, session.transport ) query = gql(query_str) with pytest.raises(TransportQueryError): - await client.execute(query) + await session.execute(query) not_json_answer = ["BLAHBLAH"] @@ -163,12 +154,12 @@ async def monkey_patch_send_query( ) async def test_websocket_transport_protocol_errors(event_loop, client_and_server): - client, server = client_and_server + session, server = client_and_server query = gql("query { hello }") with pytest.raises(TransportProtocolError): - await client.execute(query) + await session.execute(query) async def server_without_ack(ws, path): @@ -218,17 +209,17 @@ async def server_closing_after_ack(ws, path): @pytest.mark.parametrize("server", [server_closing_after_ack], indirect=True) async def test_websocket_server_closing_after_ack(event_loop, client_and_server): - client, server = client_and_server + session, server = client_and_server query = gql("query { hello }") with pytest.raises(websockets.exceptions.ConnectionClosed): - await client.execute(query) + await session.execute(query) - await client.transport.wait_closed() + await session.transport.wait_closed() with pytest.raises(TransportClosed): - await client.execute(query) + await session.execute(query) async def server_sending_invalid_query_errors(ws, path): diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index afbbe573..901a0e22 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -5,8 +5,7 @@ from gql import gql, AsyncClient from gql.transport.websockets import WebsocketsTransport -from gql.transport.exceptions import TransportError -from graphql.execution import ExecutionResult +from gql.transport.exceptions import TransportError, TransportQueryError from typing import Dict from .websocket_fixtures import MS @@ -23,7 +22,7 @@ async def test_websocket_simple_query(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -37,18 +36,15 @@ async def test_websocket_simple_query(): ) # Fetch schema - await client.fetch_schema() + await session.fetch_schema() # Execute query - result = await client.execute(query) + result = await session.execute(query) # Verify result - assert isinstance(result, ExecutionResult) - assert result.errors is None + assert isinstance(result, Dict) - assert isinstance(result.data, Dict) - - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] @@ -65,7 +61,7 @@ async def test_websocket_invalid_query(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -79,15 +75,8 @@ async def test_websocket_invalid_query(): ) # Execute query - result = await client.execute(query) - - # Verify result - assert isinstance(result, ExecutionResult) - - assert result.data is None - - print(f"result = {repr(result.data)}, {repr(result.errors)}") - assert result.errors is not None + with pytest.raises(TransportQueryError): + await session.execute(query) @pytest.mark.online @@ -100,7 +89,7 @@ async def test_websocket_sending_invalid_data(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -113,14 +102,9 @@ async def test_websocket_sending_invalid_data(): ) # Execute query - result = await client.execute(query) + result = await session.execute(query) - # Verify result - assert isinstance(result, ExecutionResult) - - print(f"result = {repr(result.data)}, {repr(result.errors)}") - - assert result.errors is None + print(f"result = {result!r}") invalid_data = "QSDF" print(f">>> {invalid_data}") @@ -160,7 +144,7 @@ async def test_websocket_sending_invalid_data_while_other_query_is_running(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query = gql( """ @@ -176,14 +160,11 @@ async def query_task1(): await asyncio.sleep(2 * MS) with pytest.raises(TransportError): - result = await client.execute(query) - - assert isinstance(result, ExecutionResult) - assert result.errors is None + result = await session.execute(query) - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" @@ -213,7 +194,7 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query1 = gql( """ @@ -236,27 +217,21 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): ) async def query_task1(): - result = await client.execute(query1) + result = await session.execute(query1) - assert isinstance(result, ExecutionResult) - assert result.errors is None + assert isinstance(result, Dict) - assert isinstance(result.data, Dict) - - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" async def query_task2(): - result = await client.execute(query2) - - assert isinstance(result, ExecutionResult) - assert result.errors is None + result = await session.execute(query2) - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["name"] == "Africa" diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 3e86b064..9c0c62a4 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -4,7 +4,6 @@ import json from .websocket_fixtures import MS, server, client_and_server, TestServer -from graphql.execution import ExecutionResult from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import ( TransportClosed, @@ -46,7 +45,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: assert isinstance( sample_transport.websocket, websockets.client.WebSocketClientProtocol @@ -54,17 +53,14 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): query1 = gql(query1_str) - result = await client.execute(query1) + result = await session.execute(query1) - assert isinstance(result, ExecutionResult) - - print("Client received: " + str(result.data)) + print("Client received: " + str(result)) # Verify result - assert result.errors is None - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" @@ -78,13 +74,13 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_simple_query(event_loop, client_and_server, query_str): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) - result = await client.execute(query) + result = await session.execute(query) - print("Client received: " + str(result.data)) + print("Client received: " + str(result)) server1_two_answers_in_series = [ @@ -100,19 +96,19 @@ async def test_websocket_two_queries_in_series( event_loop, client_and_server, query_str ): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) - result1 = await client.execute(query) + result1 = await session.execute(query) - print("Query1 received: " + str(result1.data)) + print("Query1 received: " + str(result1)) - result2 = await client.execute(query) + result2 = await session.execute(query) - print("Query2 received: " + str(result2.data)) + print("Query2 received: " + str(result2)) - assert str(result1.data) == str(result2.data) + assert result1 == result2 async def server1_two_queries_in_parallel(ws, path): @@ -136,7 +132,7 @@ async def test_websocket_two_queries_in_parallel( event_loop, client_and_server, query_str ): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) @@ -145,21 +141,21 @@ async def test_websocket_two_queries_in_parallel( async def task1_coro(): nonlocal result1 - result1 = await client.execute(query) + result1 = await session.execute(query) async def task2_coro(): nonlocal result2 - result2 = await client.execute(query) + result2 = await session.execute(query) task1 = asyncio.ensure_future(task1_coro()) task2 = asyncio.ensure_future(task2_coro()) await asyncio.gather(task1, task2) - print("Query1 received: " + str(result1.data)) - print("Query2 received: " + str(result2.data)) + print("Query1 received: " + str(result1)) + print("Query2 received: " + str(result2)) - assert str(result1.data) == str(result2.data) + assert result1 == result2 async def server_closing_while_we_are_doing_something_else(ws, path): @@ -183,16 +179,12 @@ async def test_websocket_server_closing_after_first_query( event_loop, client_and_server, query_str ): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) # First query is working - result = await client.execute(query) - - assert isinstance(result, ExecutionResult) - assert result.data is not None - assert result.errors is None + await session.execute(query) # Then we do other things await asyncio.sleep(2 * MS) @@ -202,7 +194,7 @@ async def test_websocket_server_closing_after_first_query( # Now the server is closed but we don't know it yet, we have to send a query # to notice it and to receive the exception with pytest.raises(TransportClosed): - result = await client.execute(query) + await session.execute(query) ignore_invalid_id_answers = [ @@ -217,37 +209,32 @@ async def test_websocket_server_closing_after_first_query( @pytest.mark.parametrize("query_str", [query1_str]) async def test_websocket_ignore_invalid_id(event_loop, client_and_server, query_str): - client, server = client_and_server + session, server = client_and_server query = gql(query_str) # First query is working - result = await client.execute(query) - assert isinstance(result, ExecutionResult) + await session.execute(query) # Second query gets no answer -> raises with pytest.raises(TransportQueryError): - result = await client.execute(query) + await session.execute(query) # Third query is working - result = await client.execute(query) - assert isinstance(result, ExecutionResult) + await session.execute(query) -async def assert_client_is_working(client): +async def assert_client_is_working(session): query1 = gql(query1_str) - result = await client.execute(query1) + result = await session.execute(query1) - assert isinstance(result, ExecutionResult) - - print("Client received: " + str(result.data)) + print("Client received: " + str(result)) # Verify result - assert result.errors is None - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" @@ -262,14 +249,14 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as client: - await assert_client_is_working(client) + async with AsyncClient(transport=sample_transport) as session: + await assert_client_is_working(session) # Check client is disconnect here assert sample_transport.websocket is None - async with AsyncClient(transport=sample_transport) as client: - await assert_client_is_working(client) + async with AsyncClient(transport=sample_transport) as session: + await assert_client_is_working(session) # Check client is disconnect here assert sample_transport.websocket is None @@ -284,8 +271,8 @@ async def test_websocket_multiple_connections_in_parallel(event_loop, server): async def task_coro(): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as client: - await assert_client_is_working(client) + async with AsyncClient(transport=sample_transport) as session: + await assert_client_is_working(session) task1 = asyncio.ensure_future(task_coro()) task2 = asyncio.ensure_future(task_coro()) @@ -303,8 +290,8 @@ async def test_websocket_trying_to_connect_to_already_connected_transport( print(f"url = {url}") sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as client: - await assert_client_is_working(client) + async with AsyncClient(transport=sample_transport) as session: + await assert_client_is_working(session) with pytest.raises(TransportAlreadyConnected): async with AsyncClient(transport=sample_transport): @@ -353,21 +340,18 @@ async def test_websocket_connect_success_with_authentication_in_connection_init( sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query1 = gql(query_str) - result = await client.execute(query1) - - assert isinstance(result, ExecutionResult) + result = await session.execute(query1) - print("Client received: " + str(result.data)) + print("Client received: " + str(result)) # Verify result - assert result.errors is None - assert isinstance(result.data, Dict) + assert isinstance(result, Dict) - continents = result.data["continents"] + continents = result["continents"] africa = continents[0] assert africa["code"] == "AF" @@ -389,7 +373,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) with pytest.raises(TransportServerError): - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: query1 = gql(query_str) - await client.execute(query1) + await session.execute(query1) diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 6df530a7..c05b5897 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -5,7 +5,6 @@ from parse import search from .websocket_fixtures import MS, server, client_and_server, TestServer -from graphql.execution import ExecutionResult from gql import gql @@ -103,15 +102,14 @@ async def keepalive_coro(): @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) async def test_websocket_subscription(event_loop, client_and_server, subscription_str): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count @@ -127,15 +125,14 @@ async def test_websocket_subscription_break( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count @@ -155,17 +152,16 @@ async def test_websocket_subscription_task_cancel( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) async def task_coro(): nonlocal count - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count @@ -195,17 +191,16 @@ async def test_websocket_subscription_close_transport( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) async def task_coro(): nonlocal count - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count @@ -219,7 +214,7 @@ async def close_transport_task_coro(): await asyncio.sleep(11 * MS) - await client.transport.close() + await session.transport.close() close_transport_task = asyncio.ensure_future(close_transport_task_coro()) @@ -261,25 +256,22 @@ async def test_websocket_subscription_server_connection_closed( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) with pytest.raises(websockets.exceptions.ConnectionClosedOK): - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count count -= 1 - assert count > 0 - @pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @@ -288,16 +280,15 @@ async def test_websocket_subscription_slow_consumer( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in client.subscribe(subscription): + async for result in session.subscribe(subscription): await asyncio.sleep(10 * MS) - assert isinstance(result, ExecutionResult) - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count @@ -317,15 +308,14 @@ async def test_websocket_subscription_with_keepalive( event_loop, client_and_server, subscription_str ): - client, server = client_and_server + session, server = client_and_server count = 10 subscription = gql(subscription_str.format(count=count)) - async for result in client.subscribe(subscription): - assert isinstance(result, ExecutionResult) + async for result in session.subscribe(subscription): - number = result.data["number"] + number = result["number"] print(f"Number received: {number}") assert number == count diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py index c00fc670..c1a21233 100644 --- a/tests_py36/websocket_fixtures.py +++ b/tests_py36/websocket_fixtures.py @@ -153,7 +153,7 @@ async def client_and_server(server): url = "ws://" + server.hostname + ":" + str(server.port) + path sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as client: + async with AsyncClient(transport=sample_transport) as session: - # Yield both client and server - yield (client, server) + # Yield both client session and server + yield (session, server) From 789083f5ca2d6051598b6b1607badab16e0500da Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 May 2020 10:26:54 +0200 Subject: [PATCH 29/46] Adding subscribe generator (not async) to AsyncClient Now client.execute will use the existing asyncio loop if using an async transport. --- gql/async_client.py | 57 ++++++++++++++++------- tests_py36/test_websocket_subscription.py | 28 ++++++++++- 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/gql/async_client.py b/gql/async_client.py index 46026497..b47c7e14 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -1,9 +1,10 @@ import asyncio from graphql import build_ast_schema, build_client_schema, introspection_query, parse +from graphql.execution import ExecutionResult from graphql.language.ast import Document -from typing import AsyncGenerator, Dict +from typing import Generator, AsyncGenerator, Dict, Any, cast from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportQueryError @@ -44,47 +45,69 @@ def __init__( self.transport = transport self.fetch_schema_from_transport = fetch_schema_from_transport - async def _execute_in_async_session(self, document: Document, *args, **kwargs): + async def execute_async(self, document: Document, *args, **kwargs) -> Dict: async with self as session: return await session.execute(document, *args, **kwargs) def execute(self, document: Document, *args, **kwargs) -> Dict: """Execute the provided document AST against the configured remote server. - This function is synchronous and WILL BLOCK until the result is received from the server. + This function WILL BLOCK until the result is received from the server. - Either the transport is sync and we execute the query directly - OR the transport is async and we will create a new asyncio event loop to - execute the query in a synchronous way (blocking here until answer) + Either the transport is sync and we execute the query synchronously directly + OR the transport is async and we execute the query in the asyncio loop (blocking here until answer) """ if isinstance(self.transport, AsyncTransport): - loop = asyncio.new_event_loop() + loop = asyncio.get_event_loop() timeout = kwargs.get("timeout", 10) - result = loop.run_until_complete( - asyncio.wait_for( - self._execute_in_async_session(document, *args, **kwargs), timeout - ) + data: Dict[Any, Any] = loop.run_until_complete( + asyncio.wait_for(self.execute_async(document, *args, **kwargs), timeout) ) - loop.stop() - loop.close() - - return result + return data else: # Sync transports if self.schema: self.validate(document) - result = self.transport.execute(document, *args, **kwargs) + result: ExecutionResult = self.transport.execute(document, *args, **kwargs) if result.errors: raise TransportQueryError(str(result.errors[0])) - return result.data + # Running cast to make mypy happy. result.data should never be None here + return cast(Dict[Any, Any], result.data) + + async def subscribe_async( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[Dict, None]: + async with self as session: + async for result in session.subscribe(document, *args, **kwargs): + yield result + + def subscribe( + self, document: Document, *args, **kwargs + ) -> Generator[Dict, None, None]: + """Execute a GraphQL subscription with a python generator. + + We need an async transport for this functionality. + """ + + async_generator = self.subscribe_async(document, *args, **kwargs) + + loop = asyncio.get_event_loop() + + try: + while True: + result = loop.run_until_complete(async_generator.__anext__()) + yield result + + except StopAsyncIteration: + pass async def __aenter__(self): diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index c05b5897..2b155787 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -5,7 +5,8 @@ from parse import search from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql +from gql import gql, AsyncClient +from gql.transport.websockets import WebsocketsTransport countdown_server_answer = ( @@ -322,3 +323,28 @@ async def test_websocket_subscription_with_keepalive( count -= 1 assert count == -1 + + +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +def test_websocket_subscription_sync(server, subscription_str): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = AsyncClient(transport=sample_transport) + + count = 10 + subscription = gql(subscription_str.format(count=count)) + + for result in client.subscribe(subscription): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 From 41645be13ff13423ef2975dc4a8432ae90c701a9 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 May 2020 11:02:59 +0200 Subject: [PATCH 30/46] Using AsyncClient as the gql.Client if python>3.6 Rename AsyncClient to Client in all tests Put back the LocalSchemaTransport in the AsyncClient class All the old tests are passing with the AsyncClient as the Client in python > 3.6 retries on the client is now removed --- gql/__init__.py | 10 ++++------ gql/async_client.py | 17 +++++++++++++++++ tests/test_client.py | 6 ++++-- tests_py36/test_aiohttp_online.py | 8 ++++---- tests_py36/test_async_client_validation.py | 10 +++++----- tests_py36/test_http_async_sync.py | 10 +++++----- tests_py36/test_websocket_exceptions.py | 8 ++++---- tests_py36/test_websocket_online.py | 14 +++++++------- tests_py36/test_websocket_query.py | 18 +++++++++--------- tests_py36/test_websocket_subscription.py | 4 ++-- tests_py36/websocket_fixtures.py | 4 ++-- 11 files changed, 63 insertions(+), 46 deletions(-) diff --git a/gql/__init__.py b/gql/__init__.py index 577a71e2..2651c40e 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,12 +1,10 @@ import sys from .gql import gql -from .client import Client - -__all__ = ["gql", "Client"] if sys.version_info > (3, 6): - from .async_client import AsyncClient + from .async_client import AsyncClient as Client +else: + from .client import Client - # Cannot use __all__.append here because of flake8 warning - __all__ = ["gql", "Client", "AsyncClient"] +__all__ = ["gql", "Client"] diff --git a/gql/async_client.py b/gql/async_client.py index b47c7e14..443f280f 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -6,6 +6,7 @@ from typing import Generator, AsyncGenerator, Dict, Any, cast +from .transport.local_schema import LocalSchemaTransport from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportQueryError from .client import Client @@ -39,6 +40,8 @@ def __init__( ), "Cant provide Type definition and schema at the same time" type_def_ast = parse(type_def) schema = build_ast_schema(type_def_ast) + elif schema and not transport: + transport = LocalSchemaTransport(schema) self.schema = schema self.introspection = introspection @@ -126,6 +129,20 @@ async def __aexit__(self, *args): await self.transport.close() + def close(self): + """Close the client and it's underlying transport (only for Sync transports)""" + if not isinstance(self.transport, AsyncTransport): + self.transport.close() + + def __enter__(self): + assert not isinstance( + self.transport, AsyncTransport + ), "Only a sync transport can be use. Use 'async with Client(...)' instead" + return self + + def __exit__(self, *args): + self.close() + class AsyncClientSession: """ An instance of this class is created when using 'async with' on the client. diff --git a/tests/test_client.py b/tests/test_client.py index ac6675ba..78ac0914 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,5 @@ import os +import sys import mock import pytest @@ -34,6 +35,9 @@ def execute(self): assert "Any Transport subclass must implement execute method" == str(exc_info.value) +@pytest.mark.skipif( + sys.version_info > (3, 6), reason="retries on client deprecated in latest versions" +) @mock.patch("gql.transport.requests.RequestsHTTPTransport.execute") def test_retries(execute_mock): expected_retries = 3 @@ -108,10 +112,8 @@ def test_no_schema_exception(): def test_execute_result_error(): - expected_retries = 3 client = Client( - retries=expected_retries, transport=RequestsHTTPTransport( url="https://countries.trevorblades.com/", use_json=True, diff --git a/tests_py36/test_aiohttp_online.py b/tests_py36/test_aiohttp_online.py index e21d3b46..3d6940d7 100644 --- a/tests_py36/test_aiohttp_online.py +++ b/tests_py36/test_aiohttp_online.py @@ -2,7 +2,7 @@ import pytest import sys -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.exceptions import TransportQueryError from typing import Dict @@ -20,7 +20,7 @@ async def test_aiohttp_simple_query(event_loop, protocol): sample_transport = AIOHTTPTransport(url=url) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -59,7 +59,7 @@ async def test_aiohttp_invalid_query(event_loop): url="https://countries.trevorblades.com/graphql" ) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -86,7 +86,7 @@ async def test_aiohttp_two_queries_in_parallel_using_two_tasks(event_loop): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query1 = gql( """ diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index acded7af..c1f40763 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -5,7 +5,7 @@ import graphql from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport from tests_py36.schema import StarWarsSchema, StarWarsTypeDef, StarWarsIntrospection @@ -89,7 +89,7 @@ async def test_async_client_validation( sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport, **client_params) as session: + async with Client(transport=sample_transport, **client_params) as session: variable_values = {"ep": "JEDI"} @@ -131,7 +131,7 @@ async def test_async_client_validation_invalid_query( sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport, **client_params) as session: + async with Client(transport=sample_transport, **client_params) as session: variable_values = {"ep": "JEDI"} @@ -164,7 +164,7 @@ async def test_async_client_validation_different_schemas_parameters_forbidden( sample_transport = WebsocketsTransport(url=url) with pytest.raises(AssertionError): - async with AsyncClient(transport=sample_transport, **client_params): + async with Client(transport=sample_transport, **client_params): pass @@ -245,7 +245,7 @@ async def test_async_client_validation_fetch_schema_from_server_with_client_argu sample_transport = WebsocketsTransport(url=url) - async with AsyncClient( + async with Client( transport=sample_transport, fetch_schema_from_transport=True, ) as session: diff --git a/tests_py36/test_http_async_sync.py b/tests_py36/test_http_async_sync.py index 86d33eea..e361cedb 100644 --- a/tests_py36/test_http_async_sync.py +++ b/tests_py36/test_http_async_sync.py @@ -1,6 +1,6 @@ import pytest -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.requests import RequestsHTTPTransport @@ -20,7 +20,7 @@ async def test_async_client_async_transport( sample_transport = AIOHTTPTransport(url=url) # Instanciate client - async with AsyncClient( + async with Client( transport=sample_transport, fetch_schema_from_transport=fetch_schema_from_transport, ) as session: @@ -61,7 +61,7 @@ async def test_async_client_sync_transport(event_loop, fetch_schema_from_transpo # Impossible to use a sync transport asynchronously with pytest.raises(AssertionError): - async with AsyncClient( + async with Client( transport=sample_transport, fetch_schema_from_transport=fetch_schema_from_transport, ): @@ -82,7 +82,7 @@ def test_sync_client_async_transport(protocol, fetch_schema_from_transport): sample_transport = AIOHTTPTransport(url=url) # Instanciate client - client = AsyncClient( + client = Client( transport=sample_transport, fetch_schema_from_transport=fetch_schema_from_transport, ) @@ -123,7 +123,7 @@ def test_sync_client_sync_transport(protocol, fetch_schema_from_transport): sample_transport = RequestsHTTPTransport(url=url, use_json=True) # Instanciate client - client = AsyncClient( + client = Client( transport=sample_transport, fetch_schema_from_transport=fetch_schema_from_transport, ) diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 473e2308..e359b249 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -5,7 +5,7 @@ import types from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import ( TransportProtocolError, @@ -178,7 +178,7 @@ async def test_websocket_server_does_not_ack(event_loop, server): sample_transport = WebsocketsTransport(url=url) with pytest.raises(TransportProtocolError): - async with AsyncClient(transport=sample_transport): + async with Client(transport=sample_transport): pass @@ -196,7 +196,7 @@ async def test_websocket_server_closing_directly(event_loop, server): sample_transport = WebsocketsTransport(url=url) with pytest.raises(websockets.exceptions.ConnectionClosed): - async with AsyncClient(transport=sample_transport): + async with Client(transport=sample_transport): pass @@ -238,5 +238,5 @@ async def test_websocket_server_sending_invalid_query_errors(event_loop, server) sample_transport = WebsocketsTransport(url=url) # Invalid server message is ignored - async with AsyncClient(transport=sample_transport): + async with Client(transport=sample_transport): await asyncio.sleep(2 * MS) diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index 901a0e22..8a0b717c 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -3,7 +3,7 @@ import pytest import sys -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport from gql.transport.exceptions import TransportError, TransportQueryError from typing import Dict @@ -22,7 +22,7 @@ async def test_websocket_simple_query(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -61,7 +61,7 @@ async def test_websocket_invalid_query(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -89,7 +89,7 @@ async def test_websocket_sending_invalid_data(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -123,7 +123,7 @@ async def test_websocket_sending_invalid_payload(): ) # Instanciate client - async with AsyncClient(transport=sample_transport): + async with Client(transport=sample_transport): invalid_payload = '{"id": "1", "type": "start", "payload": "BLAHBLAH"}' @@ -144,7 +144,7 @@ async def test_websocket_sending_invalid_data_while_other_query_is_running(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query = gql( """ @@ -194,7 +194,7 @@ async def test_websocket_two_queries_in_parallel_using_two_tasks(): ) # Instanciate client - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query1 = gql( """ diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 9c0c62a4..8ab09aa6 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -11,7 +11,7 @@ TransportServerError, TransportAlreadyConnected, ) -from gql import gql, AsyncClient +from gql import gql, Client from typing import Dict @@ -45,7 +45,7 @@ async def test_websocket_starting_client_in_context_manager(event_loop, server): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: assert isinstance( sample_transport.websocket, websockets.client.WebSocketClientProtocol @@ -249,13 +249,13 @@ async def test_websocket_multiple_connections_in_series(event_loop, server): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: await assert_client_is_working(session) # Check client is disconnect here assert sample_transport.websocket is None - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: await assert_client_is_working(session) # Check client is disconnect here @@ -271,7 +271,7 @@ async def test_websocket_multiple_connections_in_parallel(event_loop, server): async def task_coro(): sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: await assert_client_is_working(session) task1 = asyncio.ensure_future(task_coro()) @@ -290,11 +290,11 @@ async def test_websocket_trying_to_connect_to_already_connected_transport( print(f"url = {url}") sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: await assert_client_is_working(session) with pytest.raises(TransportAlreadyConnected): - async with AsyncClient(transport=sample_transport): + async with Client(transport=sample_transport): pass @@ -340,7 +340,7 @@ async def test_websocket_connect_success_with_authentication_in_connection_init( sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query1 = gql(query_str) @@ -373,7 +373,7 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( sample_transport = WebsocketsTransport(url=url, init_payload=init_payload) with pytest.raises(TransportServerError): - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: query1 = gql(query_str) await session.execute(query1) diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 2b155787..ee482056 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -5,7 +5,7 @@ from parse import search from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport @@ -334,7 +334,7 @@ def test_websocket_subscription_sync(server, subscription_str): sample_transport = WebsocketsTransport(url=url) - client = AsyncClient(transport=sample_transport) + client = Client(transport=sample_transport) count = 10 subscription = gql(subscription_str.format(count=count)) diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py index c1a21233..346adabe 100644 --- a/tests_py36/websocket_fixtures.py +++ b/tests_py36/websocket_fixtures.py @@ -8,7 +8,7 @@ from gql.transport.websockets import WebsocketsTransport from websockets.exceptions import ConnectionClosed -from gql import AsyncClient +from gql import Client # Adding debug logs to websocket tests for name in ["websockets.server", "gql.transport.websockets"]: @@ -153,7 +153,7 @@ async def client_and_server(server): url = "ws://" + server.hostname + ":" + str(server.port) + path sample_transport = WebsocketsTransport(url=url) - async with AsyncClient(transport=sample_transport) as session: + async with Client(transport=sample_transport) as session: # Yield both client session and server yield (session, server) From 8d277c5c0d89d81154c7c4866d21bda0fdb324f8 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 May 2020 11:49:48 +0200 Subject: [PATCH 31/46] Update gql-cli script to use http or websockets transport depending on server url Adding yarl dependency (already in aiohttp...) --- scripts/gql-cli | 34 +++++++++++++++++++++++++++------- setup.py | 7 +++++-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/scripts/gql-cli b/scripts/gql-cli index 46cbca1b..533c72ff 100755 --- a/scripts/gql-cli +++ b/scripts/gql-cli @@ -1,19 +1,35 @@ #!/usr/bin/env python3 -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport +from gql.transport.aiohttp import AIOHTTPTransport +from yarl import URL import asyncio import argparse -parser = argparse.ArgumentParser(description='Send GraphQL queries from command line to a websocket endpoint') -parser.add_argument('server', help='the server websocket url starting with ws:// or wss://') +parser = argparse.ArgumentParser( + description="Send GraphQL queries from command line using http(s) or websockets" +) +parser.add_argument( + "server", help="the server url starting with http://, https://, ws:// or wss://" +) args = parser.parse_args() + async def main(): - transport = WebsocketsTransport(url=args.server, ssl=args.server.startswith('wss')) + url = URL(args.server) + + scheme = url.scheme + + if scheme in ["ws", "wss"]: + transport = WebsocketsTransport(url=args.server, ssl=(scheme == "wss")) + elif scheme in ["http", "https"]: + transport = AIOHTTPTransport(url=args.server) + else: + raise Exception("URL protocol should be one of: http, https, ws, wss") - async with AsyncClient(transport=transport) as client: + async with Client(transport=transport) as session: while True: try: @@ -23,8 +39,12 @@ async def main(): query = gql(query_str) - async for result in client.subscribe(query): + if scheme in ["ws", "wss"]: + async for result in session.subscribe(query): + print(result) + else: + result = await session.execute(query) + print(result) - print (result.data) asyncio.run(main()) diff --git a/setup.py b/setup.py index e3d79e3b..f4ba671b 100644 --- a/setup.py +++ b/setup.py @@ -24,8 +24,11 @@ 'pytest-asyncio==0.11.0', 'parse>=1.6.0', ]) - install_requires.append('websockets>=8.1,<9') - install_requires.append('aiohttp==3.6.2') + install_requires.append([ + 'websockets>=8.1,<9', + 'aiohttp==3.6.2', + 'yarl>=1.0,<2.0', + ]) scripts.append('scripts/gql-cli') else: tests_require.append([ From 5ff88a74485a0c331e8b598e24b91d499a05dbec Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Sun, 10 May 2020 15:10:10 +0200 Subject: [PATCH 32/46] Adding tests for AIOHTTPTransport Fix AIOHTTPTransport exceptions Adding asserts to ensure client.execute and client.subscribe are not called with a running asyncio event loop --- gql/async_client.py | 9 + gql/transport/aiohttp.py | 19 ++- setup.py | 1 + tests_py36/test_aiohttp.py | 194 ++++++++++++++++++++++ tests_py36/test_websocket_exceptions.py | 24 +++ tests_py36/test_websocket_query.py | 43 ++++- tests_py36/test_websocket_subscription.py | 4 +- 7 files changed, 287 insertions(+), 7 deletions(-) create mode 100644 tests_py36/test_aiohttp.py diff --git a/gql/async_client.py b/gql/async_client.py index 443f280f..26db7427 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -65,7 +65,12 @@ def execute(self, document: Document, *args, **kwargs) -> Dict: loop = asyncio.get_event_loop() + assert ( + not loop.is_running() + ), "Cannot run client.execute if an asyncio loop is running. Use execute_async instead" + timeout = kwargs.get("timeout", 10) + data: Dict[Any, Any] = loop.run_until_complete( asyncio.wait_for(self.execute_async(document, *args, **kwargs), timeout) ) @@ -104,6 +109,10 @@ def subscribe( loop = asyncio.get_event_loop() + assert ( + not loop.is_running() + ), "Cannot run client.subscribe if an asyncio loop is running. Use subscribe_async instead" + try: while True: result = loop.run_until_complete(async_generator.__anext__()) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 1ceccf80..95720163 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -3,6 +3,7 @@ from aiohttp.typedefs import LooseCookies, LooseHeaders from aiohttp.helpers import BasicAuth from aiohttp.client_reqrep import Fingerprint +from aiohttp.client_exceptions import ClientResponseError from ssl import SSLContext @@ -15,6 +16,7 @@ from .async_transport import AsyncTransport from .exceptions import ( TransportProtocolError, + TransportServerError, TransportClosed, TransportAlreadyConnected, ) @@ -124,13 +126,20 @@ async def execute( async with self.session.post(self.url, ssl=self.ssl, **post_args) as resp: try: result = await resp.json() - if not isinstance(result, dict): - raise ValueError - except ValueError: - result = {} + except Exception: + # We raise a TransportServerError if the status code is 400 or higher + # We raise a TransportProtocolError in the other cases + + try: + # Raise a ClientResponseError if the response status is 400 or higher + resp.raise_for_status() + + except ClientResponseError as e: + raise TransportServerError from e + + raise TransportProtocolError("Server did not return a GraphQL result") if "errors" not in result and "data" not in result: - resp.raise_for_status() raise TransportProtocolError("Server did not return a GraphQL result") return ExecutionResult(errors=result.get("errors"), data=result.get("data")) diff --git a/setup.py b/setup.py index f4ba671b..782984f0 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ 'pytest==5.4.1', 'pytest-asyncio==0.11.0', 'parse>=1.6.0', + 'pytest_aiohttp==0.3.0', ]) install_requires.append([ 'websockets>=8.1,<9', diff --git a/tests_py36/test_aiohttp.py b/tests_py36/test_aiohttp.py new file mode 100644 index 00000000..75d8c21f --- /dev/null +++ b/tests_py36/test_aiohttp.py @@ -0,0 +1,194 @@ +import pytest + +from gql import gql, Client +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.exceptions import ( + TransportServerError, + TransportQueryError, + TransportProtocolError, + TransportAlreadyConnected, + TransportClosed, +) + +from aiohttp import web +from pytest_aiohttp import aiohttp_server + +query1_str = """ + query getContinents { + continents { + code + name + } + } +""" + +query1_server_answer = ( + '{"data":{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},{"code":"AS","name":"Asia"},' + '{"code":"EU","name":"Europe"},{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}}' +) + + +@pytest.mark.asyncio +async def test_aiohttp_query(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + # Execute query asynchronously + result = await session.execute(query) + + continents = result["continents"] + + africa = continents[0] + + assert africa["code"] == "AF" + + +@pytest.mark.asyncio +async def test_aiohttp_error_code_500(event_loop, aiohttp_server): + async def handler(request): + # Will generate http error code 500 + raise Exception("Server error") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportServerError): + await session.execute(query) + + +query1_server_error_answer = '{"errors": ["Error 1", "Error 2"]}' + + +@pytest.mark.asyncio +async def test_aiohttp_error_code(event_loop, aiohttp_server): + async def handler(request): + return web.Response( + text=query1_server_error_answer, content_type="application/json" + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportQueryError): + await session.execute(query) + + +invalid_protocol_responses = [ + "{}", + "qlsjfqsdlkj", + '{"not_data_or_errors": 35}', +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("response", invalid_protocol_responses) +async def test_aiohttp_invalid_protocol(event_loop, aiohttp_server, response): + async def handler(request): + return web.Response(text=response, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(TransportProtocolError): + await session.execute(query) + + +@pytest.mark.asyncio +async def test_aiohttp_subscribe_not_supported(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text="does not matter", content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url) + + async with Client(transport=sample_transport,) as session: + + query = gql(query1_str) + + with pytest.raises(NotImplementedError): + async for result in session.subscribe(query): + pass + + +@pytest.mark.asyncio +async def test_aiohttp_cannot_connect_twice(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=sample_transport,) as session: + + with pytest.raises(TransportAlreadyConnected): + await session.transport.connect() + + +@pytest.mark.asyncio +async def test_aiohttp_cannot_execute_if_not_connected(event_loop, aiohttp_server): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + sample_transport = AIOHTTPTransport(url=url, timeout=10) + + query = gql(query1_str) + + with pytest.raises(TransportClosed): + await sample_transport.execute(query) diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index e359b249..dea5ab21 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -47,6 +47,30 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) await session.execute(query) +invalid_subscription_str = """ + subscription getContinents { + continents { + code + bloh + } + } +""" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_subscription_str]) +async def test_websocket_invalid_subscription(event_loop, client_and_server, query_str): + + session, server = client_and_server + + query = gql(query_str) + + with pytest.raises(TransportQueryError): + async for result in session.subscribe(query): + pass + + connection_error_server_answer = ( '{"type":"connection_error","id":null,' '"payload":{"message":"Unexpected token Q in JSON at position 0"}}' diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 8ab09aa6..59f17392 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -14,7 +14,6 @@ from gql import gql, Client from typing import Dict - query1_str = """ query getContinents { continents { @@ -377,3 +376,45 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( query1 = gql(query_str) await session.execute(query1) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +def test_websocket_execute_sync(event_loop, server): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + print(f"url = {url}") + + sample_transport = WebsocketsTransport(url=url) + + client = Client(transport=sample_transport) + + query1 = gql(query1_str) + + result = client.execute(query1) + + print("Client received: " + str(result)) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Execute sync a second time + result = client.execute(query1) + + print("Client received: " + str(result)) + + # Verify result + assert isinstance(result, Dict) + + continents = result["continents"] + africa = continents[0] + + assert africa["code"] == "AF" + + # Check client is disconnect here + assert sample_transport.websocket is None diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index ee482056..7a913183 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -325,9 +325,11 @@ async def test_websocket_subscription_with_keepalive( assert count == -1 +# Note: forced to put mark.asyncio here to avoid problem with aiohttp plugin +@pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -def test_websocket_subscription_sync(server, subscription_str): +def test_websocket_subscription_sync(event_loop, server, subscription_str): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") From 7faa6a237e48a49ce077b89609dc5017a5751d13 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 11 May 2020 21:12:55 +0200 Subject: [PATCH 33/46] Fix aiohttp tests Removing pytest_aiohttp dependency (copy only the fixture we need) --- setup.py | 1 - tests_py36/aiohttp_fixtures.py | 24 +++++++++++++++++++++++ tests_py36/test_aiohttp.py | 2 +- tests_py36/test_websocket_query.py | 3 +-- tests_py36/test_websocket_subscription.py | 4 +--- 5 files changed, 27 insertions(+), 7 deletions(-) create mode 100644 tests_py36/aiohttp_fixtures.py diff --git a/setup.py b/setup.py index 782984f0..f4ba671b 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,6 @@ 'pytest==5.4.1', 'pytest-asyncio==0.11.0', 'parse>=1.6.0', - 'pytest_aiohttp==0.3.0', ]) install_requires.append([ 'websockets>=8.1,<9', diff --git a/tests_py36/aiohttp_fixtures.py b/tests_py36/aiohttp_fixtures.py new file mode 100644 index 00000000..9a255fe6 --- /dev/null +++ b/tests_py36/aiohttp_fixtures.py @@ -0,0 +1,24 @@ +import pytest +import asyncio + +from aiohttp.test_utils import TestServer + + +@pytest.fixture +async def aiohttp_server(): + """Factory to create a TestServer instance, given an app. + + aiohttp_server(app, **kwargs) + """ + servers = [] + + async def go(app, *, port=None, **kwargs): # type: ignore + server = TestServer(app, port=port) + await server.start_server(**kwargs) + servers.append(server) + return server + + yield go + + while servers: + await servers.pop().close() diff --git a/tests_py36/test_aiohttp.py b/tests_py36/test_aiohttp.py index 75d8c21f..a7f0695a 100644 --- a/tests_py36/test_aiohttp.py +++ b/tests_py36/test_aiohttp.py @@ -11,7 +11,7 @@ ) from aiohttp import web -from pytest_aiohttp import aiohttp_server +from .aiohttp_fixtures import aiohttp_server query1_str = """ query getContinents { diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 59f17392..4eea86f4 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -378,9 +378,8 @@ async def test_websocket_connect_failed_with_authentication_in_connection_init( await session.execute(query1) -@pytest.mark.asyncio @pytest.mark.parametrize("server", [server1_answers], indirect=True) -def test_websocket_execute_sync(event_loop, server): +def test_websocket_execute_sync(server): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 7a913183..ee482056 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -325,11 +325,9 @@ async def test_websocket_subscription_with_keepalive( assert count == -1 -# Note: forced to put mark.asyncio here to avoid problem with aiohttp plugin -@pytest.mark.asyncio @pytest.mark.parametrize("server", [server_countdown], indirect=True) @pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) -def test_websocket_subscription_sync(event_loop, server, subscription_str): +def test_websocket_subscription_sync(server, subscription_str): url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" print(f"url = {url}") From c3a5c18ddd6c67dfc7deded87afdf7f9a0568c43 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Mon, 11 May 2020 22:04:19 +0200 Subject: [PATCH 34/46] Fix isort/black compatibility Fixing import order with isort Modify isort config to be compatible with black Now isort will fail if invalid import order is found isort upgraded to version 4.2.8 to fix bug https://github.com/timothycrosley/isort/issues/537 Fix posargs in tox.ini to allow to run tox for specific tests --- gql/async_client.py | 7 +++---- gql/client.py | 2 +- gql/transport/aiohttp.py | 19 ++++++++----------- gql/transport/async_transport.py | 3 +-- gql/transport/websockets.py | 21 +++++++++------------ setup.cfg | 3 +++ setup.py | 2 +- tests/test_client.py | 3 +-- tests_py36/aiohttp_fixtures.py | 2 +- tests_py36/schema.py | 14 +++++++------- tests_py36/test_aiohttp.py | 10 +++++----- tests_py36/test_aiohttp_online.py | 7 ++++--- tests_py36/test_async_client_validation.py | 11 ++++++----- tests_py36/test_http_async_sync.py | 2 +- tests_py36/test_websocket_exceptions.py | 13 +++++++------ tests_py36/test_websocket_online.py | 12 +++++++----- tests_py36/test_websocket_query.py | 14 ++++++++------ tests_py36/test_websocket_subscription.py | 9 +++++---- tests_py36/websocket_fixtures.py | 9 +++++---- tox.ini | 15 ++++++++------- 20 files changed, 91 insertions(+), 87 deletions(-) diff --git a/gql/async_client.py b/gql/async_client.py index 26db7427..e5ed7c5c 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -1,15 +1,14 @@ import asyncio +from typing import Any, AsyncGenerator, Dict, Generator, cast from graphql import build_ast_schema, build_client_schema, introspection_query, parse from graphql.execution import ExecutionResult from graphql.language.ast import Document -from typing import Generator, AsyncGenerator, Dict, Any, cast - -from .transport.local_schema import LocalSchemaTransport +from .client import Client from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportQueryError -from .client import Client +from .transport.local_schema import LocalSchemaTransport class AsyncClient(Client): diff --git a/gql/client.py b/gql/client.py index 78cdb67d..7e2b839d 100644 --- a/gql/client.py +++ b/gql/client.py @@ -3,8 +3,8 @@ from graphql import build_ast_schema, build_client_schema, introspection_query, parse from graphql.validation import validate -from .transport.local_schema import LocalSchemaTransport from .transport import Transport +from .transport.local_schema import LocalSchemaTransport log = logging.getLogger(__name__) diff --git a/gql/transport/aiohttp.py b/gql/transport/aiohttp.py index 95720163..e24c4045 100644 --- a/gql/transport/aiohttp.py +++ b/gql/transport/aiohttp.py @@ -1,24 +1,21 @@ -import aiohttp - -from aiohttp.typedefs import LooseCookies, LooseHeaders -from aiohttp.helpers import BasicAuth -from aiohttp.client_reqrep import Fingerprint -from aiohttp.client_exceptions import ClientResponseError - from ssl import SSLContext +from typing import Any, AsyncGenerator, Dict, Optional, Union -from typing import Dict, Optional, Union, AsyncGenerator, Any - +import aiohttp +from aiohttp.client_exceptions import ClientResponseError +from aiohttp.client_reqrep import Fingerprint +from aiohttp.helpers import BasicAuth +from aiohttp.typedefs import LooseCookies, LooseHeaders from graphql.execution import ExecutionResult from graphql.language.ast import Document from graphql.language.printer import print_ast from .async_transport import AsyncTransport from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, TransportProtocolError, TransportServerError, - TransportClosed, - TransportAlreadyConnected, ) diff --git a/gql/transport/async_transport.py b/gql/transport/async_transport.py index 81aada2d..99b93643 100644 --- a/gql/transport/async_transport.py +++ b/gql/transport/async_transport.py @@ -1,11 +1,10 @@ import abc +from typing import AsyncGenerator, Dict, Optional import six from graphql.execution import ExecutionResult from graphql.language.ast import Document -from typing import Dict, Optional, AsyncGenerator - @six.add_metaclass(abc.ABCMeta) class AsyncTransport: diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index eb2dfa72..f7856619 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,30 +1,27 @@ from __future__ import absolute_import -import websockets -from websockets.http import HeadersLike -from websockets.typing import Data, Subprotocol -from websockets.client import WebSocketClientProtocol -from websockets.exceptions import ConnectionClosed - -from ssl import SSLContext - import asyncio import json import logging +from ssl import SSLContext +from typing import Any, AsyncGenerator, Dict, Optional, Tuple, Union, cast -from typing import cast, Dict, Optional, Tuple, Union, AsyncGenerator, Any - +import websockets from graphql.execution import ExecutionResult from graphql.language.ast import Document from graphql.language.printer import print_ast +from websockets.client import WebSocketClientProtocol +from websockets.exceptions import ConnectionClosed +from websockets.http import HeadersLike +from websockets.typing import Data, Subprotocol from .async_transport import AsyncTransport from .exceptions import ( + TransportAlreadyConnected, + TransportClosed, TransportProtocolError, TransportQueryError, TransportServerError, - TransportClosed, - TransportAlreadyConnected, ) log = logging.getLogger(__name__) diff --git a/setup.cfg b/setup.cfg index 4a307fe7..054b5c07 100644 --- a/setup.cfg +++ b/setup.cfg @@ -6,6 +6,9 @@ max-line-length = 120 [isort] known_first_party=gql +multi_line_output=3 +include_trailing_comma=True +line_length=88 [tool:pytest] norecursedirs = venv .venv .tox .git .cache .mypy_cache .pytest_cache diff --git a/setup.py b/setup.py index f4ba671b..01a573b6 100644 --- a/setup.py +++ b/setup.py @@ -37,7 +37,7 @@ dev_requires = [ 'flake8==3.7.9', - 'isort<4.0.0', + 'isort==4.2.8', 'black==19.10b0', 'mypy==0.761', 'check-manifest>=0.40,<1', diff --git a/tests/test_client.py b/tests/test_client.py index 78ac0914..9ab7bae4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,9 +3,8 @@ import mock import pytest -from urllib3.exceptions import NewConnectionError - from graphql import build_ast_schema, parse +from urllib3.exceptions import NewConnectionError from gql import Client, gql from gql.transport.requests import RequestsHTTPTransport, Transport diff --git a/tests_py36/aiohttp_fixtures.py b/tests_py36/aiohttp_fixtures.py index 9a255fe6..c9634243 100644 --- a/tests_py36/aiohttp_fixtures.py +++ b/tests_py36/aiohttp_fixtures.py @@ -1,6 +1,6 @@ -import pytest import asyncio +import pytest from aiohttp.test_utils import TestServer diff --git a/tests_py36/schema.py b/tests_py36/schema.py index de29d50a..c97073f8 100644 --- a/tests_py36/schema.py +++ b/tests_py36/schema.py @@ -1,21 +1,21 @@ from graphql import ( - graphql, - print_schema, - GraphQLField, GraphQLArgument, + GraphQLField, GraphQLObjectType, GraphQLSchema, + graphql, + print_schema, ) from graphql.utils.introspection_query import introspection_query from tests.starwars.schema import ( - reviewType, + droidType, episodeEnum, - queryType, - mutationType, humanType, - droidType, + mutationType, + queryType, reviewInputType, + reviewType, ) from tests_py36.fixtures import reviewAdded diff --git a/tests_py36/test_aiohttp.py b/tests_py36/test_aiohttp.py index a7f0695a..d0b4f457 100644 --- a/tests_py36/test_aiohttp.py +++ b/tests_py36/test_aiohttp.py @@ -1,16 +1,16 @@ import pytest +from aiohttp import web -from gql import gql, Client +from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.exceptions import ( - TransportServerError, - TransportQueryError, - TransportProtocolError, TransportAlreadyConnected, TransportClosed, + TransportProtocolError, + TransportQueryError, + TransportServerError, ) -from aiohttp import web from .aiohttp_fixtures import aiohttp_server query1_str = """ diff --git a/tests_py36/test_aiohttp_online.py b/tests_py36/test_aiohttp_online.py index 3d6940d7..4c994e16 100644 --- a/tests_py36/test_aiohttp_online.py +++ b/tests_py36/test_aiohttp_online.py @@ -1,11 +1,12 @@ import asyncio -import pytest import sys +from typing import Dict -from gql import gql, Client +import pytest + +from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.exceptions import TransportQueryError -from typing import Dict @pytest.mark.online diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index c1f40763..edaf2c75 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -1,14 +1,15 @@ import asyncio -import pytest import json -import websockets + import graphql +import pytest +import websockets -from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, Client +from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from tests_py36.schema import StarWarsSchema, StarWarsTypeDef, StarWarsIntrospection +from tests_py36.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef +from .websocket_fixtures import MS, TestServer, client_and_server, server starwars_expected_one = { "stars": 3, diff --git a/tests_py36/test_http_async_sync.py b/tests_py36/test_http_async_sync.py index e361cedb..6fc7eed7 100644 --- a/tests_py36/test_http_async_sync.py +++ b/tests_py36/test_http_async_sync.py @@ -1,6 +1,6 @@ import pytest -from gql import gql, Client +from gql import Client, gql from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.requests import RequestsHTTPTransport diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index dea5ab21..d5019466 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -1,18 +1,19 @@ import asyncio -import pytest import json -import websockets import types -from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, Client -from gql.transport.websockets import WebsocketsTransport +import pytest +import websockets + +from gql import Client, gql from gql.transport.exceptions import ( + TransportClosed, TransportProtocolError, TransportQueryError, - TransportClosed, ) +from gql.transport.websockets import WebsocketsTransport +from .websocket_fixtures import MS, TestServer, client_and_server, server invalid_query_str = """ query getContinents { diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index 8a0b717c..ede0d2d2 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -1,12 +1,14 @@ -import logging import asyncio -import pytest +import logging import sys +from typing import Dict -from gql import gql, Client -from gql.transport.websockets import WebsocketsTransport +import pytest + +from gql import Client, gql from gql.transport.exceptions import TransportError, TransportQueryError -from typing import Dict +from gql.transport.websockets import WebsocketsTransport + from .websocket_fixtures import MS logging.basicConfig(level=logging.INFO) diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 4eea86f4..19ac4ec5 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -1,18 +1,20 @@ import asyncio +import json +from typing import Dict + import pytest import websockets -import json -from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql.transport.websockets import WebsocketsTransport +from gql import Client, gql from gql.transport.exceptions import ( + TransportAlreadyConnected, TransportClosed, TransportQueryError, TransportServerError, - TransportAlreadyConnected, ) -from gql import gql, Client -from typing import Dict +from gql.transport.websockets import WebsocketsTransport + +from .websocket_fixtures import MS, TestServer, client_and_server, server query1_str = """ query getContinents { diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index ee482056..47a13874 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -1,13 +1,14 @@ import asyncio -import pytest import json -import websockets +import pytest +import websockets from parse import search -from .websocket_fixtures import MS, server, client_and_server, TestServer -from gql import gql, Client + +from gql import Client, gql from gql.transport.websockets import WebsocketsTransport +from .websocket_fixtures import MS, TestServer, client_and_server, server countdown_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py index 346adabe..6ae02dbf 100644 --- a/tests_py36/websocket_fixtures.py +++ b/tests_py36/websocket_fixtures.py @@ -1,14 +1,15 @@ -import websockets import asyncio import json -import os -import pytest import logging +import os import types -from gql.transport.websockets import WebsocketsTransport +import pytest +import websockets from websockets.exceptions import ConnectionClosed + from gql import Client +from gql.transport.websockets import WebsocketsTransport # Adding debug logs to websocket tests for name in ["websockets.server", "gql.transport.websockets"]: diff --git a/tox.ini b/tox.ini index a3712250..800b03b5 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = +envlist = black,flake8,import-order,mypy,manifest, py{27,35,36,37,38,39-dev,py,py3} ; requires = tox-conda @@ -14,21 +14,21 @@ setenv = MULTIDICT_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/multidict YARL_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/yarl install_command = python -m pip install --ignore-installed {opts} {packages} -whitelist_externals = +whitelist_externals = python deps = -e.[test] ; Prevent installing issues: https://github.com/ContinuumIO/anaconda-issues/issues/542 commands = pip install -U setuptools - py{27,35,py}: pytest tests {posargs} - py{36,38,39-dev,py3}: pytest tests tests_py36 {posargs} - py{37}: pytest tests tests_py36 {posargs: --cov-report=term-missing --cov=gql} + py{27,35,py}: pytest {posargs:tests} + py{36,38,39-dev,py3}: pytest {posargs:tests tests_py36} + py{37}: pytest {posargs:tests tests_py36 --cov-report=term-missing --cov=gql} [testenv:black] basepython=python3.6 deps = -e.[dev] commands = - black --check gql tests + black --check gql tests tests_py36 [testenv:flake8] basepython = python3.6 @@ -40,7 +40,8 @@ commands = basepython=python3.6 deps = -e.[dev] commands = - isort -rc gql/ tests/ + isort -rc -c gql/ tests tests_py36/ +; Note: if the previous command fails, run it without the -c flag to fix automatically [testenv:mypy] basepython=python3.6 From 6250068aaa05336c81b481c1daf6c137a7b1f753 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 12 May 2020 18:44:21 +0200 Subject: [PATCH 35/46] Fix test_websocket_exception on pypy3 with python 3.6.1 Add -s to pytest to have better logs in tox Add -vb (verbose) to isort to have better logs in tox --- tests_py36/test_websocket_exceptions.py | 9 ++++++++- tox.ini | 8 ++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index d5019466..d01ff1c1 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -58,8 +58,15 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) """ +async def server_invalid_subscription(ws, path): + await TestServer.send_connection_ack(ws) + result = await ws.recv() + await ws.send(invalid_query1_server_answer.format(query_id=1)) + await ws.wait_closed() + + @pytest.mark.asyncio -@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +@pytest.mark.parametrize("server", [server_invalid_subscription], indirect=True) @pytest.mark.parametrize("query_str", [invalid_subscription_str]) async def test_websocket_invalid_subscription(event_loop, client_and_server, query_str): diff --git a/tox.ini b/tox.ini index 800b03b5..0ef1e182 100644 --- a/tox.ini +++ b/tox.ini @@ -20,9 +20,9 @@ deps = -e.[test] ; Prevent installing issues: https://github.com/ContinuumIO/anaconda-issues/issues/542 commands = pip install -U setuptools - py{27,35,py}: pytest {posargs:tests} - py{36,38,39-dev,py3}: pytest {posargs:tests tests_py36} - py{37}: pytest {posargs:tests tests_py36 --cov-report=term-missing --cov=gql} + py{27,35,py}: pytest {posargs:tests -s} + py{36,38,39-dev,py3}: pytest {posargs:tests tests_py36 -s} + py{37}: pytest {posargs:tests tests_py36 --cov-report=term-missing --cov=gql -s} [testenv:black] basepython=python3.6 @@ -40,7 +40,7 @@ commands = basepython=python3.6 deps = -e.[dev] commands = - isort -rc -c gql/ tests tests_py36/ + isort -rc -c -vb gql tests tests_py36 ; Note: if the previous command fails, run it without the -c flag to fix automatically [testenv:mypy] From a9684b2eb8e065347ea165a2c0ed53501f99347c Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 12 May 2020 20:02:25 +0200 Subject: [PATCH 36/46] TRAVIS modify timeout factor to 100 for travis automated tests --- .travis.yml | 3 +++ tests_py36/websocket_fixtures.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index c27e0b40..ae189f22 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,6 @@ +env: + global: + - GQL_TESTS_TIMEOUT_FACTOR=100 language: python sudo: false python: diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py index 6ae02dbf..20097dc1 100644 --- a/tests_py36/websocket_fixtures.py +++ b/tests_py36/websocket_fixtures.py @@ -20,9 +20,9 @@ logger.addHandler(logging.StreamHandler()) # Unit for timeouts. May be increased on slow machines by setting the -# WEBSOCKETS_TESTS_TIMEOUT_FACTOR environment variable. +# GQL_TESTS_TIMEOUT_FACTOR environment variable. # Copied from websockets source -MS = 0.001 * int(os.environ.get("WEBSOCKETS_TESTS_TIMEOUT_FACTOR", 1)) +MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) class TestServer: From f80243f327b2038272e279511c398150b9ac8716 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Tue, 12 May 2020 21:58:37 +0200 Subject: [PATCH 37/46] Increase timeout for pypy3 version 3.6.1 For some reason on pypy3 with the version on travis (3.6.1), it takes 10 seconds to leave the async generator... So we increase the default timeout for a sync execute --- gql/async_client.py | 5 ++++- gql/transport/websockets.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/gql/async_client.py b/gql/async_client.py index e5ed7c5c..2fed06e4 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -12,6 +12,9 @@ class AsyncClient(Client): + + DEFAULT_TIMEOUT = 60 + def __init__( self, schema=None, @@ -68,7 +71,7 @@ def execute(self, document: Document, *args, **kwargs) -> Dict: not loop.is_running() ), "Cannot run client.execute if an asyncio loop is running. Use execute_async instead" - timeout = kwargs.get("timeout", 10) + timeout = kwargs.get("timeout", AsyncClient.DEFAULT_TIMEOUT) data: Dict[Any, Any] = loop.run_until_complete( asyncio.wait_for(self.execute_async(document, *args, **kwargs), timeout) diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index f7856619..843e234c 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -45,7 +45,11 @@ def __init__(self, query_id: int, send_stop: bool) -> None: async def get(self) -> ParsedAnswer: - item = await self._queue.get() + try: + item = self._queue.get_nowait() + except asyncio.QueueEmpty: + item = await self._queue.get() + self._queue.task_done() # If we receive an exception when reading the queue, we raise it From 77319a16aaeedc00aa3b0c9212b32e75d2c42cef Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 07:09:20 +0200 Subject: [PATCH 38/46] Better management of timeouts Add execute_timeout to AsyncClient Add connect_timeout, ack_timeout and close_timeout to websockets transport Fix asyncio bugs of pypy3 v3.6.1 by adding await generator.aclose() in async for --- gql/async_client.py | 81 ++++++++++++++++------- gql/transport/websockets.py | 47 ++++++++++--- tests_py36/test_websocket_exceptions.py | 19 ++++++ tests_py36/test_websocket_subscription.py | 2 + 4 files changed, 116 insertions(+), 33 deletions(-) diff --git a/gql/async_client.py b/gql/async_client.py index 2fed06e4..143c555c 100644 --- a/gql/async_client.py +++ b/gql/async_client.py @@ -1,27 +1,27 @@ import asyncio -from typing import Any, AsyncGenerator, Dict, Generator, cast +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast from graphql import build_ast_schema, build_client_schema, introspection_query, parse from graphql.execution import ExecutionResult from graphql.language.ast import Document +from graphql.type import GraphQLSchema from .client import Client from .transport.async_transport import AsyncTransport from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport +from .transport.transport import Transport class AsyncClient(Client): - - DEFAULT_TIMEOUT = 60 - def __init__( self, - schema=None, + schema: Optional[GraphQLSchema] = None, introspection=None, - type_def=None, - transport=None, - fetch_schema_from_transport=False, + type_def: Optional[str] = None, + transport: Optional[Union[Transport, AsyncTransport]] = None, + fetch_schema_from_transport: bool = False, + execute_timeout: Optional[int] = 10, ): assert not ( type_def and introspection @@ -30,9 +30,11 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" - if not isinstance(transport, AsyncTransport): + if isinstance(transport, Transport): # For sync transports, we fetch the schema directly - introspection = transport.execute(parse(introspection_query)).data + execution_result = transport.execute(parse(introspection_query)) + execution_result = cast(ExecutionResult, execution_result) + introspection = execution_result.data if introspection: assert not schema, "Cant provide introspection and schema at the same time" schema = build_client_schema(introspection) @@ -45,10 +47,21 @@ def __init__( elif schema and not transport: transport = LocalSchemaTransport(schema) - self.schema = schema + # GraphQL schema + self.schema: Optional[GraphQLSchema] = schema + + # Answer of the introspection query self.introspection = introspection - self.transport = transport - self.fetch_schema_from_transport = fetch_schema_from_transport + + # GraphQL transport chosen + self.transport: Optional[Union[Transport, AsyncTransport]] = transport + + # Flag to indicate that we need to fetch the schema from the transport + # On async transports, we fetch the schema before executing the first query + self.fetch_schema_from_transport: bool = fetch_schema_from_transport + + # Enforced timeout of the execute function + self.execute_timeout = execute_timeout async def execute_async(self, document: Document, *args, **kwargs) -> Dict: async with self as session: @@ -71,10 +84,8 @@ def execute(self, document: Document, *args, **kwargs) -> Dict: not loop.is_running() ), "Cannot run client.execute if an asyncio loop is running. Use execute_async instead" - timeout = kwargs.get("timeout", AsyncClient.DEFAULT_TIMEOUT) - data: Dict[Any, Any] = loop.run_until_complete( - asyncio.wait_for(self.execute_async(document, *args, **kwargs), timeout) + self.execute_async(document, *args, **kwargs) ) return data @@ -84,7 +95,14 @@ def execute(self, document: Document, *args, **kwargs) -> Dict: if self.schema: self.validate(document) - result: ExecutionResult = self.transport.execute(document, *args, **kwargs) + result: ExecutionResult + + if isinstance(self.transport, LocalSchemaTransport): + result = cast( + ExecutionResult, self.transport.execute(document, *args, **kwargs) + ) + elif isinstance(self.transport, Transport): + result = cast(ExecutionResult, self.transport.execute(document)) if result.errors: raise TransportQueryError(str(result.errors[0])) @@ -96,7 +114,12 @@ async def subscribe_async( self, document: Document, *args, **kwargs ) -> AsyncGenerator[Dict, None]: async with self as session: - async for result in session.subscribe(document, *args, **kwargs): + + self._generator: AsyncGenerator[Dict, None] = session.subscribe( + document, *args, **kwargs + ) + + async for result in self._generator: yield result def subscribe( @@ -136,7 +159,7 @@ async def __aenter__(self): return self.session - async def __aexit__(self, *args): + async def __aexit__(self, exc_type, exc, tb): await self.transport.close() @@ -182,19 +205,31 @@ async def subscribe( await self.validate(document) # Subscribe to the transport and yield data or raise error - async for result in self.transport.subscribe(document, *args, **kwargs): + self._generator: AsyncGenerator[ + ExecutionResult, None + ] = self.transport.subscribe(document, *args, **kwargs) + + async for result in self._generator: if result.errors: + # Note: we need to run generator.aclose() here or the finally block in + # the transport.subscribe will not be reached in pypy3 (python version 3.6.1) + await self._generator.aclose() + raise TransportQueryError(str(result.errors[0])) - yield result.data + elif result.data is not None: + yield result.data async def execute(self, document: Document, *args, **kwargs) -> Dict: # Fetch schema from transport if needed and validate document if schema is present await self.validate(document) - # Execute the query with the transport - result = await self.transport.execute(document, *args, **kwargs) + # Execute the query with the transport with a timeout + result = await asyncio.wait_for( + self.transport.execute(document, *args, **kwargs), + self.client.execute_timeout, + ) # Raise an error if an error is returned in the ExecutionResult object if result.errors: diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index 843e234c..a3cbc572 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -76,6 +76,10 @@ async def set_exception(self, exception: Exception) -> None: # Put the exception in the queue await self._queue.put(exception) + # Don't need to send stop messages in case of error + self.send_stop = False + self._closed = True + class WebsocketsTransport(AsyncTransport): """Transport to execute GraphQL queries on remote servers with a websocket connection. @@ -92,6 +96,9 @@ def __init__( headers: Optional[HeadersLike] = None, ssl: Union[SSLContext, bool] = False, init_payload: Dict[str, Any] = {}, + connect_timeout: int = 10, + close_timeout: int = 10, + ack_timeout: int = 10, ) -> None: """Initialize the transport with the given request parameters. @@ -99,12 +106,19 @@ def __init__( :param headers: Dict of HTTP Headers. :param ssl: ssl_context of the connection. Use ssl=False to disable encryption :param init_payload: Dict of the payload sent in the connection_init message. + :param connect_timeout: Timeout in seconds for the establishment of the websocket connection. + :param close_timeout: Timeout in seconds for the close. + :param ack_timeout: Timeout in seconds to wait for the connection_ack message from the server. """ self.url: str = url self.ssl: Union[SSLContext, bool] = ssl self.headers: Optional[HeadersLike] = headers self.init_payload: Dict[str, Any] = init_payload + self.connect_timeout: int = connect_timeout + self.close_timeout: int = close_timeout + self.ack_timeout: int = ack_timeout + self.websocket: Optional[WebSocketClientProtocol] = None self.next_query_id: int = 1 self.listeners: Dict[int, ListenerQueue] = {} @@ -169,7 +183,8 @@ async def _send_init_message_and_wait_ack(self) -> None: await self._send(init_message) - init_answer = await self._receive() + # Wait for the connection_ack message or raise a TimeoutError + init_answer = await asyncio.wait_for(self._receive(), self.ack_timeout) answer_type, answer_id, execution_result = self._parse_answer(init_answer) @@ -412,10 +427,17 @@ async def execute( """ first_result = None - async for result in self.subscribe( + generator = self.subscribe( document, variable_values, operation_name, send_stop=False - ): + ) + + async for result in generator: first_result = result + + # Note: we need to run generator.aclose() here or the finally block in + # the subscribe will not be reached in pypy3 (python version 3.6.1) + await generator.aclose() + break if first_result is None: @@ -441,11 +463,15 @@ async def connect(self) -> None: if self.websocket is None: # Connection to the specified url - self.websocket = await websockets.connect( - self.url, - ssl=self.ssl if self.ssl else None, - extra_headers=self.headers, - subprotocols=[GRAPHQLWS_SUBPROTOCOL], + # Generate a TimeoutError if taking more than connect_timeout seconds + self.websocket = await asyncio.wait_for( + websockets.connect( + self.url, + ssl=self.ssl if self.ssl else None, + extra_headers=self.headers, + subprotocols=[GRAPHQLWS_SUBPROTOCOL], + ), + self.connect_timeout, ) self.next_query_id = 1 @@ -453,11 +479,12 @@ async def connect(self) -> None: self._wait_closed.clear() # Send the init message and wait for the ack from the server + # Note: will generate a TimeoutError if no acks are received within the ack_timeout try: await self._send_init_message_and_wait_ack() except ConnectionClosed as e: raise e - except TransportProtocolError as e: + except (TransportProtocolError, asyncio.TimeoutError) as e: await self._fail(e, clean_close=False) raise e @@ -483,7 +510,7 @@ async def _clean_close(self, e: Exception) -> None: # Wait that there is no more listeners (we received 'complete' for all queries) try: - await asyncio.wait_for(self._no_more_listeners.wait(), 10) + await asyncio.wait_for(self._no_more_listeners.wait(), self.close_timeout) except asyncio.TimeoutError: # pragma: no cover pass diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index d01ff1c1..8ee4dd3d 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -62,6 +62,7 @@ async def server_invalid_subscription(ws, path): await TestServer.send_connection_ack(ws) result = await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) + await TestServer.send_complete(ws, 1) await ws.wait_closed() @@ -85,6 +86,24 @@ async def test_websocket_invalid_subscription(event_loop, client_and_server, que ) +async def server_no_ack(ws, path): + await ws.wait_closed() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_no_ack], indirect=True) +@pytest.mark.parametrize("query_str", [invalid_query_str]) +async def test_websocket_server_does_not_send_ack(event_loop, server, query_str): + + url = "ws://" + server.hostname + ":" + str(server.port) + "/graphql" + + sample_transport = WebsocketsTransport(url=url, ack_timeout=1) + + with pytest.raises(asyncio.TimeoutError): + async with Client(transport=sample_transport): + pass + + async def server_connection_error(ws, path): await TestServer.send_connection_ack(ws) result = await ws.recv() diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 47a13874..0cf1b836 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -140,6 +140,8 @@ async def test_websocket_subscription_break( assert number == count if count <= 5: + # Note: the following line is only necessary for pypy3 v3.6.1 + await session._generator.aclose() break count -= 1 From 8cae234fa80c5cd4700d2078e2afbbdf8bc8b313 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 10:39:33 +0200 Subject: [PATCH 39/46] Set graphql-core dependency for now to 2.3.1 to fix tests --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 01a573b6..5a50fd11 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ install_requires = [ 'six>=1.10.0', - 'graphql-core>=2,<3', + 'graphql-core==2.3.1', 'promise>=2.0,<3', 'requests>=2.12,<3' ] From 6c29f397a2f7d6aa7b7868a202681493cb795516 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 13:23:30 +0200 Subject: [PATCH 40/46] tox.ini add --diff to isort to easily see where the problem is --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 0ef1e182..2fcae200 100644 --- a/tox.ini +++ b/tox.ini @@ -40,7 +40,7 @@ commands = basepython=python3.6 deps = -e.[dev] commands = - isort -rc -c -vb gql tests tests_py36 + isort --recursive --check-only --diff --verbose gql tests tests_py36 ; Note: if the previous command fails, run it without the -c flag to fix automatically [testenv:mypy] From 514cce753e049c774db9c8b504312d60ba257392 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 13:44:37 +0200 Subject: [PATCH 41/46] setup.cfg adding ssl to known_standard_library for isort --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 054b5c07..c1853480 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,6 +5,7 @@ universal=1 max-line-length = 120 [isort] +known_standard_library=ssl known_first_party=gql multi_line_output=3 include_trailing_comma=True From 62981259925a7545e114fc86cb8a11a4bfae9dc9 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 15:04:46 +0200 Subject: [PATCH 42/46] README.md add sync client.subscribe usage + rename AsyncClient to Client --- README.md | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 524ff1a7..1f9c7fab 100644 --- a/README.md +++ b/README.md @@ -113,17 +113,42 @@ query = gql(''' client.execute(query) ``` -# Async usage with asyncio and subscriptions +With a python version > 3.6, it is possible to execute GraphQL subscriptions using the websockets transport: -When using the `execute` function directly on the client, the execution is synchronous. +```python +from gql import gql, Client +from gql.transport.websockets import WebsocketsTransport + +sample_transport = WebsocketsTransport(url='wss://your_server/graphql') + +client = Client( + transport=sample_transport, + fetch_schema_from_transport=True, +) + +query = gql(''' + subscription yourSubscription { + ... + } +''') + +for result in client.subscribe(query): + print (f"result = {result!s}") +``` + +Note: the websockets transport can also execute queries or mutations + +# Async usage with asyncio + +When using the `execute` or `subscribe` function directly on the client, the execution is synchronous. It means that we are blocked until we receive an answer from the server and we cannot do anything else while waiting for this answer. -It is now possible to use this library asynchronously using [asyncio](https://docs.python.org/3/library/asyncio.html). +It is also possible to use this library asynchronously using [asyncio](https://docs.python.org/3/library/asyncio.html). Async Features: * Execute GraphQL subscriptions (See [using the websockets transport](#Websockets-async-transport)) -* Execute GraphQL queries and subscriptions in parallel +* Execute GraphQL queries, mutations and subscriptions in parallel To use the async features, you need to use an async transport: * [AIOHTTPTransport](#HTTP-async-transport) for the HTTP(s) protocols @@ -133,11 +158,11 @@ To use the async features, you need to use an async transport: This transport uses the [aiohttp library](https://docs.aiohttp.org) -GraphQL subscriptions are not supported on this HTTP transport. +GraphQL subscriptions are not supported on the HTTP transport. For subscriptions you should use the websockets transport. ```python -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.aiohttp import AIOHTTPTransport import asyncio @@ -148,7 +173,7 @@ async def main(): headers={'Authorization': 'token'} ) - async with AsyncClient( + async with Client( transport=sample_transport, fetch_schema_from_transport=True, ) as session: @@ -182,7 +207,7 @@ This transport allows to do multiple queries, mutations and subscriptions on the import logging logging.basicConfig(level=logging.INFO) -from gql import gql, AsyncClient +from gql import gql, Client from gql.transport.websockets import WebsocketsTransport import asyncio @@ -194,7 +219,7 @@ async def main(): headers={'Authorization': 'token'} ) - async with AsyncClient( + async with Client( transport=sample_transport, fetch_schema_from_transport=True, ) as session: @@ -288,7 +313,7 @@ sample_transport = WebsocketsTransport( ) ``` -### Websockets advanced usage +### Async advanced usage It is possible to send multiple GraphQL queries (query, mutation or subscription) in parallel, on the same websocket connection, using asyncio tasks From 6d0a85d293d04e58b35e94b9b6259ecc0c6bb3d8 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 16:12:27 +0200 Subject: [PATCH 43/46] Moving pytest fixtures to conftest.py Now flake8 also verify the tests_py36 folder --- tests_py36/aiohttp_fixtures.py | 24 --- tests_py36/conftest.py | 181 +++++++++++++++++++++ tests_py36/test_aiohttp.py | 2 - tests_py36/test_async_client_validation.py | 2 +- tests_py36/test_websocket_exceptions.py | 4 +- tests_py36/test_websocket_online.py | 2 +- tests_py36/test_websocket_query.py | 2 +- tests_py36/test_websocket_subscription.py | 2 +- tests_py36/websocket_fixtures.py | 160 ------------------ tox.ini | 2 +- 10 files changed, 188 insertions(+), 193 deletions(-) delete mode 100644 tests_py36/aiohttp_fixtures.py delete mode 100644 tests_py36/websocket_fixtures.py diff --git a/tests_py36/aiohttp_fixtures.py b/tests_py36/aiohttp_fixtures.py deleted file mode 100644 index c9634243..00000000 --- a/tests_py36/aiohttp_fixtures.py +++ /dev/null @@ -1,24 +0,0 @@ -import asyncio - -import pytest -from aiohttp.test_utils import TestServer - - -@pytest.fixture -async def aiohttp_server(): - """Factory to create a TestServer instance, given an app. - - aiohttp_server(app, **kwargs) - """ - servers = [] - - async def go(app, *, port=None, **kwargs): # type: ignore - server = TestServer(app, port=port) - await server.start_server(**kwargs) - servers.append(server) - return server - - yield go - - while servers: - await servers.pop().close() diff --git a/tests_py36/conftest.py b/tests_py36/conftest.py index 691f6c93..07032e2c 100644 --- a/tests_py36/conftest.py +++ b/tests_py36/conftest.py @@ -1,4 +1,16 @@ +import asyncio +import json +import logging +import os +import types + import pytest +import websockets +from aiohttp.test_utils import TestServer as AIOHTTPTestServer +from websockets.exceptions import ConnectionClosed + +from gql import Client +from gql.transport.websockets import WebsocketsTransport def pytest_addoption(parser): @@ -24,3 +36,172 @@ def pytest_collection_modifyitems(config, items): for item in items: if "online" in item.keywords: item.add_marker(skip_online) + + +@pytest.fixture +async def aiohttp_server(): + """Factory to create a TestServer instance, given an app. + + aiohttp_server(app, **kwargs) + """ + servers = [] + + async def go(app, *, port=None, **kwargs): # type: ignore + server = AIOHTTPTestServer(app, port=port) + await server.start_server(**kwargs) + servers.append(server) + return server + + yield go + + while servers: + await servers.pop().close() + + +# Adding debug logs to websocket tests +for name in ["websockets.server", "gql.transport.websockets"]: + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + + if len(logger.handlers) < 1: + logger.addHandler(logging.StreamHandler()) + +# Unit for timeouts. May be increased on slow machines by setting the +# GQL_TESTS_TIMEOUT_FACTOR environment variable. +# Copied from websockets source +MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) + + +class TestServer: + """ + Class used to generate a websocket server on localhost on a free port + + Will allow us to test our client by simulating different correct and incorrect server responses + """ + + async def start(self, handler): + + print("Starting server") + + # Start a server with a random open port + self.start_server = websockets.server.serve(handler, "localhost", 0) + + # Wait that the server is started + self.server = await self.start_server + + # Get hostname and port + hostname, port = self.server.sockets[0].getsockname() + + self.hostname = hostname + self.port = port + + print(f"Server started on port {port}") + + async def stop(self): + print("Stopping server") + + self.server.close() + try: + await asyncio.wait_for(self.server.wait_closed(), timeout=1) + except asyncio.TimeoutError: # pragma: no cover + assert False, "Server failed to stop" + + print("Server stopped\n\n\n") + + @staticmethod + async def send_complete(ws, query_id): + await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') + + @staticmethod + async def send_keepalive(ws): + await ws.send('{"type":"ka"}') + + @staticmethod + async def send_connection_ack(ws): + + # Line return for easy debugging + print("") + + # Wait for init + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "connection_init" + + # Send ack + await ws.send('{"type":"connection_ack"}') + + @staticmethod + async def wait_connection_terminate(ws): + result = await ws.recv() + json_result = json.loads(result) + assert json_result["type"] == "connection_terminate" + + +@pytest.fixture +async def server(request): + """server is a fixture used to start a dummy server to test the client behaviour. + + It can take as argument either a handler function for the websocket server for complete control + OR an array of answers to be sent by the default server handler + """ + + if isinstance(request.param, types.FunctionType): + server_handler = request.param + + else: + answers = request.param + + async def default_server_handler(ws, path): + + try: + await TestServer.send_connection_ack(ws) + query_id = 1 + + for answer in answers: + result = await ws.recv() + print(f"Server received: {result}") + + if isinstance(answer, str) and "{query_id}" in answer: + answer_format_params = {} + answer_format_params["query_id"] = query_id + formatted_answer = answer.format(**answer_format_params) + else: + formatted_answer = answer + + await ws.send(formatted_answer) + await TestServer.send_complete(ws, query_id) + query_id += 1 + + await TestServer.wait_connection_terminate(ws) + await ws.wait_closed() + except ConnectionClosed: + pass + + server_handler = default_server_handler + + try: + test_server = TestServer() + + # Starting the server with the fixture param as the handler function + await test_server.start(server_handler) + + yield test_server + except Exception as e: + print("Exception received in server fixture: " + str(e)) + finally: + await test_server.stop() + + +@pytest.fixture +async def client_and_server(server): + """client_and_server is a helper fixture to start a server and a client connected to its port""" + + # Generate transport to connect to the server fixture + path = "/graphql" + url = "ws://" + server.hostname + ":" + str(server.port) + path + sample_transport = WebsocketsTransport(url=url) + + async with Client(transport=sample_transport) as session: + + # Yield both client session and server + yield (session, server) diff --git a/tests_py36/test_aiohttp.py b/tests_py36/test_aiohttp.py index d0b4f457..3e2ef5e2 100644 --- a/tests_py36/test_aiohttp.py +++ b/tests_py36/test_aiohttp.py @@ -11,8 +11,6 @@ TransportServerError, ) -from .aiohttp_fixtures import aiohttp_server - query1_str = """ query getContinents { continents { diff --git a/tests_py36/test_async_client_validation.py b/tests_py36/test_async_client_validation.py index edaf2c75..25d01e36 100644 --- a/tests_py36/test_async_client_validation.py +++ b/tests_py36/test_async_client_validation.py @@ -9,7 +9,7 @@ from gql.transport.websockets import WebsocketsTransport from tests_py36.schema import StarWarsIntrospection, StarWarsSchema, StarWarsTypeDef -from .websocket_fixtures import MS, TestServer, client_and_server, server +from .conftest import MS, TestServer starwars_expected_one = { "stars": 3, diff --git a/tests_py36/test_websocket_exceptions.py b/tests_py36/test_websocket_exceptions.py index 8ee4dd3d..14b64fdd 100644 --- a/tests_py36/test_websocket_exceptions.py +++ b/tests_py36/test_websocket_exceptions.py @@ -13,7 +13,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .websocket_fixtures import MS, TestServer, client_and_server, server +from .conftest import MS, TestServer invalid_query_str = """ query getContinents { @@ -60,7 +60,7 @@ async def test_websocket_invalid_query(event_loop, client_and_server, query_str) async def server_invalid_subscription(ws, path): await TestServer.send_connection_ack(ws) - result = await ws.recv() + await ws.recv() await ws.send(invalid_query1_server_answer.format(query_id=1)) await TestServer.send_complete(ws, 1) await ws.wait_closed() diff --git a/tests_py36/test_websocket_online.py b/tests_py36/test_websocket_online.py index ede0d2d2..6f447a26 100644 --- a/tests_py36/test_websocket_online.py +++ b/tests_py36/test_websocket_online.py @@ -9,7 +9,7 @@ from gql.transport.exceptions import TransportError, TransportQueryError from gql.transport.websockets import WebsocketsTransport -from .websocket_fixtures import MS +from .conftest import MS logging.basicConfig(level=logging.INFO) diff --git a/tests_py36/test_websocket_query.py b/tests_py36/test_websocket_query.py index 19ac4ec5..f2c3c49b 100644 --- a/tests_py36/test_websocket_query.py +++ b/tests_py36/test_websocket_query.py @@ -14,7 +14,7 @@ ) from gql.transport.websockets import WebsocketsTransport -from .websocket_fixtures import MS, TestServer, client_and_server, server +from .conftest import MS, TestServer query1_str = """ query getContinents { diff --git a/tests_py36/test_websocket_subscription.py b/tests_py36/test_websocket_subscription.py index 0cf1b836..78c9f225 100644 --- a/tests_py36/test_websocket_subscription.py +++ b/tests_py36/test_websocket_subscription.py @@ -8,7 +8,7 @@ from gql import Client, gql from gql.transport.websockets import WebsocketsTransport -from .websocket_fixtures import MS, TestServer, client_and_server, server +from .conftest import MS, TestServer countdown_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"number":{number}}}}}}}' diff --git a/tests_py36/websocket_fixtures.py b/tests_py36/websocket_fixtures.py deleted file mode 100644 index 20097dc1..00000000 --- a/tests_py36/websocket_fixtures.py +++ /dev/null @@ -1,160 +0,0 @@ -import asyncio -import json -import logging -import os -import types - -import pytest -import websockets -from websockets.exceptions import ConnectionClosed - -from gql import Client -from gql.transport.websockets import WebsocketsTransport - -# Adding debug logs to websocket tests -for name in ["websockets.server", "gql.transport.websockets"]: - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - - if len(logger.handlers) < 1: - logger.addHandler(logging.StreamHandler()) - -# Unit for timeouts. May be increased on slow machines by setting the -# GQL_TESTS_TIMEOUT_FACTOR environment variable. -# Copied from websockets source -MS = 0.001 * int(os.environ.get("GQL_TESTS_TIMEOUT_FACTOR", 1)) - - -class TestServer: - """ - Class used to generate a websocket server on localhost on a free port - - Will allow us to test our client by simulating different correct and incorrect server responses - """ - - async def start(self, handler): - - print("Starting server") - - # Start a server with a random open port - self.start_server = websockets.server.serve(handler, "localhost", 0) - - # Wait that the server is started - self.server = await self.start_server - - # Get hostname and port - hostname, port = self.server.sockets[0].getsockname() - - self.hostname = hostname - self.port = port - - print(f"Server started on port {port}") - - async def stop(self): - print("Stopping server") - - self.server.close() - try: - await asyncio.wait_for(self.server.wait_closed(), timeout=1) - except asyncio.TimeoutError: # pragma: no cover - assert False, "Server failed to stop" - - print("Server stopped\n\n\n") - - @staticmethod - async def send_complete(ws, query_id): - await ws.send(f'{{"type":"complete","id":"{query_id}","payload":null}}') - - @staticmethod - async def send_keepalive(ws): - await ws.send('{"type":"ka"}') - - @staticmethod - async def send_connection_ack(ws): - - # Line return for easy debugging - print("") - - # Wait for init - result = await ws.recv() - json_result = json.loads(result) - assert json_result["type"] == "connection_init" - - # Send ack - await ws.send('{"type":"connection_ack"}') - - @staticmethod - async def wait_connection_terminate(ws): - result = await ws.recv() - json_result = json.loads(result) - assert json_result["type"] == "connection_terminate" - - -@pytest.fixture -async def server(request): - """server is a fixture used to start a dummy server to test the client behaviour. - - It can take as argument either a handler function for the websocket server for complete control - OR an array of answers to be sent by the default server handler - """ - - if isinstance(request.param, types.FunctionType): - server_handler = request.param - - else: - answers = request.param - - async def default_server_handler(ws, path): - - try: - await TestServer.send_connection_ack(ws) - query_id = 1 - - for answer in answers: - result = await ws.recv() - print(f"Server received: {result}") - - if isinstance(answer, str) and "{query_id}" in answer: - answer_format_params = {} - answer_format_params["query_id"] = query_id - formatted_answer = answer.format(**answer_format_params) - else: - formatted_answer = answer - - await ws.send(formatted_answer) - await TestServer.send_complete(ws, query_id) - query_id += 1 - - await TestServer.wait_connection_terminate(ws) - await ws.wait_closed() - except ConnectionClosed: - pass - - server_handler = default_server_handler - - try: - test_server = TestServer() - - # Starting the server with the fixture param as the handler function - await test_server.start(server_handler) - - yield test_server - except Exception as e: - print("Exception received in server fixture: " + str(e)) - finally: - await test_server.stop() - - -@pytest.fixture -async def client_and_server(server): - """client_and_server is a helper fixture to start a server and a client connected to its port""" - - # Generate transport to connect to the server fixture - path = "/graphql" - url = "ws://" + server.hostname + ":" + str(server.port) + path - sample_transport = WebsocketsTransport(url=url) - - async with Client(transport=sample_transport) as session: - - # Yield both client session and server - yield (session, server) diff --git a/tox.ini b/tox.ini index f60e4399..e475a8ec 100644 --- a/tox.ini +++ b/tox.ini @@ -34,7 +34,7 @@ commands = basepython = python3.8 deps = -e.[dev] commands = - flake8 gql tests + flake8 gql tests tests_py36 [testenv:import-order] basepython=python3.8 From 9eef3fe8ac89a6cbe54b0070d20fc5bca5ad6745 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Wed, 13 May 2020 16:36:42 +0200 Subject: [PATCH 44/46] tox.ini trying to combine coverage of python2.7 and 3.8 --- tox.ini | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index e475a8ec..ce778198 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{27,35,36,37,38,39-dev,py,py3} + py{27,35,36,37,38,39-dev,py,py3},cov ; requires = tox-conda [pytest] @@ -13,6 +13,7 @@ setenv = PYTHONPATH = {toxinidir} MULTIDICT_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/multidict YARL_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/yarl + COVERAGE_FILE=.coverage.{envname} install_command = python -m pip install --ignore-installed {opts} {packages} whitelist_externals = python @@ -20,9 +21,12 @@ deps = -e.[test] ; Prevent installing issues: https://github.com/ContinuumIO/anaconda-issues/issues/542 commands = pip install -U setuptools - py{27,35,py}: pytest {posargs:tests -s} + py{27}: pytest {posargs:tests --cov-report=term-missing --cov=gql -s} + py{35,py}: pytest {posargs:tests -s} py{36,37,39-dev,py3}: pytest {posargs:tests tests_py36 -s} py{38}: pytest {posargs:tests tests_py36 --cov-report=term-missing --cov=gql -s} + cov: coverage combine {toxinidir}/.coverage.py27 {toxinidir}/.coverage.py38 + cov: coverage report --show-missing [testenv:black] basepython=python3.8 From d183573062ff086191b1825aaa90c132d851c45a Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 14 May 2020 11:25:57 +0200 Subject: [PATCH 45/46] Remove support for python versions < 3.6 --- gql/__init__.py | 8 +- gql/async_client.py | 247 ------------------------------- gql/client.py | 263 ++++++++++++++++++++++++++-------- gql/transport/__init__.py | 12 +- gql/transport/local_schema.py | 11 +- gql/transport/requests.py | 34 ++--- gql/transport/transport.py | 5 +- gql/transport/websockets.py | 2 - setup.cfg | 13 +- setup.py | 50 +++---- tests/test_client.py | 32 ----- tox.ini | 7 +- 12 files changed, 252 insertions(+), 432 deletions(-) delete mode 100644 gql/async_client.py diff --git a/gql/__init__.py b/gql/__init__.py index 2651c40e..571c7371 100644 --- a/gql/__init__.py +++ b/gql/__init__.py @@ -1,10 +1,4 @@ -import sys - +from .client import Client from .gql import gql -if sys.version_info > (3, 6): - from .async_client import AsyncClient as Client -else: - from .client import Client - __all__ = ["gql", "Client"] diff --git a/gql/async_client.py b/gql/async_client.py deleted file mode 100644 index 143c555c..00000000 --- a/gql/async_client.py +++ /dev/null @@ -1,247 +0,0 @@ -import asyncio -from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast - -from graphql import build_ast_schema, build_client_schema, introspection_query, parse -from graphql.execution import ExecutionResult -from graphql.language.ast import Document -from graphql.type import GraphQLSchema - -from .client import Client -from .transport.async_transport import AsyncTransport -from .transport.exceptions import TransportQueryError -from .transport.local_schema import LocalSchemaTransport -from .transport.transport import Transport - - -class AsyncClient(Client): - def __init__( - self, - schema: Optional[GraphQLSchema] = None, - introspection=None, - type_def: Optional[str] = None, - transport: Optional[Union[Transport, AsyncTransport]] = None, - fetch_schema_from_transport: bool = False, - execute_timeout: Optional[int] = 10, - ): - assert not ( - type_def and introspection - ), "Cant provide introspection type definition at the same time" - if transport and fetch_schema_from_transport: - assert ( - not schema - ), "Cant fetch the schema from transport if is already provided" - if isinstance(transport, Transport): - # For sync transports, we fetch the schema directly - execution_result = transport.execute(parse(introspection_query)) - execution_result = cast(ExecutionResult, execution_result) - introspection = execution_result.data - if introspection: - assert not schema, "Cant provide introspection and schema at the same time" - schema = build_client_schema(introspection) - elif type_def: - assert ( - not schema - ), "Cant provide Type definition and schema at the same time" - type_def_ast = parse(type_def) - schema = build_ast_schema(type_def_ast) - elif schema and not transport: - transport = LocalSchemaTransport(schema) - - # GraphQL schema - self.schema: Optional[GraphQLSchema] = schema - - # Answer of the introspection query - self.introspection = introspection - - # GraphQL transport chosen - self.transport: Optional[Union[Transport, AsyncTransport]] = transport - - # Flag to indicate that we need to fetch the schema from the transport - # On async transports, we fetch the schema before executing the first query - self.fetch_schema_from_transport: bool = fetch_schema_from_transport - - # Enforced timeout of the execute function - self.execute_timeout = execute_timeout - - async def execute_async(self, document: Document, *args, **kwargs) -> Dict: - async with self as session: - return await session.execute(document, *args, **kwargs) - - def execute(self, document: Document, *args, **kwargs) -> Dict: - """Execute the provided document AST against the configured remote server. - - This function WILL BLOCK until the result is received from the server. - - Either the transport is sync and we execute the query synchronously directly - OR the transport is async and we execute the query in the asyncio loop (blocking here until answer) - """ - - if isinstance(self.transport, AsyncTransport): - - loop = asyncio.get_event_loop() - - assert ( - not loop.is_running() - ), "Cannot run client.execute if an asyncio loop is running. Use execute_async instead" - - data: Dict[Any, Any] = loop.run_until_complete( - self.execute_async(document, *args, **kwargs) - ) - - return data - - else: # Sync transports - - if self.schema: - self.validate(document) - - result: ExecutionResult - - if isinstance(self.transport, LocalSchemaTransport): - result = cast( - ExecutionResult, self.transport.execute(document, *args, **kwargs) - ) - elif isinstance(self.transport, Transport): - result = cast(ExecutionResult, self.transport.execute(document)) - - if result.errors: - raise TransportQueryError(str(result.errors[0])) - - # Running cast to make mypy happy. result.data should never be None here - return cast(Dict[Any, Any], result.data) - - async def subscribe_async( - self, document: Document, *args, **kwargs - ) -> AsyncGenerator[Dict, None]: - async with self as session: - - self._generator: AsyncGenerator[Dict, None] = session.subscribe( - document, *args, **kwargs - ) - - async for result in self._generator: - yield result - - def subscribe( - self, document: Document, *args, **kwargs - ) -> Generator[Dict, None, None]: - """Execute a GraphQL subscription with a python generator. - - We need an async transport for this functionality. - """ - - async_generator = self.subscribe_async(document, *args, **kwargs) - - loop = asyncio.get_event_loop() - - assert ( - not loop.is_running() - ), "Cannot run client.subscribe if an asyncio loop is running. Use subscribe_async instead" - - try: - while True: - result = loop.run_until_complete(async_generator.__anext__()) - yield result - - except StopAsyncIteration: - pass - - async def __aenter__(self): - - assert isinstance( - self.transport, AsyncTransport - ), "Only a transport of type AsyncTransport can be used asynchronously" - - await self.transport.connect() - - if not hasattr(self, "session"): - self.session = AsyncClientSession(client=self) - - return self.session - - async def __aexit__(self, exc_type, exc, tb): - - await self.transport.close() - - def close(self): - """Close the client and it's underlying transport (only for Sync transports)""" - if not isinstance(self.transport, AsyncTransport): - self.transport.close() - - def __enter__(self): - assert not isinstance( - self.transport, AsyncTransport - ), "Only a sync transport can be use. Use 'async with Client(...)' instead" - return self - - def __exit__(self, *args): - self.close() - - -class AsyncClientSession: - """ An instance of this class is created when using 'async with' on the client. - - It contains the async methods (execute, subscribe) to send queries with the async transports""" - - def __init__(self, client: AsyncClient): - self.client = client - - async def validate(self, document: Document): - """ Fetch schema from transport if needed and validate document if schema is present """ - - # Get schema from transport if needed - if self.client.fetch_schema_from_transport and not self.client.schema: - await self.fetch_schema() - - # Validate document - if self.client.schema: - self.client.validate(document) - - async def subscribe( - self, document: Document, *args, **kwargs - ) -> AsyncGenerator[Dict, None]: - - # Fetch schema from transport if needed and validate document if schema is present - await self.validate(document) - - # Subscribe to the transport and yield data or raise error - self._generator: AsyncGenerator[ - ExecutionResult, None - ] = self.transport.subscribe(document, *args, **kwargs) - - async for result in self._generator: - if result.errors: - # Note: we need to run generator.aclose() here or the finally block in - # the transport.subscribe will not be reached in pypy3 (python version 3.6.1) - await self._generator.aclose() - - raise TransportQueryError(str(result.errors[0])) - - elif result.data is not None: - yield result.data - - async def execute(self, document: Document, *args, **kwargs) -> Dict: - - # Fetch schema from transport if needed and validate document if schema is present - await self.validate(document) - - # Execute the query with the transport with a timeout - result = await asyncio.wait_for( - self.transport.execute(document, *args, **kwargs), - self.client.execute_timeout, - ) - - # Raise an error if an error is returned in the ExecutionResult object - if result.errors: - raise TransportQueryError(str(result.errors[0])) - - return result.data - - async def fetch_schema(self) -> None: - execution_result = await self.transport.execute(parse(introspection_query)) - self.client.introspection = execution_result.data - self.client.schema = build_client_schema(self.client.introspection) - - @property - def transport(self): - return self.client.transport diff --git a/gql/client.py b/gql/client.py index b5bad8c7..c739199a 100644 --- a/gql/client.py +++ b/gql/client.py @@ -1,32 +1,27 @@ -import logging +import asyncio +from typing import Any, AsyncGenerator, Dict, Generator, Optional, Union, cast from graphql import build_ast_schema, build_client_schema, introspection_query, parse +from graphql.execution import ExecutionResult +from graphql.language.ast import Document +from graphql.type import GraphQLSchema from graphql.validation import validate -from .transport import Transport +from .transport.async_transport import AsyncTransport +from .transport.exceptions import TransportQueryError from .transport.local_schema import LocalSchemaTransport +from .transport.transport import Transport -log = logging.getLogger(__name__) - -class RetryError(Exception): - """Custom exception thrown when retry logic fails""" - - def __init__(self, retries_count, last_exception): - message = "Failed %s retries: %s" % (retries_count, last_exception) - super(RetryError, self).__init__(message) - self.last_exception = last_exception - - -class Client(object): +class Client: def __init__( self, - schema=None, + schema: Optional[GraphQLSchema] = None, introspection=None, - type_def=None, - transport=None, - fetch_schema_from_transport=False, - retries=0, # We should remove this parameter and let the transport level handle it + type_def: Optional[str] = None, + transport: Optional[Union[Transport, AsyncTransport]] = None, + fetch_schema_from_transport: bool = False, + execute_timeout: Optional[int] = 10, ): assert not ( type_def and introspection @@ -35,10 +30,11 @@ def __init__( assert ( not schema ), "Cant fetch the schema from transport if is already provided" - assert isinstance( - transport, Transport - ), "With an asyncio transport, please use the AsyncClient class" - introspection = transport.execute(parse(introspection_query)).data + if isinstance(transport, Transport): + # For sync transports, we fetch the schema directly + execution_result = transport.execute(parse(introspection_query)) + execution_result = cast(ExecutionResult, execution_result) + introspection = execution_result.data if introspection: assert not schema, "Cant provide introspection and schema at the same time" schema = build_client_schema(introspection) @@ -51,16 +47,21 @@ def __init__( elif schema and not transport: transport = LocalSchemaTransport(schema) - self.schema = schema + # GraphQL schema + self.schema: Optional[GraphQLSchema] = schema + + # Answer of the introspection query self.introspection = introspection - self.transport = transport - self.retries = retries - if self.retries: - log.warning( - "The retries parameter on the Client class is deprecated." - "You can pass it to the RequestsHTTPTransport." - ) + # GraphQL transport chosen + self.transport: Optional[Union[Transport, AsyncTransport]] = transport + + # Flag to indicate that we need to fetch the schema from the transport + # On async transports, we fetch the schema before executing the first query + self.fetch_schema_from_transport: bool = fetch_schema_from_transport + + # Enforced timeout of the execute function + self.execute_timeout = execute_timeout def validate(self, document): if not self.schema: @@ -71,45 +72,183 @@ def validate(self, document): if validation_errors: raise validation_errors[0] - def execute(self, document, *args, **kwargs): - if self.schema: - self.validate(document) + async def execute_async(self, document: Document, *args, **kwargs) -> Dict: + async with self as session: + return await session.execute(document, *args, **kwargs) - result = self._get_result(document, *args, **kwargs) - if result.errors: - raise Exception(str(result.errors[0])) + def execute(self, document: Document, *args, **kwargs) -> Dict: + """Execute the provided document AST against the configured remote server. - return result.data + This function WILL BLOCK until the result is received from the server. + + Either the transport is sync and we execute the query synchronously directly + OR the transport is async and we execute the query in the asyncio loop (blocking here until answer) + """ + + if isinstance(self.transport, AsyncTransport): + + loop = asyncio.get_event_loop() + + assert ( + not loop.is_running() + ), "Cannot run client.execute if an asyncio loop is running. Use execute_async instead" + + data: Dict[Any, Any] = loop.run_until_complete( + self.execute_async(document, *args, **kwargs) + ) + + return data + + else: # Sync transports + + if self.schema: + self.validate(document) + + assert self.transport is not None, "Cant execute without a tranport" + + result: ExecutionResult = self.transport.execute(document, *args, **kwargs) + + if result.errors: + raise TransportQueryError(str(result.errors[0])) + + assert ( + result.data is not None + ), "Transport returned an ExecutionResult without data or errors" + + return result.data + + async def subscribe_async( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[Dict, None]: + async with self as session: + + self._generator: AsyncGenerator[Dict, None] = session.subscribe( + document, *args, **kwargs + ) + + async for result in self._generator: + yield result + + def subscribe( + self, document: Document, *args, **kwargs + ) -> Generator[Dict, None, None]: + """Execute a GraphQL subscription with a python generator. + + We need an async transport for this functionality. + """ + + async_generator = self.subscribe_async(document, *args, **kwargs) - def _get_result(self, document, *args, **kwargs): - if not self.retries: - return self.transport.execute(document, *args, **kwargs) - - last_exception = None - retries_count = 0 - while retries_count < self.retries: - try: - result = self.transport.execute(document, *args, **kwargs) - return result - except Exception as e: - last_exception = e - log.warning( - "Request failed with exception %s. Retrying for the %s time...", - e, - retries_count + 1, - exc_info=True, - ) - finally: - retries_count += 1 - - raise RetryError(retries_count, last_exception) + loop = asyncio.get_event_loop() + + assert ( + not loop.is_running() + ), "Cannot run client.subscribe if an asyncio loop is running. Use subscribe_async instead" + + try: + while True: + result = loop.run_until_complete(async_generator.__anext__()) + yield result + + except StopAsyncIteration: + pass + + async def __aenter__(self): + + assert isinstance( + self.transport, AsyncTransport + ), "Only a transport of type AsyncTransport can be used asynchronously" + + await self.transport.connect() + + if not hasattr(self, "session"): + self.session = ClientSession(client=self) + + return self.session + + async def __aexit__(self, exc_type, exc, tb): + + await self.transport.close() def close(self): - """Close the client and it's underlying transport""" - self.transport.close() + """Close the client and it's underlying transport (only for Sync transports)""" + if not isinstance(self.transport, AsyncTransport): + self.transport.close() def __enter__(self): + assert not isinstance( + self.transport, AsyncTransport + ), "Only a sync transport can be use. Use 'async with Client(...)' instead" return self def __exit__(self, *args): self.close() + + +class ClientSession: + """ An instance of this class is created when using 'async with' on the client. + + It contains the async methods (execute, subscribe) to send queries with the async transports""" + + def __init__(self, client: Client): + self.client = client + + async def validate(self, document: Document): + """ Fetch schema from transport if needed and validate document if schema is present """ + + # Get schema from transport if needed + if self.client.fetch_schema_from_transport and not self.client.schema: + await self.fetch_schema() + + # Validate document + if self.client.schema: + self.client.validate(document) + + async def subscribe( + self, document: Document, *args, **kwargs + ) -> AsyncGenerator[Dict, None]: + + # Fetch schema from transport if needed and validate document if schema is present + await self.validate(document) + + # Subscribe to the transport and yield data or raise error + self._generator: AsyncGenerator[ + ExecutionResult, None + ] = self.transport.subscribe(document, *args, **kwargs) + + async for result in self._generator: + if result.errors: + # Note: we need to run generator.aclose() here or the finally block in + # the transport.subscribe will not be reached in pypy3 (python version 3.6.1) + await self._generator.aclose() + + raise TransportQueryError(str(result.errors[0])) + + elif result.data is not None: + yield result.data + + async def execute(self, document: Document, *args, **kwargs) -> Dict: + + # Fetch schema from transport if needed and validate document if schema is present + await self.validate(document) + + # Execute the query with the transport with a timeout + result = await asyncio.wait_for( + self.transport.execute(document, *args, **kwargs), + self.client.execute_timeout, + ) + + # Raise an error if an error is returned in the ExecutionResult object + if result.errors: + raise TransportQueryError(str(result.errors[0])) + + return result.data + + async def fetch_schema(self) -> None: + execution_result = await self.transport.execute(parse(introspection_query)) + self.client.introspection = execution_result.data + self.client.schema = build_client_schema(self.client.introspection) + + @property + def transport(self): + return self.client.transport diff --git a/gql/transport/__init__.py b/gql/transport/__init__.py index 0fdf95d6..ca8b6252 100644 --- a/gql/transport/__init__.py +++ b/gql/transport/__init__.py @@ -1,12 +1,4 @@ -import sys - +from .async_transport import AsyncTransport from .transport import Transport -__all__ = ["Transport"] - - -if sys.version_info > (3, 6): - from .async_transport import AsyncTransport - - # Cannot use __all__.append here because of flake8 warning - __all__ = ["Transport", "AsyncTransport"] +__all__ = ["AsyncTransport", "Transport"] diff --git a/gql/transport/local_schema.py b/gql/transport/local_schema.py index d2b82c55..2af1d1be 100644 --- a/gql/transport/local_schema.py +++ b/gql/transport/local_schema.py @@ -1,9 +1,6 @@ -from typing import Any, Union - from graphql import GraphQLSchema from graphql.execution import ExecutionResult, execute from graphql.language.ast import Document -from promise import Promise from gql.transport import Transport @@ -12,8 +9,7 @@ class LocalSchemaTransport(Transport): """A transport for executing GraphQL queries against a local schema.""" def __init__( - self, # type: LocalSchemaTransport - schema, # type: GraphQLSchema + self, schema: GraphQLSchema, ): """Initialize the transport with the given local schema. @@ -21,8 +17,7 @@ def __init__( """ self.schema = schema - def execute(self, document, *args, **kwargs): - # type: (Document, *Any, **Any) -> Union[ExecutionResult, Promise[ExecutionResult]] + def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: """Execute the given document against the configured local schema. :param document: GraphQL query as AST Node object. @@ -30,4 +25,4 @@ def execute(self, document, *args, **kwargs): :param kwargs: Keyword options passed to execute method from graphql-core library. :return: Either ExecutionResult or a Promise that resolves to ExecutionResult object. """ - return execute(self.schema, document, *args, **kwargs) + return execute(self.schema, document, *args, **kwargs) # type: ignore diff --git a/gql/transport/requests.py b/gql/transport/requests.py index 238dcc26..3f48e945 100644 --- a/gql/transport/requests.py +++ b/gql/transport/requests.py @@ -1,6 +1,4 @@ -from __future__ import absolute_import - -from typing import Any, Dict, Union +from typing import Any, Dict, Optional, Union import requests from graphql.execution import ExecutionResult @@ -20,17 +18,17 @@ class RequestsHTTPTransport(Transport): """ def __init__( - self, # type: RequestsHTTPTransport - url, # type: str - headers=None, # type: Dict[str, Any] - cookies=None, # type: Union[Dict[str, Any], RequestsCookieJar] - auth=None, # type: AuthBase - use_json=False, # type: bool - timeout=None, # type: int - verify=True, # type: bool - retries=0, # type: int - method="POST", # type: str - **kwargs # type: Any + self, + url: str, + headers: Optional[Dict[str, Any]] = None, + cookies: Optional[Union[Dict[str, Any], RequestsCookieJar]] = None, + auth: Optional[AuthBase] = None, + use_json: bool = False, + timeout: Optional[int] = None, + verify: bool = True, + retries: int = 0, + method: str = "POST", + **kwargs, ): """Initialize the transport with the given request parameters. @@ -76,8 +74,12 @@ def __init__( for prefix in "http://", "https://": self.session.mount(prefix, adapter) - def execute(self, document, variable_values=None, timeout=None): - # type: (Document, Dict, int) -> ExecutionResult + def execute( # type: ignore + self, + document: Document, + variable_values: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = None, + ) -> ExecutionResult: """Execute the provided document AST against the configured remote server. This uses the requests library to perform a HTTP POST request to the remote server. diff --git a/gql/transport/transport.py b/gql/transport/transport.py index b65ffcc6..9d4d6711 100644 --- a/gql/transport/transport.py +++ b/gql/transport/transport.py @@ -1,17 +1,14 @@ import abc -from typing import Union import six from graphql.execution import ExecutionResult from graphql.language.ast import Document -from promise import Promise @six.add_metaclass(abc.ABCMeta) class Transport: @abc.abstractmethod - def execute(self, document): - # type: (Document) -> Union[ExecutionResult, Promise[ExecutionResult]] + def execute(self, document: Document, *args, **kwargs) -> ExecutionResult: """Execute the provided document AST for either a remote or local GraphQL Schema. :param document: GraphQL query as AST Node or Document object. diff --git a/gql/transport/websockets.py b/gql/transport/websockets.py index a3cbc572..3dea3bad 100644 --- a/gql/transport/websockets.py +++ b/gql/transport/websockets.py @@ -1,5 +1,3 @@ -from __future__ import absolute_import - import asyncio import json import logging diff --git a/setup.cfg b/setup.cfg index c1853480..50a6335c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,15 +1,16 @@ [wheel] -universal=1 +universal = 1 [flake8] max-line-length = 120 [isort] -known_standard_library=ssl -known_first_party=gql -multi_line_output=3 -include_trailing_comma=True -line_length=88 +known_standard_library = ssl +known_first_party = gql +multi_line_output = 3 +include_trailing_comma = True +line_length = 88 +not_skip = __init__.py [tool:pytest] norecursedirs = venv .venv .tox .git .cache .mypy_cache .pytest_cache diff --git a/setup.py b/setup.py index 919dcae6..c2c87934 100644 --- a/setup.py +++ b/setup.py @@ -1,48 +1,34 @@ -import sys - from setuptools import setup, find_packages install_requires = [ - "six>=1.10.0", + "aiohttp==3.6.2", "graphql-core>=2.3.2,<3", - "promise>=2.3,<3", "requests>=2.12,<3", + "six>=1.10.0", + "websockets>=8.1,<9", + "yarl>=1.0,<2.0", ] -scripts = [] - -if sys.version_info > (3, 6): - install_requires.append( - ["websockets>=8.1,<9", "aiohttp==3.6.2", "yarl>=1.0,<2.0",] - ) - scripts.append("scripts/gql-cli") +scripts = [ + "scripts/gql-cli", +] -tests_require = ( - [ - "pytest==5.4.2", - "pytest-asyncio==0.11.0", - "pytest-cov==2.8.1", - "mock==4.0.2", - "vcrpy==4.0.2", - "coveralls==2.0.0", - "parse>=1.6.0", - ] - if sys.version_info > (3, 6) - else [ - "pytest==4.6.9", - "pytest-cov==2.8.1", - "vcrpy==3.0.0", - "mock==3.0.0", - "coveralls==1.11.1", - ] -) +tests_require = [ + "coveralls==2.0.0", + "parse>=1.6.0", + "pytest==5.4.2", + "pytest-asyncio==0.11.0", + "pytest-cov==2.8.1", + "mock==4.0.2", + "vcrpy==4.0.2", +] dev_requires = [ + "black==19.10b0", + "check-manifest>=0.40,<1", "flake8==3.7.9", "isort==4.2.8", - "black==19.10b0", "mypy==0.770", - "check-manifest>=0.40,<1", ] + tests_require setup( diff --git a/tests/test_client.py b/tests/test_client.py index d5372740..6532acb7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,5 +1,4 @@ import os -import sys import mock import pytest @@ -34,37 +33,6 @@ def execute(self): assert "Any Transport subclass must implement execute method" == str(exc_info.value) -@pytest.mark.skipif( - sys.version_info > (3, 6), reason="retries on client deprecated in latest versions" -) -@mock.patch("gql.transport.requests.RequestsHTTPTransport.execute") -def test_retries(execute_mock): - expected_retries = 3 - execute_mock.side_effect = Exception("fail") - - client = Client( - retries=expected_retries, - transport=RequestsHTTPTransport(url="http://swapi.graphene-python.org/graphql"), - ) - - query = gql( - """ - { - myFavoriteFilm: film(id:"RmlsbToz") { - id - title - episodeId - } - } - """ - ) - - with pytest.raises(Exception): - client.execute(query) - client.close() - assert execute_mock.call_count == expected_retries - - @mock.patch("urllib3.connection.HTTPConnection._new_conn") def test_retries_on_transport(execute_mock): """Testing retries on the transport level diff --git a/tox.ini b/tox.ini index ce778198..7f78e12a 100644 --- a/tox.ini +++ b/tox.ini @@ -1,7 +1,7 @@ [tox] envlist = black,flake8,import-order,mypy,manifest, - py{27,35,36,37,38,39-dev,py,py3},cov + py{36,37,38,39-dev,py3} ; requires = tox-conda [pytest] @@ -13,7 +13,6 @@ setenv = PYTHONPATH = {toxinidir} MULTIDICT_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/multidict YARL_NO_EXTENSIONS = 1 ; Related to https://github.com/aio-libs/yarl - COVERAGE_FILE=.coverage.{envname} install_command = python -m pip install --ignore-installed {opts} {packages} whitelist_externals = python @@ -21,12 +20,8 @@ deps = -e.[test] ; Prevent installing issues: https://github.com/ContinuumIO/anaconda-issues/issues/542 commands = pip install -U setuptools - py{27}: pytest {posargs:tests --cov-report=term-missing --cov=gql -s} - py{35,py}: pytest {posargs:tests -s} py{36,37,39-dev,py3}: pytest {posargs:tests tests_py36 -s} py{38}: pytest {posargs:tests tests_py36 --cov-report=term-missing --cov=gql -s} - cov: coverage combine {toxinidir}/.coverage.py27 {toxinidir}/.coverage.py38 - cov: coverage report --show-missing [testenv:black] basepython=python3.8 From 564065b6e6c73e094205de568b12a1c7f708c849 Mon Sep 17 00:00:00 2001 From: Leszek Hanusz Date: Thu, 14 May 2020 12:31:12 +0200 Subject: [PATCH 46/46] .travis.yml: stop testing on python 2.7, 3.5 and pypy --- .travis.yml | 3 --- 1 file changed, 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 8d9110fe..520178fe 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,13 +4,10 @@ env: language: python sudo: false python: - - 2.7 - - 3.5 - 3.6 - 3.7 - 3.8 - 3.9-dev - - pypy - pypy3 matrix: include: