diff --git a/CHANGES b/CHANGES index 9d82341a76..439eee8f31 100644 --- a/CHANGES +++ b/CHANGES @@ -36,6 +36,7 @@ * Fix Sentinel.execute_command doesn't execute across the entire sentinel cluster bug (#2458) * Added a replacement for the default cluster node in the event of failure (#2463) * Fix for Unhandled exception related to self.host with unix socket (#2496) + * Simplified connection allocation code for asyncio.connection.BlockingConnectionPool * 4.1.3 (Feb 8, 2022) * Fix flushdb and flushall (#1926) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 2c75d4fcf1..b99a5ed06f 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1237,6 +1237,19 @@ class ConnectionPool: ``connection_class``. """ + __slots__ = ( + "connection_class", + "connection_kwargs", + "max_connections", + "_fork_lock", + "_lock", + "_created_connections", + "_available_connections", + "_in_use_connections", + "encoder_class", + "pid", + ) + @classmethod def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: """ @@ -1519,18 +1532,24 @@ class BlockingConnectionPool(ConnectionPool): >>> pool = BlockingConnectionPool(timeout=5) """ + __slots__ = ( + "queue_class", + "timeout", + "pool", + ) + def __init__( self, max_connections: int = 50, timeout: Optional[int] = 20, connection_class: Type[Connection] = Connection, - queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + queue_class: Type[asyncio.Queue] = asyncio.Queue, **connection_kwargs, ): - self.queue_class = queue_class self.timeout = timeout - self._connections: List[Connection] + self._in_use_connections: Set[Connection] + super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1538,17 +1557,12 @@ def __init__( ) def reset(self): - # Create and fill up a thread safe queue with ``None`` values. + # a queue of ready connections. populated lazily self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except asyncio.QueueFull: - break - - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # used to decide wether we can allocate new connection or wait + self._created_connections = 0 + # keep track of connections that are outside queue to close them + self._in_use_connections = set() # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1562,42 +1576,40 @@ def reset(self): self.pid = os.getpid() def make_connection(self): - """Make a fresh connection.""" - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + """Create a new connection""" + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) async def get_connection(self, command_name, *keys, **options): """ Get a connection, blocking for ``self.timeout`` until a connection is available from the pool. - If the connection returned is ``None`` then creates a new connection. - Because we use a last-in first-out queue, the existing connections - (having been returned to the pool after the initial ``None`` values - were added) will be returned before ``None`` values. This means we only - create new connections when we need to, i.e.: the actual number of - connections will only increase in response to demand. + Checks internal connection counter to ensure connections are allocated lazily. """ # Make sure we haven't changed process. self._checkpid() - # Try and get a connection from the pool. If one isn't available within - # self.timeout then raise a ``ConnectionError``. - connection = None - try: - async with async_timeout.timeout(self.timeout): - connection = await self.pool.get() - except (asyncio.QueueEmpty, asyncio.TimeoutError): - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() - + # if we are under max_connections, try getting one immediately. if it fails + # it is ok to allocate new one + if self._created_connections < self.max_connections: + try: + connection = self.pool.get_nowait() + except asyncio.QueueEmpty: + connection = self.make_connection() + else: + # wait for available connection + try: + async with async_timeout.timeout(self.timeout): + connection = await self.pool.get() + except asyncio.TimeoutError: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. + raise ConnectionError("No connection available.") + + # add to set before try block to ensure release does not try to .remove missing + # value + self._in_use_connections.add(connection) try: # ensure this connection is connected to Redis await connection.connect() @@ -1624,15 +1636,15 @@ async def release(self, connection: Connection): """Releases the connection back to the pool.""" # Make sure we haven't changed process. self._checkpid() + if not self.owns_connection(connection): # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. + # to the pool await connection.disconnect() - self.pool.put_nowait(None) return + self._in_use_connections.remove(connection) + # Put the connection back into the pool. try: self.pool.put_nowait(connection) @@ -1646,7 +1658,15 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: resp = await asyncio.gather( - *(connection.disconnect() for connection in self._connections), + *( + self.pool.get_nowait().disconnect() + for _ in range(self.pool.qsize()) + ), + *( + connection.disconnect() + for connection in self._in_use_connections + if inuse_connections + ), return_exceptions=True, ) exc = next((r for r in resp if isinstance(r, BaseException)), None) diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index 8e4fdac309..9a6628059e 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -70,7 +70,7 @@ async def call_with_retry(self, _, __): mock_conn = mock.MagicMock() mock_conn.retry = Retry_() - async def get_conn(_): + async def get_conn(*_): # Validate only one client is created in single-client mode when # concurrent requests are made nonlocal init_call_count @@ -78,8 +78,8 @@ async def get_conn(_): init_call_count += 1 return mock_conn - with mock.patch.object(r.connection_pool, "get_connection", get_conn): - with mock.patch.object(r.connection_pool, "release"): + with mock.patch.object(type(r.connection_pool), "get_connection", get_conn): + with mock.patch.object(type(r.connection_pool), "release"): await asyncio.gather(r.set("a", "b"), r.set("c", "d")) assert init_call_count == 1