diff --git a/.circleci/config.yml b/.circleci/config.yml index a6c85d237..0877c161a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -29,6 +29,15 @@ jobs: - checkout - run: sudo pip install tox - run: tox -e py37 + py38: + docker: + - image: circleci/python:3.8.0rc1 + steps: + # Remove IPv6 entry for localhost in Circle CI containers because it doesn't work anyway. + - run: sudo cp /etc/hosts /tmp; sudo sed -i '/::1/d' /tmp/hosts; sudo cp /tmp/hosts /etc + - checkout + - run: sudo pip install tox + - run: tox -e py38 workflows: version: 2 @@ -41,3 +50,6 @@ workflows: - py37: requires: - main + - py38: + requires: + - main diff --git a/docs/changelog.rst b/docs/changelog.rst index 87b2e4380..2a106fbc0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog *In development* +* Added compatibility with Python 3.8. + 8.0.2 ..... diff --git a/setup.py b/setup.py index c76430104..f35819247 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', ], package_dir = {'': 'src'}, package_data = {'websockets': ['py.typed']}, diff --git a/src/websockets/__init__.py b/src/websockets/__init__.py index e7ba31ce5..6bad0f7bc 100644 --- a/src/websockets/__init__.py +++ b/src/websockets/__init__.py @@ -1,12 +1,13 @@ # This relies on each of the submodules having an __all__ variable. -from .auth import * -from .client import * -from .exceptions import * -from .protocol import * -from .server import * -from .typing import * -from .uri import * +from . import auth, client, exceptions, protocol, server, typing, uri +from .auth import * # noqa +from .client import * # noqa +from .exceptions import * # noqa +from .protocol import * # noqa +from .server import * # noqa +from .typing import * # noqa +from .uri import * # noqa from .version import version as __version__ # noqa diff --git a/src/websockets/__main__.py b/src/websockets/__main__.py index bccb8aa52..394f7ac79 100644 --- a/src/websockets/__main__.py +++ b/src/websockets/__main__.py @@ -6,8 +6,8 @@ import threading from typing import Any, Set -import websockets -from websockets.exceptions import format_close +from .client import connect +from .exceptions import ConnectionClosed, format_close if sys.platform == "win32": @@ -95,7 +95,7 @@ async def run_client( stop: "asyncio.Future[None]", ) -> None: try: - websocket = await websockets.connect(uri) + websocket = await connect(uri) except Exception as exc: print_over_input(f"Failed to connect to {uri}: {exc}.") exit_from_event_loop_thread(loop, stop) @@ -122,7 +122,7 @@ async def run_client( if incoming in done: try: message = incoming.result() - except websockets.ConnectionClosed: + except ConnectionClosed: break else: if isinstance(message, str): diff --git a/src/websockets/client.py b/src/websockets/client.py index c1fdf88a0..725ec1e7a 100644 --- a/src/websockets/client.py +++ b/src/websockets/client.py @@ -24,7 +24,6 @@ from .extensions.permessage_deflate import ClientPerMessageDeflateFactory from .handshake import build_request, check_response from .headers import ( - ExtensionHeader, build_authorization_basic, build_extension, build_subprotocol, @@ -33,7 +32,7 @@ ) from .http import USER_AGENT, Headers, HeadersLike, read_response from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol from .uri import WebSocketURI, parse_uri @@ -85,7 +84,7 @@ def write_http_request(self, path: str, headers: Headers) -> None: request = f"GET {path} HTTP/1.1\r\n" request += str(headers) - self.writer.write(request.encode()) + self.transport.write(request.encode()) async def read_http_response(self) -> Tuple[int, Headers]: """ diff --git a/src/websockets/framing.py b/src/websockets/framing.py index 81a3185b0..c24b8a73d 100644 --- a/src/websockets/framing.py +++ b/src/websockets/framing.py @@ -147,7 +147,7 @@ async def read( def write( frame, - writer: Callable[[bytes], Any], + write: Callable[[bytes], Any], *, mask: bool, extensions: Optional[Sequence["websockets.extensions.base.Extension"]] = None, @@ -156,7 +156,7 @@ def write( Write a WebSocket frame. :param frame: frame to write - :param writer: function that writes bytes + :param write: function that writes bytes :param mask: whether the frame should be masked i.e. whether the write happens on the client side :param extensions: list of classes with an ``encode()`` method that @@ -210,10 +210,10 @@ def write( # Send the frame. - # The frame is written in a single call to writer in order to prevent + # The frame is written in a single call to write in order to prevent # TCP fragmentation. See #68 for details. This also makes it safe to # send frames concurrently from multiple coroutines. - writer(output.getvalue()) + write(output.getvalue()) def check(frame) -> None: """ diff --git a/src/websockets/handshake.py b/src/websockets/handshake.py index 17332d155..9bfe27754 100644 --- a/src/websockets/handshake.py +++ b/src/websockets/handshake.py @@ -29,9 +29,10 @@ import binascii import hashlib import random +from typing import List from .exceptions import InvalidHeader, InvalidHeaderValue, InvalidUpgrade -from .headers import parse_connection, parse_upgrade +from .headers import ConnectionOption, UpgradeProtocol, parse_connection, parse_upgrade from .http import Headers, MultipleValuesError @@ -74,14 +75,16 @@ def check_request(headers: Headers) -> str: is invalid; then the server must return 400 Bad Request error """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", ", ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. @@ -148,14 +151,16 @@ def check_response(headers: Headers, key: str) -> None: is invalid """ - connection = sum( + connection: List[ConnectionOption] = sum( [parse_connection(value) for value in headers.get_all("Connection")], [] ) if not any(value.lower() == "upgrade" for value in connection): raise InvalidUpgrade("Connection", " ".join(connection)) - upgrade = sum([parse_upgrade(value) for value in headers.get_all("Upgrade")], []) + upgrade: List[UpgradeProtocol] = sum( + [parse_upgrade(value) for value in headers.get_all("Upgrade")], [] + ) # For compatibility with non-strict implementations, ignore case when # checking the Upgrade header. It's supposed to be 'WebSocket'. diff --git a/src/websockets/protocol.py b/src/websockets/protocol.py index 1f0edcce2..6c29b2a52 100644 --- a/src/websockets/protocol.py +++ b/src/websockets/protocol.py @@ -14,6 +14,7 @@ import logging import random import struct +import sys import warnings from typing import ( Any, @@ -61,7 +62,7 @@ class State(enum.IntEnum): # between the check and the assignment. -class WebSocketCommonProtocol(asyncio.StreamReaderProtocol): +class WebSocketCommonProtocol(asyncio.Protocol): """ :class:`~asyncio.Protocol` subclass implementing the data transfer phase. @@ -212,8 +213,6 @@ def __init__( self.read_limit = read_limit self.write_limit = write_limit - # Store a reference to loop to avoid relying on self._loop, a private - # attribute of StreamReaderProtocol, inherited from FlowControlMixin. if loop is None: loop = asyncio.get_event_loop() self.loop = loop @@ -227,12 +226,15 @@ def __init__( # ``self.read_limit``. The ``limit`` argument controls the line length # limit and half the buffer limit of :class:`~asyncio.StreamReader`. # That's why it must be set to half of ``self.read_limit``. - stream_reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - super().__init__(stream_reader, self.client_connected, loop) + self.reader = asyncio.StreamReader(limit=read_limit // 2, loop=loop) - self.reader: asyncio.StreamReader - self.writer: asyncio.StreamWriter - self._drain_lock = asyncio.Lock(loop=loop) + # Copied from asyncio.FlowControlMixin + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + + self._drain_lock = asyncio.Lock( + loop=loop if sys.version_info[:2] < (3, 8) else None + ) # This class implements the data transfer and closing handshake, which # are shared between the client-side and the server-side. @@ -284,19 +286,36 @@ def __init__( # Task closing the TCP connection. self.close_connection_task: asyncio.Task[None] - def client_connected( - self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ) -> None: - """ - Callback when the TCP connection is established. - - Record references to the stream reader and the stream writer to avoid - using private attributes ``_stream_reader`` and ``_stream_writer`` of - :class:`~asyncio.StreamReaderProtocol`. - - """ - self.reader = reader - self.writer = writer + # Copied from asyncio.FlowControlMixin + async def _drain_helper(self) -> None: # pragma: no cover + if self.connection_lost_waiter.done(): + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = self.loop.create_future() + self._drain_waiter = waiter + await waiter + + # Copied from asyncio.StreamWriter + async def _drain(self) -> None: # pragma: no cover + if self.reader is not None: + exc = self.reader.exception() + if exc is not None: + raise exc + if self.transport is not None: + if self.transport.is_closing(): + # Yield to the event loop so connection_lost() may be + # called. Without this, _drain_helper() would return + # immediately, and code that calls + # write(...); yield from drain() + # in a loop would never call connection_lost(), so it + # would not see an error when the socket is closed. + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) + await self._drain_helper() def connection_open(self) -> None: """ @@ -344,9 +363,12 @@ def local_address(self) -> Any: been established yet. """ - if self.writer is None: + try: + transport = self.transport + except AttributeError: return None - return self.writer.get_extra_info("sockname") + else: + return transport.get_extra_info("sockname") @property def remote_address(self) -> Any: @@ -357,9 +379,12 @@ def remote_address(self) -> Any: been established yet. """ - if self.writer is None: + try: + transport = self.transport + except AttributeError: return None - return self.writer.get_extra_info("peername") + else: + return transport.get_extra_info("peername") @property def open(self) -> bool: @@ -466,7 +491,7 @@ async def recv(self) -> Data: # pop_message_waiter and self.transfer_data_task. await asyncio.wait( [pop_message_waiter, self.transfer_data_task], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, return_when=asyncio.FIRST_COMPLETED, ) finally: @@ -584,10 +609,14 @@ async def send( elif isinstance(message, AsyncIterable): # aiter_message = aiter(message) without aiter - aiter_message = type(message).__aiter__(message) + # https://github.com/python/mypy/issues/5738 + aiter_message = type(message).__aiter__(message) # type: ignore try: # message_chunk = anext(aiter_message) without anext - message_chunk = await type(aiter_message).__anext__(aiter_message) + # https://github.com/python/mypy/issues/5738 + message_chunk = await type(aiter_message).__anext__( # type: ignore + aiter_message + ) except StopAsyncIteration: return opcode, data = prepare_data(message_chunk) @@ -598,7 +627,8 @@ async def send( await self.write_frame(False, opcode, data) # Other fragments. - async for message_chunk in aiter_message: + # https://github.com/python/mypy/issues/5738 + async for message_chunk in aiter_message: # type: ignore confirm_opcode, data = prepare_data(message_chunk) if confirm_opcode != opcode: raise TypeError("data contains inconsistent types") @@ -646,7 +676,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: await asyncio.wait_for( self.write_close_frame(serialize_close(code, reason)), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: # If the close frame cannot be sent because the send buffers @@ -665,7 +695,9 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # If close() is canceled during the wait, self.transfer_data_task # is canceled before the timeout elapses. await asyncio.wait_for( - self.transfer_data_task, self.close_timeout, loop=self.loop + self.transfer_data_task, + self.close_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except (asyncio.TimeoutError, asyncio.CancelledError): pass @@ -882,8 +914,7 @@ async def read_message(self) -> Optional[Data]: max_size = self.max_size if text: decoder_factory = codecs.getincrementaldecoder("utf-8") - # https://github.com/python/typeshed/pull/2752 - decoder = decoder_factory(errors="strict") # type: ignore + decoder = decoder_factory(errors="strict") if max_size is None: def append(frame: Frame) -> None: @@ -1033,7 +1064,9 @@ async def write_frame( frame = Frame(fin, opcode, data) logger.debug("%s > %r", self.side, frame) - frame.write(self.writer.write, mask=self.is_client, extensions=self.extensions) + frame.write( + self.transport.write, mask=self.is_client, extensions=self.extensions + ) try: # drain() cannot be called concurrently by multiple coroutines: @@ -1041,7 +1074,7 @@ async def write_frame( # version of Python where this bugs exists is supported anymore. async with self._drain_lock: # Handle flow control automatically. - await self.writer.drain() + await self._drain() except ConnectionError: # Terminate the connection if the socket died. self.fail_connection() @@ -1083,7 +1116,10 @@ async def keepalive_ping(self) -> None: try: while True: - await asyncio.sleep(self.ping_interval, loop=self.loop) + await asyncio.sleep( + self.ping_interval, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, + ) # ping() raises CancelledError if the connection is closed, # when close_connection() cancels self.keepalive_ping_task. @@ -1096,7 +1132,9 @@ async def keepalive_ping(self) -> None: if self.ping_timeout is not None: try: await asyncio.wait_for( - ping_waiter, self.ping_timeout, loop=self.loop + ping_waiter, + self.ping_timeout, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: logger.debug("%s ! timed out waiting for pong", self.side) @@ -1143,9 +1181,9 @@ async def close_connection(self) -> None: logger.debug("%s ! timed out waiting for TCP close", self.side) # Half-close the TCP connection if possible (when there's no TLS). - if self.writer.can_write_eof(): + if self.transport.can_write_eof(): logger.debug("%s x half-closing TCP connection", self.side) - self.writer.write_eof() + self.transport.write_eof() if await self.wait_for_connection_lost(): return @@ -1158,17 +1196,12 @@ async def close_connection(self) -> None: # If connection_lost() was called, the TCP connection is closed. # However, if TLS is enabled, the transport still needs closing. # Else asyncio complains: ResourceWarning: unclosed transport. - try: - writer_is_closing = self.writer.is_closing # type: ignore - except AttributeError: # pragma: no cover - # Python < 3.7 - writer_is_closing = self.writer.transport.is_closing - if self.connection_lost_waiter.done() and writer_is_closing(): + if self.connection_lost_waiter.done() and self.transport.is_closing(): return # Close the TCP connection. Buffers are flushed asynchronously. logger.debug("%s x closing TCP connection", self.side) - self.writer.close() + self.transport.close() if await self.wait_for_connection_lost(): return @@ -1176,8 +1209,7 @@ async def close_connection(self) -> None: # Abort the TCP connection. Buffers are discarded. logger.debug("%s x aborting TCP connection", self.side) - # mypy thinks self.writer.transport is a BaseTransport, not a Transport. - self.writer.transport.abort() # type: ignore + self.transport.abort() # connection_lost() is called quickly after aborting. await self.wait_for_connection_lost() @@ -1194,7 +1226,7 @@ async def wait_for_connection_lost(self) -> bool: await asyncio.wait_for( asyncio.shield(self.connection_lost_waiter), self.close_timeout, - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) except asyncio.TimeoutError: pass @@ -1257,7 +1289,7 @@ def fail_connection(self, code: int = 1006, reason: str = "") -> None: frame = Frame(True, OP_CLOSE, frame_data) logger.debug("%s > %r", self.side, frame) frame.write( - self.writer.write, mask=self.is_client, extensions=self.extensions + self.transport.write, mask=self.is_client, extensions=self.extensions ) # Start close_connection_task if the opening handshake didn't succeed. @@ -1289,7 +1321,7 @@ def abort_pings(self) -> None: "%s - aborted pending ping%s: %s", self.side, plural, pings_hex ) - # asyncio.StreamReaderProtocol methods + # asyncio.Protocol methods def connection_made(self, transport: asyncio.BaseTransport) -> None: """ @@ -1306,36 +1338,13 @@ def connection_made(self, transport: asyncio.BaseTransport) -> None: """ logger.debug("%s - event = connection_made(%s)", self.side, transport) - # mypy thinks transport is a BaseTransport, not a Transport. - transport.set_write_buffer_limits(self.write_limit) # type: ignore - super().connection_made(transport) - - def eof_received(self) -> bool: - """ - Close the transport after receiving EOF. - - Since Python 3.5, `:meth:~StreamReaderProtocol.eof_received` returns - ``True`` on non-TLS connections. - - See http://bugs.python.org/issue24539 for more information. - - This is inappropriate for ``websockets`` for at least three reasons: - - 1. The use case is to read data until EOF with self.reader.read(-1). - Since WebSocket is a TLV protocol, this never happens. - 2. It doesn't work on TLS connections. A falsy value must be - returned to have the same behavior on TLS and plain connections. + transport = cast(asyncio.Transport, transport) + transport.set_write_buffer_limits(self.write_limit) + self.transport = transport - 3. The WebSocket protocol has its own closing handshake. Endpoints - close the TCP connection after sending a close frame. - - As a consequence we revert to the previous, more useful behavior. - - """ - logger.debug("%s - event = eof_received()", self.side) - super().eof_received() - return False + # Copied from asyncio.StreamReaderProtocol + self.reader.set_transport(transport) def connection_lost(self, exc: Optional[Exception]) -> None: """ @@ -1360,4 +1369,61 @@ def connection_lost(self, exc: Optional[Exception]) -> None: # - it's set only here in connection_lost() which is called only once; # - it must never be canceled. self.connection_lost_waiter.set_result(None) - super().connection_lost(exc) + + if True: # pragma: no cover + + # Copied from asyncio.StreamReaderProtocol + if self.reader is not None: + if exc is None: + self.reader.feed_eof() + else: + self.reader.set_exception(exc) + + # Copied from asyncio.FlowControlMixin + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + def pause_writing(self) -> None: # pragma: no cover + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: # pragma: no cover + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def data_received(self, data: bytes) -> None: + logger.debug("%s - event = data_received(<%d bytes>)", self.side, len(data)) + self.reader.feed_data(data) + + def eof_received(self) -> None: + """ + Close the transport after receiving EOF. + + The WebSocket protocol has its own closing handshake: endpoints close + the TCP or TLS connection after sending and receiving a close frame. + + As a consequence, they never need to write after receiving EOF, so + there's no reason to keep the transport open by returning ``True``. + + Besides, that doesn't work on TLS connections. + + """ + logger.debug("%s - event = eof_received()", self.side) + self.reader.feed_eof() diff --git a/src/websockets/server.py b/src/websockets/server.py index b220a1b88..4f5e9e0ef 100644 --- a/src/websockets/server.py +++ b/src/websockets/server.py @@ -10,6 +10,7 @@ import http import logging import socket +import sys import warnings from types import TracebackType from typing import ( @@ -39,15 +40,10 @@ from .extensions.base import Extension, ServerExtensionFactory from .extensions.permessage_deflate import ServerPerMessageDeflateFactory from .handshake import build_response, check_request -from .headers import ( - ExtensionHeader, - build_extension, - parse_extension, - parse_subprotocol, -) +from .headers import build_extension, parse_extension, parse_subprotocol from .http import USER_AGENT, Headers, HeadersLike, MultipleValuesError, read_request from .protocol import WebSocketCommonProtocol -from .typing import Origin, Subprotocol +from .typing import ExtensionHeader, Origin, Subprotocol __all__ = ["serve", "unix_serve", "WebSocketServerProtocol", "WebSocketServer"] @@ -211,7 +207,7 @@ async def handler(self) -> None: except Exception: # Last-ditch attempt to avoid leaking connections on errors. try: - self.writer.close() + self.transport.close() except Exception: # pragma: no cover pass @@ -265,11 +261,11 @@ def write_http_response( response = f"HTTP/1.1 {status.value} {status.phrase}\r\n" response += str(headers) - self.writer.write(response.encode()) + self.transport.write(response.encode()) if body is not None: logger.debug("%s > body (%d bytes)", self.side, len(body)) - self.writer.write(body) + self.transport.write(body) async def process_request( self, path: str, request_headers: Headers @@ -662,7 +658,7 @@ def is_serving(self) -> bool: """ try: # Python ≥ 3.7 - return self.server.is_serving() # type: ignore + return self.server.is_serving() except AttributeError: # pragma: no cover # Python < 3.7 return self.server.sockets is not None @@ -703,7 +699,9 @@ async def _close(self) -> None: # Wait until all accepted connections reach connection_made() and call # register(). See https://bugs.python.org/issue34852 for details. - await asyncio.sleep(0) + await asyncio.sleep( + 0, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) # Close OPEN connections with status code 1001. Since the server was # closed, handshake() closes OPENING conections with a HTTP 503 error. @@ -712,7 +710,8 @@ async def _close(self) -> None: # asyncio.wait doesn't accept an empty first argument if self.websockets: await asyncio.wait( - [websocket.close(1001) for websocket in self.websockets], loop=self.loop + [websocket.close(1001) for websocket in self.websockets], + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Wait until all connection handlers are complete. @@ -721,7 +720,7 @@ async def _close(self) -> None: if self.websockets: await asyncio.wait( [websocket.handler_task for websocket in self.websockets], - loop=self.loop, + loop=self.loop if sys.version_info[:2] < (3, 8) else None, ) # Tell wait_closed() to return. diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..dd78609f5 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,5 @@ +import logging + + +# Avoid displaying stack traces at the ERROR logging level. +logging.basicConfig(level=logging.CRITICAL) diff --git a/tests/test_client_server.py b/tests/test_client_server.py index e74ec6bf6..35913666c 100644 --- a/tests/test_client_server.py +++ b/tests/test_client_server.py @@ -2,7 +2,6 @@ import contextlib import functools import http -import logging import pathlib import random import socket @@ -37,10 +36,6 @@ from .utils import AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - # Generate TLS certificate with: # $ openssl req -x509 -config test_localhost.cnf -days 15340 -newkey rsa:2048 \ # -out test_localhost.crt -keyout test_localhost.key @@ -178,7 +173,7 @@ async def process_request(self, path, request_headers): return http.HTTPStatus.OK, [("X-Access", "OK")], b"status = green\n" -class SlowServerProtocol(WebSocketServerProtocol): +class SlowOpeningHandshakeProtocol(WebSocketServerProtocol): async def process_request(self, path, request_headers): await asyncio.sleep(10 * MS) @@ -1166,11 +1161,11 @@ def test_server_close_crashes(self, close): def test_client_closes_connection_before_handshake(self, handshake): # We have mocked the handshake() method to prevent the client from # performing the opening handshake. Force it to close the connection. - self.client.writer.close() + self.client.transport.close() # The server should stop properly anyway. It used to hang because the # task handling the connection was waiting for the opening handshake. - @with_server(create_protocol=SlowServerProtocol) + @with_server(create_protocol=SlowOpeningHandshakeProtocol) def test_server_shuts_down_during_opening_handshake(self): self.loop.call_later(5 * MS, self.server.close) with self.assertRaises(InvalidStatusCode) as raised: @@ -1193,20 +1188,6 @@ def test_server_shuts_down_during_connection_handling(self): self.assertEqual(self.client.close_code, 1001) self.assertEqual(server_ws.close_code, 1001) - @with_server() - @unittest.mock.patch("websockets.server.WebSocketServerProtocol.close") - def test_server_shuts_down_during_connection_close(self, _close): - _close.side_effect = asyncio.CancelledError - - self.server.closing = True - with self.temp_client(): - self.loop.run_until_complete(self.client.send("Hello!")) - reply = self.loop.run_until_complete(self.client.recv()) - self.assertEqual(reply, "Hello!") - - # Websocket connection terminates abnormally. - self.assertEqual(self.client.close_code, 1006) - @with_server() def test_server_shuts_down_waits_until_handlers_terminate(self): # This handler waits a bit after the connection is closed in order @@ -1381,13 +1362,16 @@ def test_client(self): start_server = serve(handler, "localhost", 0) server = self.loop.run_until_complete(start_server) - @asyncio.coroutine - def run_client(): - # Yield from connect. - client = yield from connect(get_server_uri(server)) - self.assertEqual(client.state, State.OPEN) - yield from client.close() - self.assertEqual(client.state, State.CLOSED) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_client(): + # Yield from connect. + client = yield from connect(get_server_uri(server)) + self.assertEqual(client.state, State.OPEN) + yield from client.close() + self.assertEqual(client.state, State.CLOSED) self.loop.run_until_complete(run_client()) @@ -1395,14 +1379,17 @@ def run_client(): self.loop.run_until_complete(server.wait_closed()) def test_server(self): - @asyncio.coroutine - def run_server(): - # Yield from serve. - server = yield from serve(handler, "localhost", 0) - self.assertTrue(server.sockets) - server.close() - yield from server.wait_closed() - self.assertFalse(server.sockets) + # @asyncio.coroutine is deprecated on Python ≥ 3.8 + with warnings.catch_warnings(record=True): + + @asyncio.coroutine + def run_server(): + # Yield from serve. + server = yield from serve(handler, "localhost", 0) + self.assertTrue(server.sockets) + server.close() + yield from server.wait_closed() + self.assertFalse(server.sockets) self.loop.run_until_complete(run_server()) diff --git a/tests/test_framing.py b/tests/test_framing.py index 9e6f1871d..5def415d2 100644 --- a/tests/test_framing.py +++ b/tests/test_framing.py @@ -27,15 +27,15 @@ def decode(self, message, mask=False, max_size=None, extensions=None): return frame def encode(self, frame, mask=False, extensions=None): - writer = unittest.mock.Mock() - frame.write(writer, mask=mask, extensions=extensions) - # Ensure the entire frame is sent with a single call to writer(). + write = unittest.mock.Mock() + frame.write(write, mask=mask, extensions=extensions) + # Ensure the entire frame is sent with a single call to write(). # Multiple calls cause TCP fragmentation and degrade performance. - self.assertEqual(writer.call_count, 1) + self.assertEqual(write.call_count, 1) # The frame data is the single positional argument of that call. - self.assertEqual(len(writer.call_args[0]), 1) - self.assertEqual(len(writer.call_args[1]), 0) - return writer.call_args[0][0] + self.assertEqual(len(write.call_args[0]), 1) + self.assertEqual(len(write.call_args[1]), 0) + return write.call_args[0][0] def round_trip(self, message, expected, mask=False, extensions=None): decoded = self.decode(message, mask, extensions=extensions) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index a6c420181..66a822e79 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,6 +1,6 @@ import asyncio import contextlib -import logging +import sys import unittest import unittest.mock import warnings @@ -12,10 +12,6 @@ from .utils import MS, AsyncioTestCase -# Avoid displaying stack traces at the ERROR logging level. -logging.basicConfig(level=logging.CRITICAL) - - async def async_iterable(iterable): for item in iterable: yield item @@ -94,16 +90,18 @@ def tearDown(self): # Utilities for writing tests. def make_drain_slow(self, delay=MS): - # Process connection_made in order to initialize self.protocol.writer. + # Process connection_made in order to initialize self.protocol.transport. self.run_loop_once() - original_drain = self.protocol.writer.drain + original_drain = self.protocol._drain async def delayed_drain(): - await asyncio.sleep(delay, loop=self.loop) + await asyncio.sleep( + delay, loop=self.loop if sys.version_info[:2] < (3, 8) else None + ) await original_drain() - self.protocol.writer.drain = delayed_drain + self.protocol._drain = delayed_drain close_frame = Frame(True, OP_CLOSE, serialize_close(1000, "close")) local_close = Frame(True, OP_CLOSE, serialize_close(1000, "local")) @@ -114,9 +112,9 @@ def receive_frame(self, frame): Make the protocol receive a frame. """ - writer = self.protocol.data_received + write = self.protocol.data_received mask = not self.protocol.is_client - frame.write(writer, mask=mask) + frame.write(write, mask=mask) def receive_eof(self): """ @@ -321,32 +319,32 @@ def test_local_address(self): self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.local_address, ("host", 4312)) - get_extra_info.assert_called_with("sockname", None) + get_extra_info.assert_called_with("sockname") def test_local_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.local_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_remote_address(self): get_extra_info = unittest.mock.Mock(return_value=("host", 4312)) self.transport.get_extra_info = get_extra_info self.assertEqual(self.protocol.remote_address, ("host", 4312)) - get_extra_info.assert_called_with("peername", None) + get_extra_info.assert_called_with("peername") def test_remote_address_before_connection(self): # Emulate the situation before connection_open() runs. - self.protocol.writer, _writer = None, self.protocol.writer - + _transport = self.protocol.transport + del self.protocol.transport try: self.assertEqual(self.protocol.remote_address, None) finally: - self.protocol.writer = _writer + self.protocol.transport = _transport def test_open(self): self.assertTrue(self.protocol.open) diff --git a/tox.ini b/tox.ini index 801d4d5d1..825e34061 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,coverage,black,flake8,isort,mypy +envlist = py36,py37,py38,coverage,black,flake8,isort,mypy [testenv] commands = python -W default -m unittest {posargs} @@ -25,4 +25,4 @@ deps = isort [testenv:mypy] commands = mypy --strict src -deps = mypy==0.670 +deps = mypy