Skip to content

Switch to using ConnectionManager #147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,10 @@ _build
.idea
.vscode
*~

# tox-specific files
.tox
build

# coverage-specific files
.coverage
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ repos:
types: [python]
files: "^tests/"
args:
- --disable=missing-docstring,consider-using-f-string,duplicate-code
- --disable=missing-docstring,invalid-name,consider-using-f-string,duplicate-code
268 changes: 35 additions & 233 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
* Adafruit CircuitPython firmware for the supported boards:
https://github.com/adafruit/circuitpython/releases

* Adafruit's Connection Manager library:
https://github.com/adafruit/Adafruit_CircuitPython_ConnectionManager

"""

__version__ = "0.0.0+auto.0"
Expand All @@ -41,93 +44,16 @@

import json as json_module

if not sys.implementation.name == "circuitpython":
from ssl import SSLContext
from types import ModuleType, TracebackType
from typing import Any, Dict, Optional, Tuple, Type, Union

try:
from typing import Protocol
except ImportError:
from typing_extensions import Protocol

# Based on https://github.com/python/typeshed/blob/master/stdlib/_socket.pyi
class CommonSocketType(Protocol):
"""Describes the common structure every socket type must have."""

def send(self, data: bytes, flags: int = ...) -> None:
"""Send data to the socket. The meaning of the optional flags kwarg is
implementation-specific."""

def settimeout(self, value: Optional[float]) -> None:
"""Set a timeout on blocking socket operations."""

def close(self) -> None:
"""Close the socket."""

class CommonCircuitPythonSocketType(CommonSocketType, Protocol):
"""Describes the common structure every CircuitPython socket type must have."""

def connect(
self,
address: Tuple[str, int],
conntype: Optional[int] = ...,
) -> None:
"""Connect to a remote socket at the provided (host, port) address. The conntype
kwarg optionally may indicate SSL or not, depending on the underlying interface.
"""

class SupportsRecvWithFlags(Protocol):
"""Describes a type that posseses a socket recv() method supporting the flags kwarg."""

def recv(self, bufsize: int = ..., flags: int = ...) -> bytes:
"""Receive data from the socket. The return value is a bytes object representing
the data received. The maximum amount of data to be received at once is specified
by bufsize. The meaning of the optional flags kwarg is implementation-specific.
"""

class SupportsRecvInto(Protocol):
"""Describes a type that possesses a socket recv_into() method."""

def recv_into(
self, buffer: bytearray, nbytes: int = ..., flags: int = ...
) -> int:
"""Receive up to nbytes bytes from the socket, storing the data into the provided
buffer. If nbytes is not specified (or 0), receive up to the size available in the
given buffer. The meaning of the optional flags kwarg is implementation-specific.
Returns the number of bytes received."""

class CircuitPythonSocketType(
CommonCircuitPythonSocketType,
SupportsRecvInto,
SupportsRecvWithFlags,
Protocol,
): # pylint: disable=too-many-ancestors
"""Describes the structure every modern CircuitPython socket type must have."""

class StandardPythonSocketType(
CommonSocketType, SupportsRecvInto, SupportsRecvWithFlags, Protocol
):
"""Describes the structure every standard Python socket type must have."""

def connect(self, address: Union[Tuple[Any, ...], str, bytes]) -> None:
"""Connect to a remote socket at the provided address."""

SocketType = Union[
CircuitPythonSocketType,
StandardPythonSocketType,
]

SocketpoolModuleType = ModuleType
from adafruit_connection_manager import get_connection_manager

class InterfaceType(Protocol):
"""Describes the structure every interface type must have."""

@property
def TLS_MODE(self) -> int: # pylint: disable=invalid-name
"""Constant representing that a socket's connection mode is TLS."""

SSLContextType = Union[SSLContext, "_FakeSSLContext"]
if not sys.implementation.name == "circuitpython":
from types import TracebackType
from typing import Any, Dict, Optional, Type
from adafruit_connection_manager import (
SocketType,
SocketpoolModuleType,
SSLContextType,
)


class _RawResponse:
Expand Down Expand Up @@ -176,7 +102,7 @@ def __init__(self, sock: SocketType, session: Optional["Session"] = None) -> Non
http = self._readto(b" ")
if not http:
if session:
session._close_socket(self.socket)
session._connection_manager.close_socket(self.socket)
else:
self.socket.close()
raise RuntimeError("Unable to read HTTP response.")
Expand Down Expand Up @@ -320,7 +246,8 @@ def close(self) -> None:
self._throw_away(chunk_size + 2)
self._parse_headers()
if self._session:
self._session._free_socket(self.socket) # pylint: disable=protected-access
# pylint: disable=protected-access
self._session._connection_manager.free_socket(self.socket)
else:
self.socket.close()
self.socket = None
Expand Down Expand Up @@ -429,105 +356,26 @@ def iter_content(self, chunk_size: int = 1, decode_unicode: bool = False) -> byt
self.close()


_global_session = None # pylint: disable=invalid-name


class Session:
"""HTTP session that shares sockets and ssl context."""

def __init__(
self,
socket_pool: SocketpoolModuleType,
ssl_context: Optional[SSLContextType] = None,
set_global_session: bool = True,
) -> None:
self._socket_pool = socket_pool
self._connection_manager = get_connection_manager(socket_pool)
self._ssl_context = ssl_context
# Hang onto open sockets so that we can reuse them.
self._open_sockets = {}
self._socket_free = {}
self._last_response = None

def _free_socket(self, socket: SocketType) -> None:
if socket not in self._open_sockets.values():
raise RuntimeError("Socket not from session")
self._socket_free[socket] = True

