diff --git a/.gitignore b/.gitignore index db3d538..a06dc67 100755 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,10 @@ _build .idea .vscode *~ + +# tox-specific files +.tox +build + +# coverage-specific files +.coverage diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 70ade69..e2c8831 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -39,4 +39,4 @@ repos: types: [python] files: "^tests/" args: - - --disable=missing-docstring,consider-using-f-string,duplicate-code + - --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code diff --git a/adafruit_requests.py b/adafruit_requests.py index 26b40f0..e152ba5 100644 --- a/adafruit_requests.py +++ b/adafruit_requests.py @@ -31,6 +31,9 @@ * Adafruit CircuitPython firmware for the supported boards: https://github.com/adafruit/circuitpython/releases +* Adafruit's Connection Manager library: + https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager + """ __version__ = "0.0.0+auto.0" @@ -41,93 +44,16 @@ import json as json_module -if not sys.implementation.name == "circuitpython": - from ssl import SSLContext - from types import ModuleType, TracebackType - from typing import Any, Dict, Optional, Tuple, Type, Union - - try: - from typing import Protocol - except ImportError: - from typing_extensions import Protocol - - # Based on https://github.com/python/typeshed/blob/master/stdlib/_socket.pyi - class CommonSocketType(Protocol): - """Describes the common structure every socket type must have.""" - - def send(self, data: bytes, flags: int = ...) -> None: - """Send data to the socket. The meaning of the optional flags kwarg is - implementation-specific.""" - - def settimeout(self, value: Optional[float]) -> None: - """Set a timeout on blocking socket operations.""" - - def close(self) -> None: - """Close the socket.""" - - class CommonCircuitPythonSocketType(CommonSocketType, Protocol): - """Describes the common structure every CircuitPython socket type must have.""" - - def connect( - self, - address: Tuple[str, int], - conntype: Optional[int] = ..., - ) -> None: - """Connect to a remote socket at the provided (host, port) address. The conntype - kwarg optionally may indicate SSL or not, depending on the underlying interface. - """ - - class SupportsRecvWithFlags(Protocol): - """Describes a type that posseses a socket recv() method supporting the flags kwarg.""" - - def recv(self, bufsize: int = ..., flags: int = ...) -> bytes: - """Receive data from the socket. The return value is a bytes object representing - the data received. The maximum amount of data to be received at once is specified - by bufsize. The meaning of the optional flags kwarg is implementation-specific. - """ - - class SupportsRecvInto(Protocol): - """Describes a type that possesses a socket recv_into() method.""" - - def recv_into( - self, buffer: bytearray, nbytes: int = ..., flags: int = ... - ) -> int: - """Receive up to nbytes bytes from the socket, storing the data into the provided - buffer. If nbytes is not specified (or 0), receive up to the size available in the - given buffer. The meaning of the optional flags kwarg is implementation-specific. - Returns the number of bytes received.""" - - class CircuitPythonSocketType( - CommonCircuitPythonSocketType, - SupportsRecvInto, - SupportsRecvWithFlags, - Protocol, - ): # pylint: disable=too-many-ancestors - """Describes the structure every modern CircuitPython socket type must have.""" - - class StandardPythonSocketType( - CommonSocketType, SupportsRecvInto, SupportsRecvWithFlags, Protocol - ): - """Describes the structure every standard Python socket type must have.""" - - def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None: - """Connect to a remote socket at the provided address.""" - - SocketType = Union[ - CircuitPythonSocketType, - StandardPythonSocketType, - ] - - SocketpoolModuleType = ModuleType +from adafruit_connection_manager import get_connection_manager - class InterfaceType(Protocol): - """Describes the structure every interface type must have.""" - - @property - def TLS_MODE(self) -> int: # pylint: disable=invalid-name - """Constant representing that a socket's connection mode is TLS.""" - - SSLContextType = Union[SSLContext, "_FakeSSLContext"] +if not sys.implementation.name == "circuitpython": + from types import TracebackType + from typing import Any, Dict, Optional, Type + from adafruit_connection_manager import ( + SocketType, + SocketpoolModuleType, + SSLContextType, + ) class _RawResponse: @@ -176,7 +102,7 @@ def __init__(self, sock: SocketType, session: Optional["Session"] = None) -> Non http = self._readto(b" ") if not http: if session: - session._close_socket(self.socket) + session._connection_manager.close_socket(self.socket) else: self.socket.close() raise RuntimeError("Unable to read HTTP response.") @@ -320,7 +246,8 @@ def close(self) -> None: self._throw_away(chunk_size + 2) self._parse_headers() if self._session: - self._session._free_socket(self.socket) # pylint: disable=protected-access + # pylint: disable=protected-access + self._session._connection_manager.free_socket(self.socket) else: self.socket.close() self.socket = None @@ -429,6 +356,9 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt self.close() +_global_session = None # pylint: disable=invalid-name + + class Session: """HTTP session that shares sockets and ssl context.""" @@ -436,98 +366,16 @@ def __init__( self, socket_pool: SocketpoolModuleType, ssl_context: Optional[SSLContextType] = None, + set_global_session: bool = True, ) -> None: - self._socket_pool = socket_pool + self._connection_manager = get_connection_manager(socket_pool) self._ssl_context = ssl_context - # Hang onto open sockets so that we can reuse them. - self._open_sockets = {} - self._socket_free = {} self._last_response = None - def _free_socket(self, socket: SocketType) -> None: - if socket not in self._open_sockets.values(): - raise RuntimeError("Socket not from session") - self._socket_free[socket] = True - - def _close_socket(self, sock: SocketType) -> None: - sock.close() - del self._socket_free[sock] - key = None - for k in self._open_sockets: # pylint: disable=consider-using-dict-items - if self._open_sockets[k] == sock: - key = k - break - if key: - del self._open_sockets[key] - - def _free_sockets(self) -> None: - free_sockets = [] - for sock, val in self._socket_free.items(): - if val: - free_sockets.append(sock) - for sock in free_sockets: - self._close_socket(sock) - - def _get_socket( - self, host: str, port: int, proto: str, *, timeout: float = 1 - ) -> CircuitPythonSocketType: - # pylint: disable=too-many-branches - key = (host, port, proto) - if key in self._open_sockets: - sock = self._open_sockets[key] - if self._socket_free[sock]: - self._socket_free[sock] = False - return sock - if proto == "https:" and not self._ssl_context: - raise RuntimeError( - "ssl_context must be set before using adafruit_requests for https" - ) - addr_info = self._socket_pool.getaddrinfo( - host, port, 0, self._socket_pool.SOCK_STREAM - )[0] - retry_count = 0 - sock = None - last_exc = None - while retry_count < 5 and sock is None: - if retry_count > 0: - if any(self._socket_free.items()): - self._free_sockets() - else: - raise RuntimeError("Sending request failed") from last_exc - retry_count += 1 - - try: - sock = self._socket_pool.socket(addr_info[0], addr_info[1]) - except OSError as exc: - last_exc = exc - continue - except RuntimeError as exc: - last_exc = exc - continue - - connect_host = addr_info[-1][0] - if proto == "https:": - sock = self._ssl_context.wrap_socket(sock, server_hostname=host) - connect_host = host - sock.settimeout(timeout) # socket read timeout - - try: - sock.connect((connect_host, port)) - except MemoryError as exc: - last_exc = exc - sock.close() - sock = None - except OSError as exc: - last_exc = exc - sock.close() - sock = None - - if sock is None: - raise RuntimeError("Repeated socket failures") from last_exc - - self._open_sockets[key] = sock - self._socket_free[sock] = False - return sock + if set_global_session: + # pylint: disable=global-statement + global _global_session + _global_session = self @staticmethod def _send(socket: SocketType, data: bytes): @@ -668,7 +516,9 @@ def request( last_exc = None while retry_count < 2: retry_count += 1 - socket = self._get_socket(host, port, proto, timeout=timeout) + socket = self._connection_manager.get_socket( + host, port, proto, timeout=timeout, ssl_context=self._ssl_context + ) ok = True try: self._send_request(socket, host, method, path, headers, data, json) @@ -689,7 +539,7 @@ def request( if result == b"H": # Things seem to be ok so break with socket set. break - self._close_socket(socket) + self._connection_manager.close_socket(socket) socket = None if not socket: @@ -748,54 +598,6 @@ def delete(self, url: str, **kw) -> Response: return self.request("DELETE", url, **kw) -# Backwards compatible API: - -_default_session = None # pylint: disable=invalid-name - - -class _FakeSSLSocket: - def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None: - self._socket = socket - self._mode = tls_mode - self.settimeout = socket.settimeout - self.send = socket.send - self.recv = socket.recv - self.close = socket.close - self.recv_into = socket.recv_into - - def connect(self, address: Tuple[str, int]) -> None: - """connect wrapper to add non-standard mode parameter""" - try: - return self._socket.connect(address, self._mode) - except RuntimeError as error: - raise OSError(errno.ENOMEM) from error - - -class _FakeSSLContext: - def __init__(self, iface: InterfaceType) -> None: - self._iface = iface - - def wrap_socket( - self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None - ) -> _FakeSSLSocket: - """Return the same socket""" - # pylint: disable=unused-argument - return _FakeSSLSocket(socket, self._iface.TLS_MODE) - - -def set_socket( - sock: SocketpoolModuleType, iface: Optional[InterfaceType] = None -) -> None: - """Legacy API for setting the socket and network interface. Use a `Session` instead.""" - global _default_session # pylint: disable=global-statement,invalid-name - if not iface: - # pylint: disable=protected-access - _default_session = Session(sock, _FakeSSLContext(sock._the_interface)) - else: - _default_session = Session(sock, _FakeSSLContext(iface)) - sock.set_interface(iface) - - def request( method: str, url: str, @@ -807,7 +609,7 @@ def request( ) -> None: """Send HTTP request""" # pylint: disable=too-many-arguments - _default_session.request( + _global_session.request( method, url, data=data, @@ -820,29 +622,29 @@ def request( def head(url: str, **kw): """Send HTTP HEAD request""" - return _default_session.request("HEAD", url, **kw) + return _global_session.request("HEAD", url, **kw) def get(url: str, **kw): """Send HTTP GET request""" - return _default_session.request("GET", url, **kw) + return _global_session.request("GET", url, **kw) def post(url: str, **kw): """Send HTTP POST request""" - return _default_session.request("POST", url, **kw) + return _global_session.request("POST", url, **kw) def put(url: str, **kw): """Send HTTP PUT request""" - return _default_session.request("PUT", url, **kw) + return _global_session.request("PUT", url, **kw) def patch(url: str, **kw): """Send HTTP PATCH request""" - return _default_session.request("PATCH", url, **kw) + return _global_session.request("PATCH", url, **kw) def delete(url: str, **kw): """Send HTTP DELETE request""" - return _default_session.request("DELETE", url, **kw) + return _global_session.request("DELETE", url, **kw) diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..376dd7d --- /dev/null +++ b/conftest.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2023 Justin Myers for Adafruit Industries +# +# SPDX-License-Identifier: Unlicense + +""" PyTest Setup """ + +import pytest +import adafruit_connection_manager + + +@pytest.fixture(autouse=True) +def reset_connection_manager(monkeypatch): + """Reset the ConnectionManager, since it's a singlton and will hold data""" + monkeypatch.setattr( + "adafruit_requests.get_connection_manager", + adafruit_connection_manager.ConnectionManager, + ) diff --git a/requirements.txt b/requirements.txt index 7a984a4..d83a678 100755 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Unlicense Adafruit-Blinka +Adafruit-Circuitpython-ConnectionManager@git+https://github.com/justmobilize/Adafruit_CircuitPython_ConnectionManager@connection-manager diff --git a/tests/concurrent_test.py b/tests/concurrent_test.py index 79a32c5..ec972ef 100644 --- a/tests/concurrent_test.py +++ b/tests/concurrent_test.py @@ -17,7 +17,7 @@ RESPONSE = b"HTTP/1.0 200 OK\r\nContent-Length: 70\r\n\r\n" + TEXT -def test_second_connect_fails_memoryerror(): # pylint: disable=invalid-name +def test_second_connect_fails_memoryerror(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),) sock = mocket.Mocket(RESPONSE) @@ -60,7 +60,7 @@ def test_second_connect_fails_memoryerror(): # pylint: disable=invalid-name assert pool.socket.call_count == 3 -def test_second_connect_fails_oserror(): # pylint: disable=invalid-name +def test_second_connect_fails_oserror(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),) sock = mocket.Mocket(RESPONSE) diff --git a/tests/reuse_test.py b/tests/reuse_test.py index b768a58..b778c0a 100644 --- a/tests/reuse_test.py +++ b/tests/reuse_test.py @@ -209,7 +209,7 @@ def test_second_send_fails(): assert pool.socket.call_count == 2 -def test_second_send_lies_recv_fails(): # pylint: disable=invalid-name +def test_second_send_lies_recv_fails(): pool = mocket.MocketPool() pool.getaddrinfo.return_value = ((None, None, None, None, (IP, 80)),) sock = mocket.Mocket(RESPONSE) diff --git a/tox.ini b/tox.ini index ab2df5e..9d1910d 100644 --- a/tox.ini +++ b/tox.ini @@ -3,9 +3,36 @@ # SPDX-License-Identifier: MIT [tox] -envlist = py38 +envlist = py311 [testenv] -changedir = {toxinidir}/tests -deps = pytest==6.2.5 +description = run tests +deps = + pytest==7.4.3 commands = pytest + +[testenv:coverage] +description = run coverage +deps = + pytest==7.4.3 + pytest-cov==4.1.0 +package = editable +commands = + coverage run --source=. --omit=tests/* --branch {posargs} -m pytest + coverage report + coverage html + +[testenv:lint] +description = run linters +deps = + pre-commit==3.6.0 +skip_install = true +commands = pre-commit run {posargs} + +[testenv:docs] +description = build docs +deps = + -r requirements.txt + -r docs/requirements.txt +skip_install = true +commands = sphinx-build -E -W -b html docs/. _build/html