diff --git a/.circleci/config.yml b/.circleci/config.yml index 55a8f07..b85de02 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,9 +1,9 @@ version: 2 jobs: - python2-test: + python3-test: docker: - - image: "python:2.7-stretch" + - image: "python:3.5-stretch" steps: - checkout - run: pip install -e .[test] @@ -12,24 +12,6 @@ jobs: - run: pycodestyle pyls_jsonrpc test - run: pyflakes pyls_jsonrpc test - python3-test: - docker: - - image: "python:3.5-stretch" - steps: - - checkout - - run: pip install -e .[test] - - run: py.test test/ - - lint: - docker: - - image: "python:2.7-stretch" - steps: - - checkout - - run: pip install -e .[all] .[test] - - run: pylint pyls_jsonrpc test - - run: pycodestyle pyls_jsonrpc test - - run: pyflakes pyls_jsonrpc test - publish: docker: - image: "python:3.5-stretch" @@ -42,8 +24,6 @@ workflows: version: 2 build: jobs: - - python2-test: - filters: { tags: { only: /.*/ } } - python3-test: filters: { tags: { only: /.*/ } } - publish: @@ -53,5 +33,4 @@ workflows: branches: ignore: /.*/ requires: - - python2-test - python3-test diff --git a/appveyor.yml b/appveyor.yml index d09b552..ce22666 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -1,9 +1,5 @@ environment: matrix: - - PYTHON: "C:\\Python27" - PYTHON_VERSION: "2.7.15" - PYTHON_ARCH: "64" - - PYTHON: "C:\\Python35" PYTHON_VERSION: "3.5.7" PYTHON_ARCH: "64" diff --git a/pyls_jsonrpc/dispatchers.py b/pyls_jsonrpc/dispatchers.py index 22c8a3b..bde3228 100644 --- a/pyls_jsonrpc/dispatchers.py +++ b/pyls_jsonrpc/dispatchers.py @@ -6,11 +6,11 @@ _RE_ALL_CAP = re.compile('([a-z0-9])([A-Z])') -class MethodDispatcher(object): +class MethodDispatcher: """JSON RPC dispatcher that calls methods on itself. - Method names are computed by converting camel case to snake case, slashes with double underscores, and removing - dollar signs. + Method names are computed by converting camel case to snake case, slashes + with double underscores, and removing dollar signs. """ def __getitem__(self, item): @@ -19,17 +19,17 @@ def __getitem__(self, item): method = getattr(self, method_name) @functools.wraps(method) - def handler(params): + async def handler(params): return method(**(params or {})) return handler raise KeyError() -def _method_to_string(method): +def _method_to_string(method: str) -> str: return _camel_to_underscore(method.replace("/", "__").replace("$", "")) -def _camel_to_underscore(string): +def _camel_to_underscore(string: str) -> str: s1 = _RE_FIRST_CAP.sub(r'\1_\2', string) return _RE_ALL_CAP.sub(r'\1_\2', s1).lower() diff --git a/pyls_jsonrpc/dispatchers.pyi b/pyls_jsonrpc/dispatchers.pyi new file mode 100644 index 0000000..b6339a0 --- /dev/null +++ b/pyls_jsonrpc/dispatchers.pyi @@ -0,0 +1,6 @@ + +from typing import Coroutine, Any + + +class MethodDispatcher(object): + def __getitem__(self, item: Any) -> Coroutine: ... diff --git a/pyls_jsonrpc/endpoint.py b/pyls_jsonrpc/endpoint.py index 00baf87..d945da1 100644 --- a/pyls_jsonrpc/endpoint.py +++ b/pyls_jsonrpc/endpoint.py @@ -1,42 +1,56 @@ # Copyright 2018 Palantir Technologies, Inc. +import asyncio import logging import uuid import sys -from concurrent import futures -from .exceptions import JsonRpcException, JsonRpcRequestCancelled, JsonRpcInternalError, JsonRpcMethodNotFound +from typing import Dict, Awaitable + +from .exceptions import ( + JsonRpcException, JsonRpcRequestCancelled, + JsonRpcInternalError, JsonRpcMethodNotFound) log = logging.getLogger(__name__) JSONRPC_VERSION = '2.0' CANCEL_METHOD = '$/cancelRequest' -class Endpoint(object): +class Endpoint: - def __init__(self, dispatcher, consumer, id_generator=lambda: str(uuid.uuid4()), max_workers=5): + def __init__(self, + dispatcher, + consumer, + id_generator=lambda: str(uuid.uuid4()), + loop=None): """A JSON RPC endpoint for managing messages sent to/from the client. Args: dispatcher (dict): A dictionary of method name to handler function. - The handler functions should return either the result or a callable that will be used to asynchronously - compute the result. - consumer (fn): A function that consumes JSON RPC message dicts and sends them to the client. - id_generator (fn, optional): A function used to generate request IDs. - Defaults to the string value of :func:`uuid.uuid4`. - max_workers (int, optional): The number of workers in the asynchronous executor pool. + The handler functions should return either the result or a + callable that will be used to asynchronously compute + the result. + consumer (fn): A function that consumes JSON RPC message dicts and + sends them to the client. + id_generator (fn, optional): A function used to generate + request IDs. Defaults to the string value + of :func:`uuid.uuid4`. + max_workers (int, optional): The number of workers in the + asynchronous executor pool. """ self._dispatcher = dispatcher self._consumer = consumer self._id_generator = id_generator - self._client_request_futures = {} - self._server_request_futures = {} - self._executor_service = futures.ThreadPoolExecutor(max_workers=max_workers) + self.loop = ( + asyncio.get_event_loop() if loop is None else loop) # type: asyncio.BaseEventLoop + self._client_request_futures = {} # type: Dict[str, Awaitable] + self._server_request_futures = {} # type: Dict[str, Awaitable] def shutdown(self): - self._executor_service.shutdown() + # self._executor_service.shutdown() + self.loop.close() - def notify(self, method, params=None): + async def notify(self, method, params=None): """Send a JSON RPC notification to the client. Args: @@ -44,7 +58,6 @@ def notify(self, method, params=None): params (any): The payload of the notification """ log.debug('Sending notification: %s %s', method, params) - message = { 'jsonrpc': JSONRPC_VERSION, 'method': method, @@ -52,9 +65,9 @@ def notify(self, method, params=None): if params is not None: message['params'] = params - self._consumer(message) + await self._consumer(message) - def request(self, method, params=None): + async def request(self, method, params=None): """Send a JSON RPC request to the client. Args: @@ -75,11 +88,11 @@ def request(self, method, params=None): if params is not None: message['params'] = params - request_future = futures.Future() + request_future = asyncio.Future() request_future.add_done_callback(self._cancel_callback(msg_id)) self._server_request_futures[msg_id] = request_future - self._consumer(message) + await self._consumer(message) return request_future @@ -87,11 +100,12 @@ def _cancel_callback(self, request_id): """Construct a cancellation callback for the given request ID.""" def callback(future): if future.cancelled(): - self.notify(CANCEL_METHOD, {'id': request_id}) - future.set_exception(JsonRpcRequestCancelled()) + asyncio.ensure_future( + self.notify(CANCEL_METHOD, {'id': request_id}), + loop=self.loop) return callback - def consume(self, message): + async def consume(self, message): """Consume a JSON RPC message from the client. Args: @@ -103,33 +117,36 @@ def consume(self, message): if 'id' not in message: log.debug("Handling notification from client %s", message) - self._handle_notification(message['method'], message.get('params')) + await self._handle_notification( + message['method'], message.get('params')) elif 'method' not in message: log.debug("Handling response from client %s", message) - self._handle_response(message['id'], message.get('result'), message.get('error')) + await self._handle_response( + message['id'], message.get('result'), message.get('error')) else: try: log.debug("Handling request from client %s", message) - self._handle_request(message['id'], message['method'], message.get('params')) + await self._handle_request( + message['id'], message['method'], message.get('params')) except JsonRpcException as e: log.exception("Failed to handle request %s", message['id']) - self._consumer({ + await self._consumer({ 'jsonrpc': JSONRPC_VERSION, 'id': message['id'], 'error': e.to_dict() }) except Exception: # pylint: disable=broad-except log.exception("Failed to handle request %s", message['id']) - self._consumer({ + await self._consumer({ 'jsonrpc': JSONRPC_VERSION, 'id': message['id'], 'error': JsonRpcInternalError.of(sys.exc_info()).to_dict() }) - def _handle_notification(self, method, params): + async def _handle_notification(self, method, params): """Handle a notification from the client.""" if method == CANCEL_METHOD: - self._handle_cancel_notification(params['id']) + await self._handle_cancel_notification(params['id']) return try: @@ -141,13 +158,16 @@ def _handle_notification(self, method, params): try: handler_result = handler(params) except Exception: # pylint: disable=broad-except - log.exception("Failed to handle notification %s: %s", method, params) + log.exception( + "Failed to handle notification %s: %s", method, params) return if callable(handler_result): - log.debug("Executing async notification handler %s", handler_result) - notification_future = self._executor_service.submit(handler_result) - notification_future.add_done_callback(self._notification_callback(method, params)) + log.debug( + "Executing async notification handler %s", handler_result) + notification_task = asyncio.ensure_future(handler_result) + notification_task.add_done_callback( + self._notification_callback(method, params)) @staticmethod def _notification_callback(method, params): @@ -155,24 +175,27 @@ def _notification_callback(method, params): def callback(future): try: future.result() - log.debug("Successfully handled async notification %s %s", method, params) + log.debug("Successfully handled async notification %s %s", + method, params) except Exception: # pylint: disable=broad-except - log.exception("Failed to handle async notification %s %s", method, params) + log.exception("Failed to handle async notification %s %s", + method, params) return callback - def _handle_cancel_notification(self, msg_id): + async def _handle_cancel_notification(self, msg_id): """Handle a cancel notification from the client.""" request_future = self._client_request_futures.pop(msg_id, None) if not request_future: - log.warning("Received cancel notification for unknown message id %s", msg_id) + log.warning("Received cancel notification for unknown message id %s", + msg_id) return # Will only work if the request hasn't started executing if request_future.cancel(): log.debug("Cancelled request with id %s", msg_id) - def _handle_request(self, msg_id, method, params): + async def _handle_request(self, msg_id, method, params): """Handle a request from the client.""" try: handler = self._dispatcher[method] @@ -181,18 +204,20 @@ def _handle_request(self, msg_id, method, params): handler_result = handler(params) - if callable(handler_result): + if callable(handler_result) or asyncio.iscoroutine(handler_result): log.debug("Executing async request handler %s", handler_result) - request_future = self._executor_service.submit(handler_result) - self._client_request_futures[msg_id] = request_future - request_future.add_done_callback(self._request_callback(msg_id)) - elif isinstance(handler_result, futures.Future): + # request_future = self._executor_service.submit(handler_result) + request_task = asyncio.ensure_future(handler_result) + self._client_request_futures[msg_id] = request_task + request_task.add_done_callback(self._request_callback(msg_id)) + elif isinstance(handler_result, asyncio.Future): log.debug("Request handler is already a future %s", handler_result) self._client_request_futures[msg_id] = handler_result handler_result.add_done_callback(self._request_callback(msg_id)) else: - log.debug("Got result from synchronous request handler: %s", handler_result) - self._consumer({ + log.debug("Got result from synchronous request handler: %s", + handler_result) + await self._consumer({ 'jsonrpc': JSONRPC_VERSION, 'id': msg_id, 'result': handler_result @@ -219,24 +244,25 @@ def callback(future): message['error'] = e.to_dict() except Exception: # pylint: disable=broad-except log.exception("Failed to handle request %s", request_id) - message['error'] = JsonRpcInternalError.of(sys.exc_info()).to_dict() + message['error'] = JsonRpcInternalError.of( + sys.exc_info()).to_dict() - self._consumer(message) + asyncio.ensure_future(self._consumer(message), loop=self.loop) return callback - def _handle_response(self, msg_id, result=None, error=None): + async def _handle_response(self, msg_id, result=None, error=None): """Handle a response from the client.""" - request_future = self._server_request_futures.pop(msg_id, None) + request_future = self._server_request_futures.pop(msg_id, None) # type: asyncio.Future if not request_future: log.warning("Received response to unknown message id %s", msg_id) return if error is not None: - log.debug("Received error response to message %s: %s", msg_id, error) + log.debug( + "Received error response to message %s: %s", msg_id, error) request_future.set_exception(JsonRpcException.from_dict(error)) - return - - log.debug("Received result for message %s: %s", msg_id, result) - request_future.set_result(result) + else: + log.debug("Received result for message %s: %s", msg_id, result) + request_future.set_result(result) diff --git a/pyls_jsonrpc/endpoint.pyi b/pyls_jsonrpc/endpoint.pyi new file mode 100644 index 0000000..ad0efe6 --- /dev/null +++ b/pyls_jsonrpc/endpoint.pyi @@ -0,0 +1,16 @@ + +import asyncio +from typing import Dict, Coroutine, Union, Callable, Awaitable + + +class Endpoint(object): + def __init__(self, dispatcher: Dict[str, Union[Coroutine, Dict, None]], + consumer: Coroutine, + id_generator: Callable[[], str], + max_workers: int = 5, + loop: asyncio.BaseEventLoop = None): ... + + def shutdown(self) -> None: ... + async def notify(self, method: str, params: Dict = None) -> None: ... + async def request(self, method: str, params: Dict = None) -> Awaitable: ... + async def consume(self, message: Dict) -> None: ... diff --git a/pyls_jsonrpc/streams.py b/pyls_jsonrpc/streams.py index 3277437..9eb2610 100644 --- a/pyls_jsonrpc/streams.py +++ b/pyls_jsonrpc/streams.py @@ -1,6 +1,7 @@ # Copyright 2018 Palantir Technologies, Inc. +import asyncio import logging -import threading +import functools try: import ujson as json @@ -10,45 +11,50 @@ log = logging.getLogger(__name__) -class JsonRpcStreamReader(object): +class JsonRpcStreamReader: - def __init__(self, rfile): + def __init__(self, rfile, loop=None): self._rfile = rfile + self.loop = asyncio.get_event_loop() if loop is not None else loop - def close(self): - self._rfile.close() + def close(self) -> None: + # self.close = True + self._rfile.feed_eof() + # self._rfile.close() - def listen(self, message_consumer): + async def listen(self, message_consumer): """Blocking call to listen for messages on the rfile. Args: - message_consumer (fn): function that is passed each message as it is read off the socket. + message_consumer (fn): function that is passed each message as it + is read off the socket. """ - while not self._rfile.closed: + while not self._rfile.at_eof(): try: - request_str = self._read_message() + request_str = await self._read_message() except ValueError: - if self._rfile.closed: + if self._rfile.at_eof(): return - else: - log.exception("Failed to read from rfile") + + log.exception("Failed to read from rfile") if request_str is None: break try: - message_consumer(json.loads(request_str.decode('utf-8'))) + body = json.loads(request_str.decode('utf-8')) + asyncio.ensure_future(message_consumer(body), loop=self.loop) except ValueError: log.exception("Failed to parse JSON message %s", request_str) continue - def _read_message(self): + async def _read_message(self): """Reads the contents of a message. Returns: body of message if parsable else None """ - line = self._rfile.readline() + line = await self._rfile.readline() if not line: return None @@ -57,13 +63,14 @@ def _read_message(self): # Blindly consume all header lines while line and line.strip(): - line = self._rfile.readline() + line = await self._rfile.readline() if not line: return None + content = await self._rfile.read(content_length) # Grab the body - return self._rfile.read(content_length) + return content @staticmethod def _content_length(line): @@ -74,39 +81,47 @@ def _content_length(line): try: return int(value) except ValueError: - raise ValueError("Invalid Content-Length header: {}".format(value)) + raise ValueError( + "Invalid Content-Length header: {}".format(value)) return None -class JsonRpcStreamWriter(object): +class JsonRpcStreamWriter: - def __init__(self, wfile, **json_dumps_args): + def __init__(self, wfile, loop=None, **json_dumps_args): self._wfile = wfile - self._wfile_lock = threading.Lock() + self._wfile_lock = asyncio.Lock() self._json_dumps_args = json_dumps_args + self.loop = asyncio.get_event_loop() if loop is None else loop - def close(self): - with self._wfile_lock: + async def close(self): + async with self._wfile_lock: self._wfile.close() + await self._wfile.wait_closed() - def write(self, message): - with self._wfile_lock: - if self._wfile.closed: + async def write(self, message): + async with self._wfile_lock: + if self._wfile.is_closing(): return try: - body = json.dumps(message, **self._json_dumps_args) + body = await self.loop.run_in_executor( + None, functools.partial( + json.dumps, message, **self._json_dumps_args)) # Ensure we get the byte length, not the character length - content_length = len(body) if isinstance(body, bytes) else len(body.encode('utf-8')) + content_length = (len(body) if isinstance(body, bytes) else + len(body.encode('utf-8'))) response = ( "Content-Length: {}\r\n" - "Content-Type: application/vscode-jsonrpc; charset=utf8\r\n\r\n" + "Content-Type: application/vscode-jsonrpc; " + "charset=utf8\r\n\r\n" "{}".format(content_length, body) ) self._wfile.write(response.encode('utf-8')) - self._wfile.flush() + await self._wfile.drain() except Exception: # pylint: disable=broad-except - log.exception("Failed to write message to output file %s", message) + log.exception( + "Failed to write message to output file %s", message) diff --git a/pyls_jsonrpc/streams.pyi b/pyls_jsonrpc/streams.pyi new file mode 100644 index 0000000..a52b522 --- /dev/null +++ b/pyls_jsonrpc/streams.pyi @@ -0,0 +1,23 @@ + +import asyncio +from typing import Coroutine, Dict, Union, Any + + +class JsonRpcStreamReader(object): + def __init__(self, rfile: asyncio.StreamReader, + loop: asyncio.AbstractEventLoop = None): ... + + def close(self) -> None: ... + async def listen(self, message_consumer: Coroutine) -> Any: ... + async def _read_message(self) -> Union[str, None]: ... + + @staticmethod + def _content_length(line: str) -> Union[int, None]: ... + + +class JsonRpcStreamWriter(object): + def __init__(self, wfile: asyncio.StreamWriter, + loop: asyncio.AbstractEventLoop = None, + **json_dumps_args): ... + async def close(self) -> None: ... + async def write(self, message: Dict) -> None: ... diff --git a/setup.py b/setup.py index 3c6ba2a..d0a7f24 100755 --- a/setup.py +++ b/setup.py @@ -6,10 +6,7 @@ README = open('README.rst', 'r').read() -install_requires = [ - 'future>=0.14.0; python_version<"3"', - 'futures; python_version<"3.2"', -] +install_requires = [] if sys.version_info[0] == 2: install_requires.append('ujson<=2.0.3; platform_system!="Windows"') @@ -49,6 +46,7 @@ # for example: # $ pip install -e .[test] extras_require={ - 'test': ['versioneer', 'pylint', 'pycodestyle', 'pyflakes', 'pytest', 'mock', 'pytest-cov', 'coverage'], + 'test': ['versioneer', 'pylint', 'pycodestyle', 'pyflakes', 'pytest', + 'mock', 'pytest-cov', 'coverage', 'pytest-asyncio'], }, ) diff --git a/test/test_endpoint.py b/test/test_endpoint.py index b954732..b550f26 100644 --- a/test/test_endpoint.py +++ b/test/test_endpoint.py @@ -1,6 +1,7 @@ # Copyright 2018 Palantir Technologies, Inc. # pylint: disable=redefined-outer-name -from concurrent import futures +# from concurrent import futures +import asyncio import time import mock import pytest @@ -11,6 +12,13 @@ MSG_ID = 'id' +async def async_magic(): + pass + +mock.MagicMock.__await__ = lambda x: async_magic().__await__() +mock.Mock.__await__ = lambda x: async_magic().__await__() + + @pytest.fixture() def dispatcher(): return {} @@ -22,17 +30,20 @@ def consumer(): @pytest.fixture() -def endpoint(dispatcher, consumer): - return Endpoint(dispatcher, consumer, id_generator=lambda: MSG_ID) +def endpoint(dispatcher, consumer, event_loop): + return Endpoint( + dispatcher, consumer, id_generator=lambda: MSG_ID, loop=event_loop) -def test_bad_message(endpoint): +@pytest.mark.asyncio +async def test_bad_message(endpoint: Endpoint): # Ensure doesn't raise for a bad message - endpoint.consume({'key': 'value'}) + await endpoint.consume({'key': 'value'}) -def test_notify(endpoint, consumer): - endpoint.notify('methodName', {'key': 'value'}) +@pytest.mark.asyncio +async def test_notify(endpoint, consumer): + await endpoint.notify('methodName', {'key': 'value'}) consumer.assert_called_once_with({ 'jsonrpc': '2.0', 'method': 'methodName', @@ -40,16 +51,18 @@ def test_notify(endpoint, consumer): }) -def test_notify_none_params(endpoint, consumer): - endpoint.notify('methodName', None) +@pytest.mark.asyncio +async def test_notify_none_params(endpoint, consumer): + await endpoint.notify('methodName', None) consumer.assert_called_once_with({ 'jsonrpc': '2.0', 'method': 'methodName', }) -def test_request(endpoint, consumer): - future = endpoint.request('methodName', {'key': 'value'}) +@pytest.mark.asyncio +async def test_request(endpoint: Endpoint, consumer): + future = await endpoint.request('methodName', {'key': 'value'}) assert not future.done() consumer.assert_called_once_with({ @@ -61,17 +74,18 @@ def test_request(endpoint, consumer): # Send the response back to the endpoint result = 1234 - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'result': result }) - - assert future.result(timeout=2) == result + fut_result = await future + assert fut_result == result -def test_request_error(endpoint, consumer): - future = endpoint.request('methodName', {'key': 'value'}) +@pytest.mark.asyncio +async def test_request_error(endpoint: Endpoint, consumer): + future = await endpoint.request('methodName', {'key': 'value'}) assert not future.done() consumer.assert_called_once_with({ @@ -83,21 +97,23 @@ def test_request_error(endpoint, consumer): # Send an error back from the client error = exceptions.JsonRpcInvalidRequest(data=1234) - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'error': error.to_dict() }) - # Verify the exception raised by the future is the same as the error the client serialized + # Verify the exception raised by the future is the same as the error + # the client serialized with pytest.raises(exceptions.JsonRpcException) as exc_info: - assert future.result(timeout=2) + await future # .result(timeout=2) assert exc_info.type == exceptions.JsonRpcInvalidRequest assert exc_info.value == error -def test_request_cancel(endpoint, consumer): - future = endpoint.request('methodName', {'key': 'value'}) +@pytest.mark.asyncio +async def test_request_cancel(endpoint: Endpoint, consumer): + future = await endpoint.request('methodName', {'key': 'value'}) assert not future.done() consumer.assert_called_once_with({ @@ -109,34 +125,35 @@ def test_request_cancel(endpoint, consumer): # Cancel the request future.cancel() + await asyncio.sleep(0.5) consumer.assert_any_call({ 'jsonrpc': '2.0', 'method': '$/cancelRequest', 'params': {'id': MSG_ID} }) - with pytest.raises((exceptions.JsonRpcException, futures.CancelledError)) as exc_info: - assert future.result(timeout=2) - assert exc_info.type in (exceptions.JsonRpcRequestCancelled, futures.CancelledError) - -def test_consume_notification(endpoint, dispatcher): +@pytest.mark.asyncio +async def test_consume_notification(endpoint: Endpoint, dispatcher): handler = mock.Mock() + # handler = asyncio.Future() dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': 'methodName', 'params': {'key': 'value'} }) + # await handler handler.assert_called_once_with({'key': 'value'}) -def test_consume_notification_error(endpoint, dispatcher): +@pytest.mark.asyncio +async def test_consume_notification_error(endpoint, dispatcher): handler = mock.Mock(side_effect=ValueError) dispatcher['methodName'] = handler # Verify the consume doesn't throw - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': 'methodName', 'params': {'key': 'value'} @@ -144,23 +161,26 @@ def test_consume_notification_error(endpoint, dispatcher): handler.assert_called_once_with({'key': 'value'}) -def test_consume_notification_method_not_found(endpoint): +@pytest.mark.asyncio +async def test_consume_notification_method_not_found(endpoint): # Verify consume doesn't throw for method not found - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': 'methodName', 'params': {'key': 'value'} }) -def test_consume_async_notification_error(endpoint, dispatcher): - def _async_handler(): +@pytest.mark.asyncio +async def test_consume_async_notification_error(endpoint, dispatcher): + async def _async_handler(): raise ValueError() - handler = mock.Mock(return_value=_async_handler) + + handler = mock.Mock(wraps=_async_handler) dispatcher['methodName'] = handler # Verify the consume doesn't throw - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': 'methodName', 'params': {'key': 'value'} @@ -168,12 +188,13 @@ def _async_handler(): handler.assert_called_once_with({'key': 'value'}) -def test_consume_request(endpoint, consumer, dispatcher): +@pytest.mark.asyncio +async def test_consume_request(endpoint, consumer, dispatcher): result = 1234 handler = mock.Mock(return_value=result) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -188,12 +209,18 @@ def test_consume_request(endpoint, consumer, dispatcher): }) -def test_consume_future_request(endpoint, consumer, dispatcher): - future_response = futures.ThreadPoolExecutor().submit(lambda: 1234) - handler = mock.Mock(return_value=future_response) +@pytest.mark.asyncio +async def test_consume_future_request( + endpoint: Endpoint, consumer, dispatcher): + # future_response = futures.ThreadPoolExecutor().submit(lambda: 1234) + async def future_wrap(*args, **kwargs): # pylint: disable=unused-argument + return 1234 + + task = asyncio.ensure_future(future_wrap()) + handler = mock.Mock(return_value=task) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -201,20 +228,23 @@ def test_consume_future_request(endpoint, consumer, dispatcher): }) handler.assert_called_once_with({'key': 'value'}) - await_assertion(lambda: consumer.assert_called_once_with({ + await asyncio.sleep(0.5) + consumer.assert_called_once_with({ 'jsonrpc': '2.0', 'id': MSG_ID, 'result': 1234 - })) + }) -def test_consume_async_request(endpoint, consumer, dispatcher): - def _async_handler(): +@pytest.mark.asyncio +async def test_consume_async_request(endpoint, consumer, dispatcher): + async def _async_handler(*args, **kwargs): # pylint: disable=unused-argument return 1234 - handler = mock.Mock(return_value=_async_handler) + + handler = mock.Mock(wraps=_async_handler) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -222,25 +252,29 @@ def _async_handler(): }) handler.assert_called_once_with({'key': 'value'}) - await_assertion(lambda: consumer.assert_called_once_with({ + await asyncio.sleep(0.5) + consumer.assert_called_once_with({ 'jsonrpc': '2.0', 'id': MSG_ID, 'result': 1234 - })) + }) +@pytest.mark.asyncio @pytest.mark.parametrize('exc_type, error', [ (ValueError, exceptions.JsonRpcInternalError(message='ValueError')), (KeyError, exceptions.JsonRpcInternalError(message='KeyError')), (exceptions.JsonRpcMethodNotFound, exceptions.JsonRpcMethodNotFound()), ]) -def test_consume_async_request_error(exc_type, error, endpoint, consumer, dispatcher): - def _async_handler(): +async def test_consume_async_request_error(exc_type, error, endpoint: Endpoint, + consumer, dispatcher): + async def _async_handler(*args, **kwargs): raise exc_type() - handler = mock.Mock(return_value=_async_handler) + + handler = mock.Mock(wraps=_async_handler) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -248,11 +282,13 @@ def _async_handler(): }) handler.assert_called_once_with({'key': 'value'}) - await_assertion(lambda: assert_consumer_error(consumer, error)) + await asyncio.sleep(0.5) + assert_consumer_error(consumer, error) -def test_consume_request_method_not_found(endpoint, consumer): - endpoint.consume({ +@pytest.mark.asyncio +async def test_consume_request_method_not_found(endpoint, consumer): + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -261,16 +297,17 @@ def test_consume_request_method_not_found(endpoint, consumer): assert_consumer_error(consumer, exceptions.JsonRpcMethodNotFound.of('methodName')) +@pytest.mark.asyncio @pytest.mark.parametrize('exc_type, error', [ (ValueError, exceptions.JsonRpcInternalError(message='ValueError')), (KeyError, exceptions.JsonRpcInternalError(message='KeyError')), (exceptions.JsonRpcMethodNotFound, exceptions.JsonRpcMethodNotFound()), ]) -def test_consume_request_error(exc_type, error, endpoint, consumer, dispatcher): +async def test_consume_request_error(exc_type, error, endpoint, consumer, dispatcher): handler = mock.Mock(side_effect=exc_type) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -278,16 +315,18 @@ def test_consume_request_error(exc_type, error, endpoint, consumer, dispatcher): }) handler.assert_called_once_with({'key': 'value'}) - await_assertion(lambda: assert_consumer_error(consumer, error)) + assert_consumer_error(consumer, error) + +@pytest.mark.asyncio +async def test_consume_request_cancel(endpoint, dispatcher): + async def async_handler(): + await asyncio.sleep(3) -def test_consume_request_cancel(endpoint, dispatcher): - def async_handler(): - time.sleep(3) - handler = mock.Mock(return_value=async_handler) + handler = mock.Mock(wraps=async_handler) dispatcher['methodName'] = handler - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'id': MSG_ID, 'method': 'methodName', @@ -295,7 +334,7 @@ def async_handler(): }) handler.assert_called_once_with({'key': 'value'}) - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': '$/cancelRequest', 'params': {'id': MSG_ID} @@ -309,9 +348,10 @@ def async_handler(): # }) -def test_consume_request_cancel_unknown(endpoint): +@pytest.mark.asyncio +async def test_consume_request_cancel_unknown(endpoint): # Verify consume doesn't throw - endpoint.consume({ + await endpoint.consume({ 'jsonrpc': '2.0', 'method': '$/cancelRequest', 'params': {'id': 'unknown identifier'} diff --git a/test/test_streams.py b/test/test_streams.py index 6985aec..69addf4 100644 --- a/test/test_streams.py +++ b/test/test_streams.py @@ -1,6 +1,7 @@ # Copyright 2018 Palantir Technologies, Inc. # pylint: disable=redefined-outer-name from io import BytesIO +import asyncio import datetime import os import sys @@ -10,37 +11,126 @@ from pyls_jsonrpc.streams import JsonRpcStreamReader, JsonRpcStreamWriter +async def async_magic(): + pass + +mock.MagicMock.__await__ = lambda x: async_magic().__await__() +mock.Mock.__await__ = lambda x: async_magic().__await__() + + +async def stdio(loop=None): + if loop is None: + loop = asyncio.get_event_loop() + + return _wrap_file(loop) + + +def _wrap_file(loop): + # BytesIO cannot behave like a proper pipe/socket, thus we need to wrap it + # to use an executor to read from stdio and write to stdout + # note: if nothing ever drains the writer explicitly, + # no flushing ever takes place! + class InFileReader: + def __init__(self): + self.stdin = BytesIO() + self.eof = False + + async def read(self, num_bytes): + return self.stdin.read(num_bytes) + + def write(self, buf): + self.stdin.write(buf) + + async def drain(self): + pass + + def seek(self, offset): + self.stdin.seek(offset) + + def feed_eof(self): + self.eof = True + self.stdin.close() + + def at_eof(self): + return self.eof + + async def readline(self): + # a single call to sys.stdin.readline() is thread-safe + return await loop.run_in_executor(None, self.stdin.readline) + + class OutFileReader: + def __init__(self): + self.buffer = [] + self.pipe = BytesIO() + # self.stdout = self.pipe.getbuffer() + + def write(self, data): + self.buffer.append(data) + + def close(self): + self.pipe.close() + + def seek(self, byte): + self.pipe.seek(byte) + + async def read(self, num_bytes): + return self.pipe.read(num_bytes) + + def is_closing(self): + return self.pipe.closed + + async def drain(self): + data, self.buffer = self.buffer, [] + data = b''.join(data) + # a single call to sys.stdout.writelines() is thread-safe + return await loop.run_in_executor(None, self.pipe.write, data) + + def getvalue(self): + return self.pipe.getvalue() + + return InFileReader(), OutFileReader() + + +@pytest.fixture() +async def reader_writer(event_loop): + return await stdio(loop=event_loop) + + @pytest.fixture() -def rfile(): - return BytesIO() +async def rfile(reader_writer): + reader, _ = reader_writer + return reader @pytest.fixture() -def wfile(): - return BytesIO() +async def wfile(reader_writer): + _, writer = reader_writer + return writer @pytest.fixture() -def reader(rfile): - return JsonRpcStreamReader(rfile) +async def reader(rfile, event_loop): + return JsonRpcStreamReader(rfile, loop=event_loop) @pytest.fixture() -def writer(wfile): - return JsonRpcStreamWriter(wfile, sort_keys=True) +async def writer(wfile, event_loop): + return JsonRpcStreamWriter(wfile, loop=event_loop, sort_keys=True) -def test_reader(rfile, reader): +@pytest.mark.asyncio +async def test_reader(rfile, reader): rfile.write( b'Content-Length: 49\r\n' b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' b'\r\n' b'{"id": "hello", "method": "method", "params": {}}' ) + await rfile.drain() rfile.seek(0) consumer = mock.Mock() - reader.listen(consumer) + await reader.listen(consumer) consumer.assert_called_once_with({ 'id': 'hello', @@ -49,17 +139,19 @@ def test_reader(rfile, reader): }) -def test_reader_bad_message(rfile, reader): +@pytest.mark.asyncio +async def test_reader_bad_message(rfile, reader): rfile.write(b'Hello world') rfile.seek(0) # Ensure the listener doesn't throw consumer = mock.Mock() - reader.listen(consumer) + await reader.listen(consumer) consumer.assert_not_called() -def test_reader_bad_json(rfile, reader): +@pytest.mark.asyncio +async def test_reader_bad_json(rfile, reader): rfile.write( b'Content-Length: 8\r\n' b'Content-Type: application/vscode-jsonrpc; charset=utf8\r\n' @@ -70,12 +162,13 @@ def test_reader_bad_json(rfile, reader): # Ensure the listener doesn't throw consumer = mock.Mock() - reader.listen(consumer) + await reader.listen(consumer) consumer.assert_not_called() -def test_writer(wfile, writer): - writer.write({ +@pytest.mark.asyncio +async def test_writer(wfile, writer): + await writer.write({ 'id': 'hello', 'method': 'method', 'params': {} @@ -107,12 +200,13 @@ def __json__(self): return '{0}'.format(dif) -def test_writer_bad_message(wfile, writer): +@pytest.mark.asyncio +async def test_writer_bad_message(wfile, writer): # A datetime isn't serializable(or poorly serializable), # ensure the write method doesn't throw, but the result could be empty # or the correct datetime datetime.datetime = JsonDatetime - writer.write(datetime.datetime( + await writer.write(datetime.datetime( year=2019, month=1, day=1,