diff --git a/Makefile b/Makefile index 27913507..b677a7e8 100644 --- a/Makefile +++ b/Makefile @@ -10,10 +10,10 @@ all_tests: pytest tests --cov=gql --cov-report=term-missing --run-online -vv check: - isort --recursive gql tests - black gql tests - flake8 gql tests - mypy gql tests + isort --recursive gql tests scripts/gql-cli + black gql tests scripts/gql-cli + flake8 gql tests scripts/gql-cli + mypy gql tests scripts/gql-cli check-manifest docs: diff --git a/docs/conf.py b/docs/conf.py index 987bc3cd..db6e7c5f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,6 +32,7 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + 'sphinxarg.ext', 'sphinx.ext.autodoc', 'sphinx_rtd_theme' ] diff --git a/docs/gql-cli/intro.rst b/docs/gql-cli/intro.rst new file mode 100644 index 00000000..3a25c6df --- /dev/null +++ b/docs/gql-cli/intro.rst @@ -0,0 +1,71 @@ +gql-cli +======= + +GQL provides a python 3.6+ script, called `gql-cli` which allows you to execute +GraphQL queries directly from the terminal. + +This script supports http(s) or websockets protocols. + +Usage +----- + +.. argparse:: + :module: gql.cli + :func: get_parser + :prog: gql-cli + +Examples +-------- + +Simple query using https +^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: shell + + $ echo 'query { continent(code:"AF") { name } }' | gql-cli https://countries.trevorblades.com + {"continent": {"name": "Africa"}} + +Simple query using websockets +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: shell + + $ echo 'query { continent(code:"AF") { name } }' | gql-cli wss://countries.trevorblades.com/graphql + {"continent": {"name": "Africa"}} + +Query with variable +^^^^^^^^^^^^^^^^^^^ + +.. code-block:: shell + + $ echo 'query getContinent($code:ID!) { continent(code:$code) { name } }' | gql-cli https://countries.trevorblades.com --variables code:AF + {"continent": {"name": "Africa"}} + +Interactive usage +^^^^^^^^^^^^^^^^^ + +Insert your query in the terminal, then press Ctrl-D to execute it. + +.. code-block:: shell + + $ gql-cli wss://countries.trevorblades.com/graphql --variables code:AF + +Execute query saved in a file +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Put the query in a file: + +.. code-block:: shell + + $ echo 'query { + continent(code:"AF") { + name + } + }' > query.gql + +Then execute query from the file: + +.. code-block:: shell + + $ cat query.gql | gql-cli wss://countries.trevorblades.com/graphql + {"continent": {"name": "Africa"}} diff --git a/docs/index.rst b/docs/index.rst index ead330e8..ff63ed3a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ Contents async/index transports/index advanced/index + gql-cli/intro modules/gql diff --git a/gql/cli.py b/gql/cli.py new file mode 100644 index 00000000..e4d10a6c --- /dev/null +++ b/gql/cli.py @@ -0,0 +1,277 @@ +import json +import logging +import sys +from argparse import ArgumentParser, Namespace, RawDescriptionHelpFormatter +from typing import Any, Dict + +from graphql import GraphQLError +from yarl import URL + +from gql import Client, __version__, gql +from gql.transport import AsyncTransport +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.exceptions import TransportQueryError +from gql.transport.websockets import WebsocketsTransport + +description = """ +Send GraphQL queries from the command line using http(s) or websockets. +If used interactively, write your query, then use Ctrl-D (EOF) to execute it. +""" + +examples = """ +EXAMPLES +======== + +# Simple query using https +echo 'query { continent(code:"AF") { name } }' | \ +gql-cli https://countries.trevorblades.com + +# Simple query using websockets +echo 'query { continent(code:"AF") { name } }' | \ +gql-cli wss://countries.trevorblades.com/graphql + +# Query with variable +echo 'query getContinent($code:ID!) { continent(code:$code) { name } }' | \ +gql-cli https://countries.trevorblades.com --variables code:AF + +# Interactive usage (insert your query in the terminal, then press Ctrl-D to execute it) +gql-cli wss://countries.trevorblades.com/graphql --variables code:AF + +# Execute query saved in a file +cat query.gql | gql-cli wss://countries.trevorblades.com/graphql + +""" + + +def get_parser(with_examples: bool = False) -> ArgumentParser: + """Provides an ArgumentParser for the gql-cli script. + + This function is also used by sphinx to generate the script documentation. + + :param with_examples: set to False by default so that the examples are not + present in the sphinx docs (they are put there with + a different layout) + """ + + parser = ArgumentParser( + description=description, + epilog=examples if with_examples else None, + formatter_class=RawDescriptionHelpFormatter, + ) + parser.add_argument( + "server", help="the server url starting with http://, https://, ws:// or wss://" + ) + parser.add_argument( + "-V", + "--variables", + nargs="*", + help="query variables in the form key:json_value", + ) + parser.add_argument( + "-H", "--headers", nargs="*", help="http headers in the form key:value" + ) + parser.add_argument("--version", action="version", version=f"v{__version__}") + group = parser.add_mutually_exclusive_group() + group.add_argument( + "-d", + "--debug", + help="print lots of debugging statements (loglevel==DEBUG)", + action="store_const", + dest="loglevel", + const=logging.DEBUG, + ) + group.add_argument( + "-v", + "--verbose", + help="show low level messages (loglevel==INFO)", + action="store_const", + dest="loglevel", + const=logging.INFO, + ) + parser.add_argument( + "-o", + "--operation-name", + help="set the operation_name value", + dest="operation_name", + ) + + return parser + + +def get_transport_args(args: Namespace) -> Dict[str, Any]: + """Extract extra arguments necessary for the transport + from the parsed command line args + + Will create a headers dict by splitting the colon + in the --headers arguments + + :param args: parsed command line arguments + """ + + transport_args: Dict[str, Any] = {} + + # Parse the headers argument + headers = {} + if args.headers is not None: + for header in args.headers: + + try: + # Split only the first colon (throw a ValueError if no colon is present) + header_key, header_value = header.split(":", 1) + + headers[header_key] = header_value + + except ValueError: + raise ValueError(f"Invalid header: {header}") + + if args.headers is not None: + transport_args["headers"] = headers + + return transport_args + + +def get_execute_args(args: Namespace) -> Dict[str, Any]: + """Extract extra arguments necessary for the execute or subscribe + methods from the parsed command line args + + Extract the operation_name + + Extract the variable_values from the --variables argument + by splitting the first colon, then loads the json value, + We try to add double quotes around the value if it does not work first + in order to simplify the passing of simple string values + (we allow --variables KEY:VALUE instead of KEY:\"VALUE\") + + :param args: parsed command line arguments + """ + + execute_args: Dict[str, Any] = {} + + # Parse the operation_name argument + if args.operation_name is not None: + execute_args["operation_name"] = args.operation_name + + # Parse the variables argument + if args.variables is not None: + + variables = {} + + for var in args.variables: + + try: + # Split only the first colon + # (throw a ValueError if no colon is present) + variable_key, variable_json_value = var.split(":", 1) + + # Extract the json value, + # trying with double quotes if it does not work + try: + variable_value = json.loads(variable_json_value) + except json.JSONDecodeError: + try: + variable_value = json.loads(f'"{variable_json_value}"') + except json.JSONDecodeError: + raise ValueError + + # Save the value in the variables dict + variables[variable_key] = variable_value + + except ValueError: + raise ValueError(f"Invalid variable: {var}") + + execute_args["variable_values"] = variables + + return execute_args + + +def get_transport(args: Namespace) -> AsyncTransport: + """Instanciate a transport from the parsed command line arguments + + :param args: parsed command line arguments + """ + + # Get the url scheme from server parameter + url = URL(args.server) + scheme = url.scheme + + # Get extra transport parameters from command line arguments + # (headers) + transport_args = get_transport_args(args) + + # Instanciate transport depending on url scheme + transport: AsyncTransport + if scheme in ["ws", "wss"]: + transport = WebsocketsTransport( + url=args.server, ssl=(scheme == "wss"), **transport_args + ) + elif scheme in ["http", "https"]: + transport = AIOHTTPTransport(url=args.server, **transport_args) + else: + raise ValueError("URL protocol should be one of: http, https, ws, wss") + + return transport + + +async def main(args: Namespace) -> int: + """Main entrypoint of the gql-cli script + + :param args: The parsed command line arguments + :return: The script exit code (0 = ok, 1 = error) + """ + + # Set requested log level + if args.loglevel is not None: + logging.basicConfig(level=args.loglevel) + + try: + # Instanciate transport from command line arguments + transport = get_transport(args) + + # Get extra execute parameters from command line arguments + # (variables, operation_name) + execute_args = get_execute_args(args) + + except ValueError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + # By default, the exit_code is 0 (everything is ok) + exit_code = 0 + + # Connect to the backend and provide a session + async with Client(transport=transport) as session: + + while True: + + # Read multiple lines from input and trim whitespaces + # Will read until EOF character is received (Ctrl-D) + query_str = sys.stdin.read().strip() + + # Exit if query is empty + if len(query_str) == 0: + break + + # Parse query, continue on error + try: + query = gql(query_str) + except GraphQLError as e: + print(e, file=sys.stderr) + exit_code = 1 + continue + + # Execute or Subscribe the query depending on transport + try: + if isinstance(transport, WebsocketsTransport): + try: + async for result in session.subscribe(query, **execute_args): + print(json.dumps(result)) + except KeyboardInterrupt: # pragma: no cover + pass + else: + result = await session.execute(query, **execute_args) + print(json.dumps(result)) + except (GraphQLError, TransportQueryError) as e: + print(e, file=sys.stderr) + exit_code = 1 + + return exit_code diff --git a/scripts/gql-cli b/scripts/gql-cli index bbbb0e3b..055919ff 100755 --- a/scripts/gql-cli +++ b/scripts/gql-cli @@ -1,79 +1,18 @@ #!/usr/bin/env python3 - -from gql import gql, Client -from gql.transport.websockets import WebsocketsTransport -from gql.transport.aiohttp import AIOHTTPTransport -from yarl import URL import asyncio -import argparse import sys -import json - -parser = argparse.ArgumentParser( - description="Send GraphQL queries from command line using http(s) or websockets" -) -parser.add_argument( - "server", help="the server url starting with http://, https://, ws:// or wss://" -) -parser.add_argument( - "-p", "--params", nargs="*", help="query parameters in the form param:json_value" -) -args = parser.parse_args() - - -async def main(): - - # Parse the params argument - params = {} - if args.params is not None: - for p in args.params: - - try: - # Split only the first colon (throw a ValueError if no colon is present) - param_key, param_json_value = p.split(':', 1) - - # Extract the json value, trying with double quotes if it does not work - try: - param_value = json.loads(param_json_value) - except json.JSONDecodeError: - try: - param_value = json.loads(f'"{param_json_value}"') - except json.JSONDecodeError: - raise ValueError - # Save the value in the params dict - params[param_key] = param_value +from gql.cli import get_parser, main - except ValueError: - print (f"Invalid parameter: {p}", file=sys.stderr) - return 1 - - url = URL(args.server) - - scheme = url.scheme - - if scheme in ["ws", "wss"]: - transport = WebsocketsTransport(url=args.server, ssl=(scheme == "wss")) - elif scheme in ["http", "https"]: - transport = AIOHTTPTransport(url=args.server) - else: - raise ValueError("URL protocol should be one of: http, https, ws, wss") - - async with Client(transport=transport) as session: - - while True: - try: - query_str = input() - except EOFError: - break - - query = gql(query_str) +# Get arguments from command line +parser = get_parser(with_examples=True) +args = parser.parse_args() - if scheme in ["ws", "wss"]: - async for result in session.subscribe(query, variable_values=params): - print(result) - else: - result = await session.execute(query, variable_values=params) - print(result) +try: + # Execute the script + exit_code = asyncio.run(main(args)) -asyncio.run(main()) + # Return with the correct exit code + sys.exit(exit_code) +except KeyboardInterrupt: + pass diff --git a/setup.py b/setup.py index 9233eab3..a4389fdf 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ "mypy==0.770", "sphinx>=3.0.0,<4", "sphinx_rtd_theme>=0.4,<1", + "sphinx-argparse==0.2.5", ] + tests_require # Get version from __version__.py file diff --git a/tests/conftest.py b/tests/conftest.py index c2a15605..709adeb0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ import os import pathlib import ssl +import sys import tempfile import types from concurrent.futures import ThreadPoolExecutor @@ -239,7 +240,7 @@ async def default_server_handler(ws, path): for answer in answers: result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) if isinstance(answer, str) and "{query_id}" in answer: answer_format_params = {"query_id": query_id} diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 8f39319f..1a0d3af0 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -1,7 +1,11 @@ +import io +import json + import pytest from aiohttp import DummyCookieJar, web from gql import Client, gql +from gql.cli import get_parser, main from gql.transport.aiohttp import AIOHTTPTransport from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -22,15 +26,18 @@ } """ -query1_server_answer = ( - '{"data":{"continents":[' +query1_server_answer_data = ( + '{"continents":[' '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' - '{"code":"SA","name":"South America"}]}}' + '{"code":"SA","name":"South America"}]}' ) +query1_server_answer = f'{{"data":{query1_server_answer_data}}}' + + @pytest.mark.asyncio async def test_aiohttp_query(event_loop, aiohttp_server): async def handler(request): @@ -680,3 +687,103 @@ async def handler(request): success = result["success"] assert success + + +@pytest.mark.asyncio +async def test_aiohttp_using_cli(event_loop, aiohttp_server, monkeypatch, capsys): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, "--verbose"]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer + + +@pytest.mark.asyncio +async def test_aiohttp_using_cli_invalid_param( + event_loop, aiohttp_server, monkeypatch, capsys +): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url, "--variables", "invalid_param"]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Checking that sys.exit() is called + with pytest.raises(SystemExit): + await main(args) + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = "Error: Invalid variable: invalid_param" + + assert expected_error in captured_err + + +@pytest.mark.asyncio +async def test_aiohttp_using_cli_invalid_query( + event_loop, aiohttp_server, monkeypatch, capsys +): + async def handler(request): + return web.Response(text=query1_server_answer, content_type="application/json") + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Send invalid query on standard input + monkeypatch.setattr("sys.stdin", io.StringIO("BLAHBLAH")) + + exit_code = await main(args) + + assert exit_code == 1 + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = "Syntax Error: Unexpected Name 'BLAHBLAH'" + + assert expected_error in captured_err diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 00000000..44f61a15 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,175 @@ +import logging + +import pytest + +from gql.cli import get_execute_args, get_parser, get_transport, get_transport_args +from gql.transport.aiohttp import AIOHTTPTransport +from gql.transport.websockets import WebsocketsTransport + + +@pytest.fixture +def parser(): + return get_parser() + + +def test_cli_parser(parser): + + # Simple call with https server + # gql-cli https://your_server.com + args = parser.parse_args(["https://your_server.com"]) + + assert args.server == "https://your_server.com" + assert args.headers is None + assert args.loglevel is None + assert args.operation_name is None + assert args.variables is None + + # Call with variable values parameters + # gql-cli https://your_server.com --variables KEY1:value1 KEY2:value2 + args = parser.parse_args( + ["https://your_server.com", "--variables", "KEY1:value1", "KEY2:value2"] + ) + + assert args.server == "https://your_server.com" + assert args.variables == ["KEY1:value1", "KEY2:value2"] + + # Call with headers values parameters + # gql-cli https://your_server.com --headers HEADER1:value1 HEADER2:value2 + args = parser.parse_args( + ["https://your_server.com", "--headers", "HEADER1:value1", "HEADER2:value2"] + ) + + assert args.server == "https://your_server.com" + assert args.headers == ["HEADER1:value1", "HEADER2:value2"] + + # Call with header value with a space in value + # gql-cli https://your_server.com --headers Authorization:"Bearer blahblah" + args = parser.parse_args( + ["https://your_server.com", "--headers", "Authorization:Bearer blahblah"] + ) + + assert args.server == "https://your_server.com" + assert args.headers == ["Authorization:Bearer blahblah"] + + # Check loglevel flags + # gql-cli https://your_server.com --debug + args = parser.parse_args(["https://your_server.com", "--debug"]) + assert args.loglevel == logging.DEBUG + + # gql-cli https://your_server.com --verbose + args = parser.parse_args(["https://your_server.com", "--verbose"]) + + assert args.loglevel == logging.INFO + + # Check operation_name + # gql-cli https://your_server.com --operation-name my_operation + args = parser.parse_args( + ["https://your_server.com", "--operation-name", "my_operation"] + ) + assert args.operation_name == "my_operation" + + +def test_cli_parse_headers(parser): + + args = parser.parse_args( + [ + "https://your_server.com", + "--headers", + "Token1:1234", + "Token2:5678", + "TokenWithSpace:abc def", + "TokenWithColon:abc:def", + ] + ) + + transport_args = get_transport_args(args) + + expected_headers = { + "Token1": "1234", + "Token2": "5678", + "TokenWithSpace": "abc def", + "TokenWithColon": "abc:def", + } + + assert transport_args == {"headers": expected_headers} + + +def test_cli_parse_headers_invalid_header(parser): + + args = parser.parse_args( + ["https://your_server.com", "--headers", "TokenWithoutColon"] + ) + + with pytest.raises(ValueError): + get_transport_args(args) + + +def test_cli_parse_operation_name(parser): + + args = parser.parse_args(["https://your_server.com", "--operation-name", "myop"]) + + execute_args = get_execute_args(args) + + assert execute_args == {"operation_name": "myop"} + + +@pytest.mark.parametrize( + "param", + [ + {"args": ["key:abcdef"], "d": {"key": "abcdef"}}, + {"args": ['key:"abcdef"'], "d": {"key": "abcdef"}}, + {"args": ["key:1234"], "d": {"key": 1234}}, + {"args": ["key1:1234", "key2:5678"], "d": {"key1": 1234, "key2": 5678}}, + {"args": ["key1:null"], "d": {"key1": None}}, + {"args": ["key1:true"], "d": {"key1": True}}, + {"args": ["key1:false"], "d": {"key1": False}}, + { + "args": ["key1:null", "key2:abcd", "key3:5"], + "d": {"key1": None, "key2": "abcd", "key3": 5}, + }, + ], +) +def test_cli_parse_variable_value(parser, param): + + args = parser.parse_args(["https://your_server.com", "--variables", *param["args"]]) + + execute_args = get_execute_args(args) + + expected_variable_values = param["d"] + + assert execute_args == {"variable_values": expected_variable_values} + + +@pytest.mark.parametrize("param", ["nocolon", 'key:"']) +def test_cli_parse_variable_value_invalid_param(parser, param): + + args = parser.parse_args(["https://your_server.com", "--variables", param]) + + with pytest.raises(ValueError): + get_execute_args(args) + + +@pytest.mark.parametrize( + "param", + [ + {"args": ["http://your_server.com"], "class": AIOHTTPTransport}, + {"args": ["https://your_server.com"], "class": AIOHTTPTransport}, + {"args": ["ws://your_server.com/graphql"], "class": WebsocketsTransport}, + {"args": ["wss://your_server.com/graphql"], "class": WebsocketsTransport}, + ], +) +def test_cli_get_transport(parser, param): + + args = parser.parse_args([*param["args"]]) + + transport = get_transport(args) + + assert isinstance(transport, param["class"]) + + +def test_cli_get_transport_no_protocol(parser): + + args = parser.parse_args(["your_server.com"]) + + with pytest.raises(ValueError): + get_transport(args) diff --git a/tests/test_websocket_exceptions.py b/tests/test_websocket_exceptions.py index a2678a4a..0d4b3e05 100644 --- a/tests/test_websocket_exceptions.py +++ b/tests/test_websocket_exceptions.py @@ -348,3 +348,37 @@ async def client_connect(client): with pytest.raises(TransportAlreadyConnected): await asyncio.gather(connect_task1, connect_task2) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [invalid_query1_server], indirect=True) +async def test_websocket_using_cli_invalid_query( + event_loop, server, monkeypatch, capsys +): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + from gql.cli import main, get_parser + import io + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(invalid_query_str)) + + # Flush captured output + captured = capsys.readouterr() + + await main(args) + + # Check that the error has been printed on stdout + captured = capsys.readouterr() + captured_err = str(captured.err).strip() + print(f"Captured: {captured_err}") + + expected_error = 'Cannot query field "bloh" on type "Continent"' + + assert expected_error in captured_err diff --git a/tests/test_websocket_query.py b/tests/test_websocket_query.py index d44aa779..0aa77e88 100644 --- a/tests/test_websocket_query.py +++ b/tests/test_websocket_query.py @@ -1,6 +1,7 @@ import asyncio import json import ssl +import sys from typing import Dict import pytest @@ -26,6 +27,14 @@ } """ +query1_server_answer_data = ( + '{"continents":[' + '{"code":"AF","name":"Africa"},{"code":"AN","name":"Antarctica"},' + '{"code":"AS","name":"Asia"},{"code":"EU","name":"Europe"},' + '{"code":"NA","name":"North America"},{"code":"OC","name":"Oceania"},' + '{"code":"SA","name":"South America"}]}' +) + query1_server_answer = ( '{{"type":"data","id":"{query_id}","payload":{{"data":{{"continents":[' '{{"code":"AF","name":"Africa"}},{{"code":"AN","name":"Antarctica"}},' @@ -155,9 +164,9 @@ async def test_websocket_two_queries_in_series( async def server1_two_queries_in_parallel(ws, path): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) await ws.send(query1_server_answer.format(query_id=1)) await ws.send(query1_server_answer.format(query_id=2)) await WebSocketServerHelper.send_complete(ws, 1) @@ -202,7 +211,7 @@ async def task2_coro(): async def server_closing_while_we_are_doing_something_else(ws, path): await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) await ws.send(query1_server_answer.format(query_id=1)) await WebSocketServerHelper.send_complete(ws, 1) await asyncio.sleep(1 * MS) @@ -348,7 +357,7 @@ async def server_with_authentication_in_connection_init_payload(ws, path): await ws.send('{"type":"connection_ack"}') result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) await ws.send(query1_server_answer.format(query_id=1)) await WebSocketServerHelper.send_complete(ws, 1) else: @@ -481,7 +490,7 @@ async def server_sending_keep_alive_before_connection_ack(ws, path): await WebSocketServerHelper.send_keepalive(ws) await WebSocketServerHelper.send_connection_ack(ws) result = await ws.recv() - print(f"Server received: {result}") + print(f"Server received: {result}", file=sys.stderr) await ws.send(query1_server_answer.format(query_id=1)) await WebSocketServerHelper.send_complete(ws, 1) await ws.wait_closed() @@ -512,3 +521,39 @@ async def test_websocket_non_regression_bug_108( africa = continents[0] assert africa["code"] == "AF" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server1_answers], indirect=True) +async def test_websocket_using_cli(event_loop, server, monkeypatch, capsys): + + url = f"ws://{server.hostname}:{server.port}/graphql" + print(f"url = {url}") + + from gql.cli import main, get_parser + import io + import json + + parser = get_parser(with_examples=True) + args = parser.parse_args([url]) + + # Monkeypatching sys.stdin to simulate getting the query + # via the standard input + monkeypatch.setattr("sys.stdin", io.StringIO(query1_str)) + + # Flush captured output + captured = capsys.readouterr() + + exit_code = await main(args) + + assert exit_code == 0 + + # Check that the result has been printed on stdout + captured = capsys.readouterr() + captured_out = str(captured.out).strip() + + expected_answer = json.loads(query1_server_answer_data) + print(f"Captured: {captured_out}") + received_answer = json.loads(captured_out) + + assert received_answer == expected_answer