def _close_socket(self, sock: SocketType) -> None:
sock.close()
del self._socket_free[sock]
key = None
for k in self._open_sockets: # pylint: disable=consider-using-dict-items
if self._open_sockets[k] == sock:
key = k
break
if key:
del self._open_sockets[key]

def _free_sockets(self) -> None:
free_sockets = []
for sock, val in self._socket_free.items():
if val:
free_sockets.append(sock)
for sock in free_sockets:
self._close_socket(sock)

def _get_socket(
self, host: str, port: int, proto: str, *, timeout: float = 1
) -> CircuitPythonSocketType:
# pylint: disable=too-many-branches
key = (host, port, proto)
if key in self._open_sockets:
sock = self._open_sockets[key]
if self._socket_free[sock]:
self._socket_free[sock] = False
return sock
if proto == "https:" and not self._ssl_context:
raise RuntimeError(
"ssl_context must be set before using adafruit_requests for https"
)
addr_info = self._socket_pool.getaddrinfo(
host, port, 0, self._socket_pool.SOCK_STREAM
)[0]
retry_count = 0
sock = None
last_exc = None
while retry_count < 5 and sock is None:
if retry_count > 0:
if any(self._socket_free.items()):
self._free_sockets()
else:
raise RuntimeError("Sending request failed") from last_exc
retry_count += 1

try:
sock = self._socket_pool.socket(addr_info[0], addr_info[1])
except OSError as exc:
last_exc = exc
continue
except RuntimeError as exc:
last_exc = exc
continue

connect_host = addr_info[-1][0]
if proto == "https:":
sock = self._ssl_context.wrap_socket(sock, server_hostname=host)
connect_host = host
sock.settimeout(timeout) # socket read timeout

try:
sock.connect((connect_host, port))
except MemoryError as exc:
last_exc = exc
sock.close()
sock = None
except OSError as exc:
last_exc = exc
sock.close()
sock = None

if sock is None:
raise RuntimeError("Repeated socket failures") from last_exc

self._open_sockets[key] = sock
self._socket_free[sock] = False
return sock
if set_global_session:
# pylint: disable=global-statement
global _global_session
_global_session = self

@staticmethod
def _send(socket: SocketType, data: bytes):
Expand Down Expand Up @@ -668,7 +516,9 @@ def request(
last_exc = None
while retry_count < 2:
retry_count += 1
socket = self._get_socket(host, port, proto, timeout=timeout)
socket = self._connection_manager.get_socket(
host, port, proto, timeout=timeout, ssl_context=self._ssl_context
)
ok = True
try:
self._send_request(socket, host, method, path, headers, data, json)
Expand All @@ -689,7 +539,7 @@ def request(
if result == b"H":
# Things seem to be ok so break with socket set.
break
self._close_socket(socket)
self._connection_manager.close_socket(socket)
socket = None

if not socket:
Expand Down Expand Up @@ -748,54 +598,6 @@ def delete(self, url: str, **kw) -> Response:
return self.request("DELETE", url, **kw)


# Backwards compatible API:

_default_session = None # pylint: disable=invalid-name


class _FakeSSLSocket:
def __init__(self, socket: CircuitPythonSocketType, tls_mode: int) -> None:
self._socket = socket
self._mode = tls_mode
self.settimeout = socket.settimeout
self.send = socket.send
self.recv = socket.recv
self.close = socket.close
self.recv_into = socket.recv_into

def connect(self, address: Tuple[str, int]) -> None:
"""connect wrapper to add non-standard mode parameter"""
try:
return self._socket.connect(address, self._mode)
except RuntimeError as error:
raise OSError(errno.ENOMEM) from error


class _FakeSSLContext:
def __init__(self, iface: InterfaceType) -> None:
self._iface = iface

def wrap_socket(
self, socket: CircuitPythonSocketType, server_hostname: Optional[str] = None
) -> _FakeSSLSocket:
"""Return the same socket"""
# pylint: disable=unused-argument
return _FakeSSLSocket(socket, self._iface.TLS_MODE)


def set_socket(
sock: SocketpoolModuleType, iface: Optional[InterfaceType] = None
) -> None:
"""Legacy API for setting the socket and network interface. Use a `Session` instead."""
global _default_session # pylint: disable=global-statement,invalid-name
if not iface:
# pylint: disable=protected-access
_default_session = Session(sock, _FakeSSLContext(sock._the_interface))
else:
_default_session = Session(sock, _FakeSSLContext(iface))
sock.set_interface(iface)


def request(
method: str,
url: str,
Expand All @@ -807,7 +609,7 @@ def request(
) -> None:
"""Send HTTP request"""
# pylint: disable=too-many-arguments
_default_session.request(
_global_session.request(
method,
url,
data=data,
Expand All @@ -820,29 +622,29 @@ def request(

def head(url: str, **kw):
"""Send HTTP HEAD request"""
return _default_session.request("HEAD", url, **kw)
return _global_session.request("HEAD", url, **kw)


def get(url: str, **kw):
"""Send HTTP GET request"""
return _default_session.request("GET", url, **kw)
return _global_session.request("GET", url, **kw)


def post(url: str, **kw):
"""Send HTTP POST request"""
return _default_session.request("POST", url, **kw)
return _global_session.request("POST", url, **kw)


def put(url: str, **kw):
"""Send HTTP PUT request"""
return _default_session.request("PUT", url, **kw)
return _global_session.request("PUT", url, **kw)


def patch(url: str, **kw):
"""Send HTTP PATCH request"""
return _default_session.request("PATCH", url, **kw)
return _global_session.request("PATCH", url, **kw)


def delete(url: str, **kw):
"""Send HTTP DELETE request"""
return _default_session.request("DELETE", url, **kw)
return _global_session.request("DELETE", url, **kw)
Loading