diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index c1cc1d310c..00865515ab 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -2,11 +2,9 @@ import copy import enum import inspect -import os import socket import ssl import sys -import threading import weakref from abc import abstractmethod from itertools import chain @@ -41,7 +39,6 @@ from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, - ChildDeadlockedError, ConnectionError, DataError, RedisError, @@ -97,7 +94,6 @@ class AbstractConnection: """Manages communication to and from a Redis server""" __slots__ = ( - "pid", "db", "username", "client_name", @@ -158,7 +154,6 @@ def __init__( "1. 'password' and (optional) 'username'\n" "2. 'credential_provider'" ) - self.pid = os.getpid() self.db = db self.client_name = client_name self.lib_name = lib_name @@ -381,12 +376,11 @@ async def disconnect(self, nowait: bool = False) -> None: if not self.is_connected: return try: - if os.getpid() == self.pid: - self._writer.close() # type: ignore[union-attr] - # wait for close to finish, except when handling errors and - # forcefully disconnecting. - if not nowait: - await self._writer.wait_closed() # type: ignore[union-attr] + self._writer.close() # type: ignore[union-attr] + # wait for close to finish, except when handling errors and + # forcefully disconnecting. + if not nowait: + await self._writer.wait_closed() # type: ignore[union-attr] except OSError: pass finally: @@ -1004,20 +998,8 @@ def __init__( self.connection_kwargs = connection_kwargs self.max_connections = max_connections - # a lock to protect the critical section in _checkpid(). - # this lock is acquired when the process id changes, such as - # after a fork. during this time, multiple threads in the child - # process could attempt to acquire this lock. the first thread - # to acquire the lock will reset the data structures and lock - # object of this pool. subsequent threads acquiring this lock - # will notice the first thread already did the work and simply - # release the lock. - self._fork_lock = threading.Lock() - self._lock = asyncio.Lock() - self._created_connections: int - self._available_connections: List[AbstractConnection] - self._in_use_connections: Set[AbstractConnection] - self.reset() # lgtm [py/init-calls-subclass] + self._available_connections: List[AbstractConnection] = [] + self._in_use_connections: Set[AbstractConnection] = set() self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) def __repr__(self): @@ -1027,97 +1009,29 @@ def __repr__(self): ) def reset(self): - self._lock = asyncio.Lock() - self._created_connections = 0 self._available_connections = [] self._in_use_connections = set() - # this must be the last operation in this method. while reset() is - # called when holding _fork_lock, other threads in this process - # can call _checkpid() which compares self.pid and os.getpid() without - # holding any lock (for performance reasons). keeping this assignment - # as the last operation ensures that those other threads will also - # notice a pid difference and block waiting for the first thread to - # release _fork_lock. when each of these threads eventually acquire - # _fork_lock, they will notice that another thread already called - # reset() and they will immediately release _fork_lock and continue on. - self.pid = os.getpid() - - def _checkpid(self): - # _checkpid() attempts to keep ConnectionPool fork-safe on modern - # systems. this is called by all ConnectionPool methods that - # manipulate the pool's state such as get_connection() and release(). - # - # _checkpid() determines whether the process has forked by comparing - # the current process id to the process id saved on the ConnectionPool - # instance. if these values are the same, _checkpid() simply returns. - # - # when the process ids differ, _checkpid() assumes that the process - # has forked and that we're now running in the child process. the child - # process cannot use the parent's file descriptors (e.g., sockets). - # therefore, when _checkpid() sees the process id change, it calls - # reset() in order to reinitialize the child's ConnectionPool. this - # will cause the child to make all new connection objects. - # - # _checkpid() is protected by self._fork_lock to ensure that multiple - # threads in the child process do not call reset() multiple times. - # - # there is an extremely small chance this could fail in the following - # scenario: - # 1. process A calls _checkpid() for the first time and acquires - # self._fork_lock. - # 2. while holding self._fork_lock, process A forks (the fork() - # could happen in a different thread owned by process A) - # 3. process B (the forked child process) inherits the - # ConnectionPool's state from the parent. that state includes - # a locked _fork_lock. process B will not be notified when - # process A releases the _fork_lock and will thus never be - # able to acquire the _fork_lock. - # - # to mitigate this possible deadlock, _checkpid() will only wait 5 - # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in - # that time it is assumed that the child is deadlocked and a - # redis.ChildDeadlockedError error is raised. - if self.pid != os.getpid(): - acquired = self._fork_lock.acquire(timeout=5) - if not acquired: - raise ChildDeadlockedError - # reset() the instance for the new process if another thread - # hasn't already done so - try: - if self.pid != os.getpid(): - self.reset() - finally: - self._fork_lock.release() + def can_get_connection(self) -> bool: + """Return True if a connection can be retrieved from the pool.""" + return ( + self._available_connections + or len(self._in_use_connections) < self.max_connections + ) async def get_connection(self, command_name, *keys, **options): """Get a connection from the pool""" - self._checkpid() - async with self._lock: - try: - connection = self._available_connections.pop() - except IndexError: - connection = self.make_connection() - self._in_use_connections.add(connection) + try: + connection = self._available_connections.pop() + except IndexError: + if len(self._in_use_connections) >= self.max_connections: + raise ConnectionError("Too many connections") from None + connection = self.make_connection() + self._in_use_connections.add(connection) try: - # ensure this connection is connected to Redis - await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - try: - if await connection.can_read_destructive(): - raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): - await connection.disconnect() - await connection.connect() - if await connection.can_read_destructive(): - raise ConnectionError("Connection not ready") from None + await self.ensure_connection(connection) except BaseException: - # release the connection back to the pool so that we don't - # leak it await self.release(connection) raise @@ -1133,35 +1047,31 @@ def get_encoder(self): ) def make_connection(self): - """Create a new connection""" - if self._created_connections >= self.max_connections: - raise ConnectionError("Too many connections") - self._created_connections += 1 + """Create a new connection. Can be overridden by child classes.""" return self.connection_class(**self.connection_kwargs) + async def ensure_connection(self, connection: AbstractConnection): + """Ensure that the connection object is connected and valid""" + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read_destructive(): + raise ConnectionError("Connection has data") from None + except (ConnectionError, OSError): + await connection.disconnect() + await connection.connect() + if await connection.can_read_destructive(): + raise ConnectionError("Connection not ready") from None + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" - self._checkpid() - async with self._lock: - try: - self._in_use_connections.remove(connection) - except KeyError: - # Gracefully fail when a connection is returned to this pool - # that the pool doesn't actually own - pass - - if self.owns_connection(connection): - self._available_connections.append(connection) - else: - # pool doesn't own this connection. do not add it back - # to the pool and decrement the count so that another - # connection can take its place if needed - self._created_connections -= 1 - await connection.disconnect() - return - - def owns_connection(self, connection: AbstractConnection): - return connection.pid == self.pid + # Connections should always be returned to the correct pool, + # not doing so is an error that will cause an exception here. + self._in_use_connections.remove(connection) + self._available_connections.append(connection) async def disconnect(self, inuse_connections: bool = True): """ @@ -1171,21 +1081,19 @@ async def disconnect(self, inuse_connections: bool = True): current in use, potentially by other tasks. Otherwise only disconnect connections that are idle in the pool. """ - self._checkpid() - async with self._lock: - if inuse_connections: - connections: Iterable[AbstractConnection] = chain( - self._available_connections, self._in_use_connections - ) - else: - connections = self._available_connections - resp = await asyncio.gather( - *(connection.disconnect() for connection in connections), - return_exceptions=True, + if inuse_connections: + connections: Iterable[AbstractConnection] = chain( + self._available_connections, self._in_use_connections ) - exc = next((r for r in resp if isinstance(r, BaseException)), None) - if exc: - raise exc + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc def set_retry(self, retry: "Retry") -> None: for conn in self._available_connections: @@ -1196,21 +1104,21 @@ def set_retry(self, retry: "Retry") -> None: class BlockingConnectionPool(ConnectionPool): """ - Thread-safe blocking connection pool:: + A blocking connection pool:: - >>> from redis.client import Redis + >>> from redis.asyncio.client import Redis >>> client = Redis(connection_pool=BlockingConnectionPool()) It performs the same function as the default - :py:class:`~redis.ConnectionPool` implementation, in that, + :py:class:`~redis.asyncio.ConnectionPool` implementation, in that, it maintains a pool of reusable connections that can be shared by - multiple redis clients (safely across threads if required). + multiple async redis clients. The difference is that, in the event that a client tries to get a connection from the pool when all of connections are in use, rather than raising a :py:class:`~redis.ConnectionError` (as the default - :py:class:`~redis.ConnectionPool` implementation does), it - makes the client wait ("blocks") for a specified number of seconds until + :py:class:`~redis.asyncio.ConnectionPool` implementation does), it + makes blocks the current `Task` for a specified number of seconds until a connection becomes available. Use ``max_connections`` to increase / decrease the pool size:: @@ -1233,131 +1141,30 @@ def __init__( max_connections: int = 50, timeout: Optional[int] = 20, connection_class: Type[AbstractConnection] = Connection, - queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, # deprecated **connection_kwargs, ): - self.queue_class = queue_class - self.timeout = timeout - self._connections: List[AbstractConnection] super().__init__( connection_class=connection_class, max_connections=max_connections, **connection_kwargs, ) - - def reset(self): - # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except asyncio.QueueFull: - break - - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] - - # this must be the last operation in this method. while reset() is - # called when holding _fork_lock, other threads in this process - # can call _checkpid() which compares self.pid and os.getpid() without - # holding any lock (for performance reasons). keeping this assignment - # as the last operation ensures that those other threads will also - # notice a pid difference and block waiting for the first thread to - # release _fork_lock. when each of these threads eventually acquire - # _fork_lock, they will notice that another thread already called - # reset() and they will immediately release _fork_lock and continue on. - self.pid = os.getpid() - - def make_connection(self): - """Make a fresh connection.""" - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + self._condition = asyncio.Condition() + self.timeout = timeout async def get_connection(self, command_name, *keys, **options): - """ - Get a connection, blocking for ``self.timeout`` until a connection - is available from the pool. - - If the connection returned is ``None`` then creates a new connection. - Because we use a last-in first-out queue, the existing connections - (having been returned to the pool after the initial ``None`` values - were added) will be returned before ``None`` values. This means we only - create new connections when we need to, i.e.: the actual number of - connections will only increase in response to demand. - """ - # Make sure we haven't changed process. - self._checkpid() - - # Try and get a connection from the pool. If one isn't available within - # self.timeout then raise a ``ConnectionError``. - connection = None + """Gets a connection from the pool, blocking until one is available""" try: async with async_timeout(self.timeout): - connection = await self.pool.get() - except (asyncio.QueueEmpty, asyncio.TimeoutError): - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() - - try: - # ensure this connection is connected to Redis - await connection.connect() - # connections that the pool provides should be ready to send - # a command. if not, the connection was either returned to the - # pool before all data has been read or the socket has been - # closed. either way, reconnect and verify everything is good. - try: - if await connection.can_read_destructive(): - raise ConnectionError("Connection has data") from None - except (ConnectionError, OSError): - await connection.disconnect() - await connection.connect() - if await connection.can_read_destructive(): - raise ConnectionError("Connection not ready") from None - except BaseException: - # release the connection back to the pool so that we don't leak it - await self.release(connection) - raise - - return connection + async with self._condition: + await self._condition.wait_for(self.can_get_connection) + return await super().get_connection(command_name, *keys, **options) + except asyncio.TimeoutError as err: + raise ConnectionError("No connection available.") from err async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" - # Make sure we haven't changed process. - self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - await connection.disconnect() - self.pool.put_nowait(None) - return - - # Put the connection back into the pool. - try: - self.pool.put_nowait(connection) - except asyncio.QueueFull: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass - - async def disconnect(self, inuse_connections: bool = True): - """Disconnects all connections in the pool.""" - self._checkpid() - async with self._lock: - resp = await asyncio.gather( - *(connection.disconnect() for connection in self._connections), - return_exceptions=True, - ) - exc = next((r for r in resp if isinstance(r, BaseException)), None) - if exc: - raise exc + async with self._condition: + await super().release(connection) + self._condition.notify() diff --git a/tests/conftest.py b/tests/conftest.py index 16f3fbb9db..bad9f43e42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import random import time from typing import Callable, TypeVar +from unittest import mock from unittest.mock import Mock from urllib.parse import urlparse @@ -9,7 +10,7 @@ import redis from packaging.version import Version from redis.backoff import NoBackoff -from redis.connection import parse_url +from redis.connection import Connection, parse_url from redis.exceptions import RedisClusterException from redis.retry import Retry @@ -39,7 +40,6 @@ def __init__( help=None, metavar=None, ): - _option_strings = [] for option_string in option_strings: _option_strings.append(option_string) @@ -72,7 +72,6 @@ def format_usage(self): def pytest_addoption(parser): - parser.addoption( "--redis-url", default=default_redis_url, @@ -354,23 +353,23 @@ def sslclient(request): def _gen_cluster_mock_resp(r, response): - connection = Mock() + connection = Mock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest.fixture() def mock_cluster_resp_ok(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "OK") + yield from _gen_cluster_mock_resp(r, "OK") @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, 2) + yield from _gen_cluster_mock_resp(r, 2) @pytest.fixture() @@ -384,7 +383,7 @@ def mock_cluster_resp_info(request, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -408,7 +407,7 @@ def mock_cluster_resp_nodes(request, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture() @@ -419,7 +418,7 @@ def mock_cluster_resp_slaves(request, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + yield from _gen_cluster_mock_resp(r, response) @pytest.fixture(scope="session") diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index c837f284f7..10ab4732c2 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -8,7 +8,7 @@ from packaging.version import Version from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor -from redis.asyncio.connection import parse_url +from redis.asyncio.connection import Connection, parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff from redis.utils import HIREDIS_AVAILABLE @@ -138,23 +138,25 @@ async def decoded_r(create_redis): def _gen_cluster_mock_resp(r, response): - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.retry = Retry(NoBackoff(), 0) connection.read_response.return_value = response - r.connection = connection - return r + with mock.patch.object(r, "connection", connection): + yield r @pytest_asyncio.fixture() async def mock_cluster_resp_ok(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "OK") + for mocked in _gen_cluster_mock_resp(r, "OK"): + yield mocked @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, 2) + for mocked in _gen_cluster_mock_resp(r, 2): + yield mocked @pytest_asyncio.fixture() @@ -168,7 +170,8 @@ async def mock_cluster_resp_info(create_redis, **kwargs): "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" "cluster_stats_messages_received:105653\r\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -192,7 +195,8 @@ async def mock_cluster_resp_nodes(create_redis, **kwargs): "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " "master,fail - 1447829446956 1447829444948 1 disconnected\n" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked @pytest_asyncio.fixture() @@ -203,7 +207,8 @@ async def mock_cluster_resp_slaves(create_redis, **kwargs): "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " "1447836789290 3 connected']" ) - return _gen_cluster_mock_resp(r, response) + for mocked in _gen_cluster_mock_resp(r, response): + yield mocked async def wait_for_command( diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 1cb1fa5195..332101edd5 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -175,7 +175,7 @@ def cmd_init_mock(self, r: ClusterNode) -> None: def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.return_value = response while node._free: @@ -185,7 +185,7 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode: def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode: - connection = mock.AsyncMock() + connection = mock.AsyncMock(spec=Connection) connection.is_connected = True connection.read_response.side_effect = exc while node._free: diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 7672dc74b4..999c03376b 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -1,5 +1,4 @@ import asyncio -import os import re import pytest @@ -94,7 +93,9 @@ class DummyConnection(Connection): def __init__(self, **kwargs): self.kwargs = kwargs - self.pid = os.getpid() + + def repr_pieces(self): + return [("id", id(self)), ("kwargs", self.kwargs)] async def connect(self): pass diff --git a/tests/test_asyncio/test_cwe_404.py b/tests/test_asyncio/test_cwe_404.py index 35707553f8..76ec2bbd26 100644 --- a/tests/test_asyncio/test_cwe_404.py +++ b/tests/test_asyncio/test_cwe_404.py @@ -213,8 +213,9 @@ def all_clear(): p.send_event.clear() async def wait_for_send(): - asyncio.wait( - [p.send_event.wait() for p in proxies], return_when=asyncio.FIRST_COMPLETED + await asyncio.wait( + [asyncio.Task(p.send_event.wait()) for p in proxies], + return_when=asyncio.FIRST_COMPLETED, ) @contextlib.contextmanager @@ -228,11 +229,10 @@ def set_delay(delay: float): for p in proxies: await stack.enter_async_context(p) - with contextlib.closing( - RedisCluster.from_url( - f"redis://127.0.0.1:{remap_base}", address_remap=remap - ) - ) as r: + r = RedisCluster.from_url( + f"redis://127.0.0.1:{remap_base}", address_remap=remap + ) + try: await r.initialize() await r.set("foo", "foo") await r.set("bar", "bar") @@ -257,3 +257,5 @@ async def doit(): assert await r.get("foo") == b"foo" await asyncio.gather(*[doit() for _ in range(10)]) + finally: + await r.close()