diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index 6c544d478..957420dc0 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -241,31 +241,24 @@ async def deactivate(self, address): connections = self.connections[address] except KeyError: # already removed from the connection pool return - for conn in list(connections): - if not conn.in_use: - connections.remove(conn) - try: - await conn.close() - except OSError: - pass - if not connections: - await self.remove(address) + closable_connections = [ + conn for conn in connections if not conn.in_use + ] + # First remove all connections in question, then try to close them. + # If closing of a connection fails, we will end up in this method + # again. + for conn in closable_connections: + connections.remove(conn) + for conn in closable_connections: + await conn.close() + if not self.connections[address]: + del self.connections[address] def on_write_failure(self, address): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) ) - async def remove(self, address): - """ Remove an address from the connection pool, if present, closing - all connections to that address. - """ - async with self.lock: - for connection in self.connections.pop(address, ()): - try: - await connection.close() - except OSError: - pass async def close(self): """ Close all connections and empty the pool. @@ -274,7 +267,8 @@ async def close(self): try: async with self.lock: for address in list(self.connections): - await self.remove(address) + for connection in self.connections.pop(address, ()): + await connection.close() except TypeError: pass diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index c4e41f7dc..55f8e4b48 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -241,31 +241,24 @@ def deactivate(self, address): connections = self.connections[address] except KeyError: # already removed from the connection pool return - for conn in list(connections): - if not conn.in_use: - connections.remove(conn) - try: - conn.close() - except OSError: - pass - if not connections: - self.remove(address) + closable_connections = [ + conn for conn in connections if not conn.in_use + ] + # First remove all connections in question, then try to close them. + # If closing of a connection fails, we will end up in this method + # again. + for conn in closable_connections: + connections.remove(conn) + for conn in closable_connections: + conn.close() + if not self.connections[address]: + del self.connections[address] def on_write_failure(self, address): raise WriteServiceUnavailable( "No write service available for pool {}".format(self) ) - def remove(self, address): - """ Remove an address from the connection pool, if present, closing - all connections to that address. - """ - with self.lock: - for connection in self.connections.pop(address, ()): - try: - connection.close() - except OSError: - pass def close(self): """ Close all connections and empty the pool. @@ -274,7 +267,8 @@ def close(self): try: with self.lock: for address in list(self.connections): - self.remove(address) + for connection in self.connections.pop(address, ()): + connection.close() except TypeError: pass diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index c1c212844..452fd91e0 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -384,3 +384,50 @@ def liveness_side_effect(*args, **kwargs): cx3.reset.assert_awaited_once() assert cx1 not in pool.connections[cx1.addr] assert cx3 in pool.connections[cx1.addr] + + +@mark_async_test +async def test_multiple_broken_connections_on_close(opener, mocker): + def mock_connection_breaks_on_close(cx): + async def close_side_effect(): + cx.closed.return_value = True + cx.defunct.return_value = True + await pool.deactivate(READER_ADDRESS) + + cx.attach_mock(mocker.AsyncMock(side_effect=close_side_effect), + "close") + + # create pool with 2 idle connections + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + await pool.release(cx1) + await pool.release(cx2) + + # both will loose connection + mock_connection_breaks_on_close(cx1) + mock_connection_breaks_on_close(cx2) + + # force pool to close cx1, which will make it realize that the server is + # unreachable + cx1.stale.return_value = True + + cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + + assert cx3 is not cx1 + assert cx3 is not cx2 + + +@mark_async_test +async def test_failing_opener_leaves_connections_in_use_alone(opener): + pool = AsyncNeo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None) + + opener.side_effect = ServiceUnavailable("Server overloaded") + with pytest.raises((ServiceUnavailable, SessionExpired)): + await pool.acquire(READ_ACCESS, 30, "test_db", None) + assert not cx1.closed() diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index a94b5fd53..4ca3b278c 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -384,3 +384,50 @@ def liveness_side_effect(*args, **kwargs): cx3.reset.assert_called_once() assert cx1 not in pool.connections[cx1.addr] assert cx3 in pool.connections[cx1.addr] + + +@mark_sync_test +def test_multiple_broken_connections_on_close(opener, mocker): + def mock_connection_breaks_on_close(cx): + def close_side_effect(): + cx.closed.return_value = True + cx.defunct.return_value = True + pool.deactivate(READER_ADDRESS) + + cx.attach_mock(mocker.Mock(side_effect=close_side_effect), + "close") + + # create pool with 2 idle connections + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.release(cx1) + pool.release(cx2) + + # both will loose connection + mock_connection_breaks_on_close(cx1) + mock_connection_breaks_on_close(cx2) + + # force pool to close cx1, which will make it realize that the server is + # unreachable + cx1.stale.return_value = True + + cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + + assert cx3 is not cx1 + assert cx3 is not cx2 + + +@mark_sync_test +def test_failing_opener_leaves_connections_in_use_alone(opener): + pool = Neo4jPool( + opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS + ) + cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + + opener.side_effect = ServiceUnavailable("Server overloaded") + with pytest.raises((ServiceUnavailable, SessionExpired)): + pool.acquire(READ_ACCESS, 30, "test_db", None) + assert not cx1.closed()