diff --git a/CHANGES b/CHANGES index 7b3b4c5ac2..8cfc47db18 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Add 'aclose()' methods to async classes, deprecate async close(). * Fix #2831, add auto_close_connection_pool=True arg to asyncio.Redis.from_url() * Fix incorrect redis.asyncio.Cluster type hint for `retry_on_error` * Fix dead weakref in sentinel connection causing ReferenceError (#2767) diff --git a/docs/examples/asyncio_examples.ipynb b/docs/examples/asyncio_examples.ipynb index a8ee8faf8b..5eab4db1f7 100644 --- a/docs/examples/asyncio_examples.ipynb +++ b/docs/examples/asyncio_examples.ipynb @@ -15,7 +15,7 @@ "\n", "## Connecting and Disconnecting\n", "\n", - "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.close` which disconnects all connections." + "Utilizing asyncio Redis requires an explicit disconnect of the connection since there is no asyncio deconstructor magic method. By default, a connection pool is created on `redis.Redis()` and attached to this `Redis` instance. The connection pool closes automatically on the call to `Redis.aclose` which disconnects all connections." ] }, { @@ -39,9 +39,9 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis()\n", - "print(f\"Ping successful: {await connection.ping()}\")\n", - "await connection.close()" + "client = redis.Redis()\n", + "print(f\"Ping successful: {await client.ping()}\")\n", + "await client.aclose()" ] }, { @@ -60,8 +60,8 @@ "import redis.asyncio as redis\n", "\n", "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", - "connection = redis.Redis.from_pool(pool)\n", - "await connection.close()" + "client = redis.Redis.from_pool(pool)\n", + "await client.close()" ] }, { @@ -91,11 +91,11 @@ "import redis.asyncio as redis\n", "\n", "pool = redis.ConnectionPool.from_url(\"redis://localhost\")\n", - "connection1 = redis.Redis(connection_pool=pool)\n", - "connection2 = redis.Redis(connection_pool=pool)\n", - "await connection1.close()\n", - "await connection2.close()\n", - "await pool.disconnect()" + "client1 = redis.Redis(connection_pool=pool)\n", + "client2 = redis.Redis(connection_pool=pool)\n", + "await client1.aclose()\n", + "await client2.aclose()\n", + "await pool.aclose()" ] }, { @@ -113,9 +113,9 @@ "source": [ "import redis.asyncio as redis\n", "\n", - "connection = redis.Redis(protocol=3)\n", - "await connection.close()\n", - "await connection.ping()" + "client = redis.Redis(protocol=3)\n", + "await client.aclose()\n", + "await client.ping()" ] }, { diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 122469a194..c340d851b1 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -14,7 +14,6 @@ List, Mapping, MutableMapping, - NoReturn, Optional, Set, Tuple, @@ -65,6 +64,7 @@ from redis.utils import ( HIREDIS_AVAILABLE, _set_info_logger, + deprecated_function, get_lib_version, safe_str, str_if_bytes, @@ -527,7 +527,7 @@ async def __aenter__(self: _RedisT) -> _RedisT: return await self.initialize() async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() + await self.aclose() _DEL_MESSAGE = "Unclosed Redis client" @@ -539,7 +539,7 @@ def __del__(self, _warnings: Any = warnings) -> None: context = {"client": self, "message": self._DEL_MESSAGE} asyncio.get_running_loop().call_exception_handler(context) - async def close(self, close_connection_pool: Optional[bool] = None) -> None: + async def aclose(self, close_connection_pool: Optional[bool] = None) -> None: """ Closes Redis client connection @@ -557,6 +557,13 @@ async def close(self, close_connection_pool: Optional[bool] = None) -> None: ): await self.connection_pool.disconnect() + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self, close_connection_pool: Optional[bool] = None) -> None: + """ + Alias for aclose(), for backwards compatibility + """ + await self.aclose(close_connection_pool) + async def _send_command_parse_response(self, conn, command_name, *args, **options): """ Send a command and parse the response @@ -764,13 +771,18 @@ async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): - await self.reset() + await self.aclose() def __del__(self): if self.connection: self.connection.clear_connect_callbacks() - async def reset(self): + async def aclose(self): + # In case a connection property does not yet exist + # (due to a crash earlier in the Redis() constructor), return + # immediately as there is nothing to clean-up. + if not hasattr(self, "connection"): + return async with self._lock: if self.connection: await self.connection.disconnect() @@ -782,13 +794,15 @@ async def reset(self): self.patterns = {} self.pending_unsubscribe_patterns = set() - def close(self) -> Awaitable[NoReturn]: - # In case a connection property does not yet exist - # (due to a crash earlier in the Redis() constructor), return - # immediately as there is nothing to clean-up. - if not hasattr(self, "connection"): - return - return self.reset() + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="reset") + async def reset(self) -> None: + """Alias for aclose(), for backwards compatibility""" + await self.aclose() async def on_connect(self, connection: Connection): """Re-subscribe to any channels and patterns previously subscribed to""" @@ -1232,6 +1246,10 @@ async def reset(self): await self.connection_pool.release(self.connection) self.connection = None + async def aclose(self) -> None: + """Alias for reset(), a standard method name for cleanup""" + await self.reset() + def multi(self): """ Start a transactional block of the pipeline after WATCH commands @@ -1264,14 +1282,14 @@ async def _disconnect_reset_raise(self, conn, error): # valid since this connection has died. raise a WatchError, which # indicates the user should retry this transaction. if self.watching: - await self.reset() + await self.aclose() raise WatchError( "A ConnectionError occurred on while watching one or more keys" ) # if retry_on_timeout is not set, or the error is not # a TimeoutError, raise it if not (conn.retry_on_timeout and isinstance(error, TimeoutError)): - await self.reset() + await self.aclose() raise async def immediate_execute_command(self, *args, **options): diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 65e1bd02ac..f4f031580d 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -62,7 +62,13 @@ TryAgainError, ) from redis.typing import AnyKeyT, EncodableT, KeyT -from redis.utils import dict_merge, get_lib_version, safe_str, str_if_bytes +from redis.utils import ( + deprecated_function, + dict_merge, + get_lib_version, + safe_str, + str_if_bytes, +) TargetNodesT = TypeVar( "TargetNodesT", str, "ClusterNode", List["ClusterNode"], Dict[Any, "ClusterNode"] @@ -395,12 +401,12 @@ async def initialize(self) -> "RedisCluster": ) self._initialize = False except BaseException: - await self.nodes_manager.close() - await self.nodes_manager.close("startup_nodes") + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") raise return self - async def close(self) -> None: + async def aclose(self) -> None: """Close all connections & client if initialized.""" if not self._initialize: if not self._lock: @@ -408,14 +414,19 @@ async def close(self) -> None: async with self._lock: if not self._initialize: self._initialize = True - await self.nodes_manager.close() - await self.nodes_manager.close("startup_nodes") + await self.nodes_manager.aclose() + await self.nodes_manager.aclose("startup_nodes") + + @deprecated_function(version="5.0.0", reason="Use aclose() instead", name="close") + async def close(self) -> None: + """alias for aclose() for backwards compatibility""" + await self.aclose() async def __aenter__(self) -> "RedisCluster": return await self.initialize() async def __aexit__(self, exc_type: None, exc_value: None, traceback: None) -> None: - await self.close() + await self.aclose() def __await__(self) -> Generator[Any, None, "RedisCluster"]: return self.initialize().__await__() @@ -767,13 +778,13 @@ async def _execute_command( self.nodes_manager.startup_nodes.pop(target_node.name, None) # Hard force of reinitialize of the node/slots setup # and try again with the new setup - await self.close() + await self.aclose() raise except ClusterDownError: # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command - await self.close() + await self.aclose() await asyncio.sleep(0.25) raise except MovedError as e: @@ -790,7 +801,7 @@ async def _execute_command( self.reinitialize_steps and self.reinitialize_counter % self.reinitialize_steps == 0 ): - await self.close() + await self.aclose() # Reset the counter self.reinitialize_counter = 0 else: @@ -1323,7 +1334,7 @@ async def initialize(self) -> None: # If initialize was called after a MovedError, clear it self._moved_exception = None - async def close(self, attr: str = "nodes_cache") -> None: + async def aclose(self, attr: str = "nodes_cache") -> None: self.default_node = None await asyncio.gather( *( @@ -1471,7 +1482,7 @@ async def execute( if type(e) in self.__class__.ERRORS_ALLOW_RETRY: # Try again with the new cluster setup. exception = e - await self._client.close() + await self._client.aclose() await asyncio.sleep(0.25) else: # All other errors should be raised. diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4f20e6ec1d..71d0e92002 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1095,6 +1095,10 @@ async def disconnect(self, inuse_connections: bool = True): if exc: raise exc + async def aclose(self) -> None: + """Close the pool, disconnecting all connections""" + await self.disconnect() + def set_retry(self, retry: "Retry") -> None: for conn in self._available_connections: conn.retry = retry diff --git a/redis/client.py b/redis/client.py index d285e1ca46..1e1ff57605 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1217,6 +1217,10 @@ def reset(self): self.connection_pool.release(self.connection) self.connection = None + def close(self): + """Close the pipeline""" + self.reset() + def multi(self): """ Start a transactional block of the pipeline after WATCH commands diff --git a/redis/connection.py b/redis/connection.py index c6a22aae76..45ecd2a370 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -1154,6 +1154,10 @@ def disconnect(self, inuse_connections=True): for connection in connections: connection.disconnect() + def close(self) -> None: + """Close the pool, disconnecting all connections""" + self.disconnect() + def set_retry(self, retry: "Retry") -> None: self.connection_kwargs.update({"retry": retry}) for conn in self._available_connections: diff --git a/tests/test_asyncio/compat.py b/tests/test_asyncio/compat.py index 5edcd4ae54..4a9778b70a 100644 --- a/tests/test_asyncio/compat.py +++ b/tests/test_asyncio/compat.py @@ -6,6 +6,18 @@ except AttributeError: import mock +try: + from contextlib import aclosing +except ImportError: + import contextlib + + @contextlib.asynccontextmanager + async def aclosing(thing): + try: + yield thing + finally: + await thing.aclose() + def create_task(coroutine): return asyncio.create_task(coroutine) diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index 10ab4732c2..5d9e0b4f2e 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -100,7 +100,7 @@ async def teardown(): # handle cases where a test disconnected a client # just manually retry the flushdb await client.flushdb() - await client.close() + await client.aclose() await client.connection_pool.disconnect() else: if flushdb: @@ -110,7 +110,7 @@ async def teardown(): # handle cases where a test disconnected a client # just manually retry the flushdb await client.flushdb(target_nodes="primaries") - await client.close() + await client.aclose() teardown_clients.append(teardown) return client diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 332101edd5..e6cf2e4ce7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -38,7 +38,7 @@ ) from ..ssl_utils import get_ssl_filename -from .compat import mock +from .compat import aclosing, mock pytestmark = pytest.mark.onlycluster @@ -270,7 +270,38 @@ async def test_host_port_startup_node(self) -> None: cluster = await get_mocked_redis_client(host=default_host, port=default_port) assert cluster.get_node(host=default_host, port=default_port) is not None - await cluster.close() + await cluster.aclose() + + async def test_aclosing(self) -> None: + cluster = await get_mocked_redis_client(host=default_host, port=default_port) + called = 0 + + async def mock_aclose(): + nonlocal called + called += 1 + + with mock.patch.object(cluster, "aclose", mock_aclose): + async with aclosing(cluster): + pass + assert called == 1 + await cluster.aclose() + + async def test_close_is_aclose(self) -> None: + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = await get_mocked_redis_client(host=default_host, port=default_port) + called = 0 + + async def mock_aclose(): + nonlocal called + called += 1 + + with mock.patch.object(cluster, "aclose", mock_aclose): + await cluster.close() + assert called == 1 + await cluster.aclose() async def test_startup_nodes(self) -> None: """ @@ -289,7 +320,7 @@ async def test_startup_nodes(self) -> None: and cluster.get_node(host=default_host, port=port_2) is not None ) - await cluster.close() + await cluster.aclose() startup_node = ClusterNode("127.0.0.1", 16379) async with RedisCluster(startup_nodes=[startup_node], client_name="test") as rc: @@ -417,7 +448,7 @@ async def read_response_mocked(*args: Any, **kwargs: Any) -> None: ) ) - await rc.close() + await rc.aclose() async def test_execute_command_errors(self, r: RedisCluster) -> None: """ @@ -461,7 +492,7 @@ async def test_execute_command_node_flag_replicas(self, r: RedisCluster) -> None conn = primary._free.pop() assert conn.read_response.called is not True - await r.close() + await r.aclose() async def test_execute_command_node_flag_all_nodes(self, r: RedisCluster) -> None: """ @@ -690,7 +721,7 @@ def execute_command_mock_third(self, *args, **options): await read_cluster.get("foo") mocks["send_command"].assert_has_calls([mock.call("READONLY")]) - await read_cluster.close() + await read_cluster.aclose() async def test_keyslot(self, r: RedisCluster) -> None: """ @@ -762,7 +793,7 @@ def raise_error(target_node, *args, **kwargs): await rc.get("bar") assert execute_command.failed_calls == rc.cluster_error_retry_attempts - await rc.close() + await rc.aclose() async def test_set_default_node_success(self, r: RedisCluster) -> None: """ @@ -843,7 +874,7 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non *(rc.echo("i", target_nodes=RedisCluster.ALL_NODES) for i in range(100)) ) ) - await rc.close() + await rc.aclose() def test_replace_cluster_node(self, r: RedisCluster) -> None: prev_default_node = r.get_default_node() @@ -901,7 +932,7 @@ def address_remap(address): assert await r.set("byte_string", b"giraffe") assert await r.get("byte_string") == b"giraffe" finally: - await r.close() + await r.aclose() finally: await asyncio.gather(*[p.aclose() for p in proxies]) @@ -1002,7 +1033,7 @@ async def test_initialize_before_execute_multi_key_command( url = request.config.getoption("--redis-url") r = RedisCluster.from_url(url) assert 0 == await r.exists("a", "b", "c") - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_cluster_myid(self, r: RedisCluster) -> None: @@ -1065,7 +1096,7 @@ async def test_cluster_delslots(self) -> None: assert node0._free.pop().read_response.called assert node1._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -1076,7 +1107,7 @@ async def test_cluster_delslotsrange(self): await r.cluster_addslots(node, 1, 2, 3, 4, 5) assert await r.cluster_delslotsrange(1, 5) assert node._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_cluster_failover(self, r: RedisCluster) -> None: @@ -1286,7 +1317,7 @@ async def test_readonly(self) -> None: for replica in r.get_replicas(): assert replica._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_readwrite(self) -> None: @@ -1299,7 +1330,7 @@ async def test_readwrite(self) -> None: for replica in r.get_replicas(): assert replica._free.pop().read_response.called - await r.close() + await r.aclose() @skip_if_redis_enterprise() async def test_bgsave(self, r: RedisCluster) -> None: @@ -1524,7 +1555,7 @@ async def test_client_kill( ] assert len(clients) == 1 assert clients[0].get("name") == "redis-py-c1" - await r2.close() + await r2.aclose() @skip_if_server_version_lt("2.6.0") async def test_cluster_bitop_not_empty_string(self, r: RedisCluster) -> None: @@ -2302,7 +2333,7 @@ async def test_acl_log( await r.acl_deluser(username, target_nodes="primaries") - await user_client.close() + await user_client.aclose() class TestNodesManager: @@ -2359,7 +2390,7 @@ async def test_init_slots_cache_not_all_slots_covered(self) -> None: cluster_slots=cluster_slots, require_full_coverage=True, ) - await rc.close() + await rc.aclose() assert str(ex.value).startswith( "All slots are not covered after query all startup_nodes." ) @@ -2385,7 +2416,7 @@ async def test_init_slots_cache_not_require_full_coverage_success(self) -> None: assert 5460 not in rc.nodes_manager.slots_cache - await rc.close() + await rc.aclose() async def test_init_slots_cache(self) -> None: """ @@ -2416,7 +2447,7 @@ async def test_init_slots_cache(self) -> None: assert len(n_manager.nodes_cache) == 6 - await rc.close() + await rc.aclose() async def test_init_slots_cache_cluster_mode_disabled(self) -> None: """ @@ -2427,7 +2458,7 @@ async def test_init_slots_cache_cluster_mode_disabled(self) -> None: rc = await get_mocked_redis_client( host=default_host, port=default_port, cluster_enabled=False ) - await rc.close() + await rc.aclose() assert "Cluster mode is not enabled on this node" in str(e.value) async def test_empty_startup_nodes(self) -> None: @@ -2514,7 +2545,7 @@ async def test_cluster_one_instance(self) -> None: for i in range(0, REDIS_CLUSTER_HASH_SLOTS): assert n.slots_cache[i] == [n_node] - await rc.close() + await rc.aclose() async def test_init_with_down_node(self) -> None: """ diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index a954a40dbf..28e6b0d9c3 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -302,7 +302,25 @@ async def get_redis_connection(): r1 = await get_redis_connection() assert r1.auto_close_connection_pool is True - await r1.close() + await r1.aclose() + + +async def test_close_is_aclose(request): + """Verify close() calls aclose()""" + calls = 0 + + async def mock_aclose(self): + nonlocal calls + calls += 1 + + url: str = request.config.getoption("--redis-url") + r1 = await Redis.from_url(url) + with patch.object(r1, "aclose", mock_aclose): + await r1.close() + assert calls == 1 + + with pytest.deprecated_call(): + await r1.close() async def test_pool_from_url_deprecation(request): @@ -326,7 +344,7 @@ async def get_redis_connection(): r1 = await get_redis_connection() assert r1.auto_close_connection_pool is False await r1.connection_pool.disconnect() - await r1.close() + await r1.aclose() @pytest.mark.parametrize("from_url", (True, False), ids=("from_url", "from_args")) diff --git a/tests/test_asyncio/test_connection_pool.py b/tests/test_asyncio/test_connection_pool.py index 999c03376b..c93fa91a39 100644 --- a/tests/test_asyncio/test_connection_pool.py +++ b/tests/test_asyncio/test_connection_pool.py @@ -7,7 +7,7 @@ from redis.asyncio.connection import Connection, to_bool from tests.conftest import skip_if_redis_enterprise, skip_if_server_version_lt -from .compat import mock +from .compat import aclosing, mock from .conftest import asynccontextmanager from .test_pubsub import wait_for_message @@ -42,7 +42,7 @@ async def test_auto_disconnect_redis_created_pool(self, r: redis.Redis): new_conn = await self.create_two_conn(r) assert new_conn != r.connection assert self.get_total_connected_connections(r.connection_pool) == 2 - await r.close() + await r.aclose() assert self.has_no_connected_connections(r.connection_pool) async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): @@ -52,7 +52,7 @@ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): ) new_conn = await self.create_two_conn(r2) assert self.get_total_connected_connections(r2.connection_pool) == 2 - await r2.close() + await r2.aclose() assert r2.connection_pool._in_use_connections == {new_conn} assert new_conn.is_connected assert len(r2.connection_pool._available_connections) == 1 @@ -61,7 +61,7 @@ async def test_do_not_auto_disconnect_redis_created_pool(self, r2: redis.Redis): async def test_auto_release_override_true_manual_created_pool(self, r: redis.Redis): assert r.auto_close_connection_pool is True, "This is from the class fixture" await self.create_two_conn(r) - await r.close() + await r.aclose() assert self.get_total_connected_connections(r.connection_pool) == 2, ( "The connection pool should not be disconnected as a manually created " "connection pool was passed in in conftest.py" @@ -72,7 +72,7 @@ async def test_auto_release_override_true_manual_created_pool(self, r: redis.Red async def test_close_override(self, r: redis.Redis, auto_close_conn_pool): r.auto_close_connection_pool = auto_close_conn_pool await self.create_two_conn(r) - await r.close(close_connection_pool=True) + await r.aclose(close_connection_pool=True) assert self.has_no_connected_connections(r.connection_pool) @pytest.mark.parametrize("auto_close_conn_pool", [True, False]) @@ -81,7 +81,7 @@ async def test_negate_auto_close_client_pool( ): r.auto_close_connection_pool = auto_close_conn_pool new_conn = await self.create_two_conn(r) - await r.close(close_connection_pool=False) + await r.aclose(close_connection_pool=False) assert not self.has_no_connected_connections(r.connection_pool) assert r.connection_pool._in_use_connections == {new_conn} assert r.connection_pool._available_connections[0].is_connected @@ -135,6 +135,16 @@ async def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + async def test_aclosing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + async with aclosing(pool): + pass + async def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0]} async with self.get_pool(connection_kwargs=connection_kwargs) as pool: diff --git a/tests/test_asyncio/test_pipeline.py b/tests/test_asyncio/test_pipeline.py index edd2f6d147..0fa1204750 100644 --- a/tests/test_asyncio/test_pipeline.py +++ b/tests/test_asyncio/test_pipeline.py @@ -2,6 +2,7 @@ import redis from tests.conftest import skip_if_server_version_lt +from .compat import aclosing, mock from .conftest import wait_for_command @@ -286,6 +287,24 @@ async def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command["command"] == "UNWATCH" + @pytest.mark.onlynoncluster + async def test_aclose_is_reset(self, r): + async with r.pipeline() as pipe: + called = 0 + + async def mock_reset(): + nonlocal called + called += 1 + + with mock.patch.object(pipe, "reset", mock_reset): + await pipe.aclose() + assert called == 1 + + @pytest.mark.onlynoncluster + async def test_aclosing(self, r): + async with aclosing(r.pipeline()): + pass + @pytest.mark.onlynoncluster async def test_transaction_callable(self, r): await r.set("a", 1) diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 858576584f..8fef34d83d 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -20,7 +20,7 @@ from redis.utils import HIREDIS_AVAILABLE from tests.conftest import get_protocol_version, skip_if_server_version_lt -from .compat import create_task, mock +from .compat import aclosing, create_task, mock def with_timeout(t): @@ -84,9 +84,8 @@ def make_subscribe_test_data(pubsub, type): @pytest_asyncio.fixture() async def pubsub(r: redis.Redis): - p = r.pubsub() - yield p - await p.close() + async with r.pubsub() as p: + yield p @pytest.mark.onlynoncluster @@ -217,6 +216,46 @@ async def test_subscribe_property_with_patterns(self, pubsub): kwargs = make_subscribe_test_data(pubsub, "pattern") await self._test_subscribed_property(**kwargs) + async def test_aclosing(self, r: redis.Redis): + p = r.pubsub() + async with aclosing(p): + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + assert p.subscribed is False + + async def test_context_manager(self, r: redis.Redis): + p = r.pubsub() + async with p: + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + assert p.subscribed is False + + async def test_close_is_aclose(self, r: redis.Redis): + """ + Test backwards compatible close method + """ + p = r.pubsub() + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + with pytest.deprecated_call(): + await p.close() + assert p.subscribed is False + + async def test_reset_is_aclose(self, r: redis.Redis): + """ + Test backwards compatible reset method + """ + p = r.pubsub() + assert p.subscribed is False + await p.subscribe("foo") + assert p.subscribed is True + with pytest.deprecated_call(): + await p.reset() + assert p.subscribed is False + async def test_ignore_all_subscribe_messages(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -233,7 +272,7 @@ async def test_ignore_all_subscribe_messages(self, r: redis.Redis): assert p.subscribed is True assert await wait_for_message(p) is None assert p.subscribed is False - await p.close() + await p.aclose() async def test_ignore_individual_subscribe_messages(self, pubsub): p = pubsub @@ -350,7 +389,7 @@ async def test_channel_message_handler(self, r: redis.Redis): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") - await p.close() + await p.aclose() async def test_channel_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -359,7 +398,7 @@ async def test_channel_async_message_handler(self, r): assert await r.publish("foo", "test message") == 1 assert await wait_for_message(p) is None assert self.async_message == make_message("message", "foo", "test message") - await p.close() + await p.aclose() async def test_channel_sync_async_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) @@ -371,7 +410,7 @@ async def test_channel_sync_async_message_handler(self, r): assert await wait_for_message(p) is None assert self.message == make_message("message", "foo", "test message") assert self.async_message == make_message("message", "bar", "test message 2") - await p.close() + await p.aclose() @pytest.mark.onlynoncluster async def test_pattern_message_handler(self, r: redis.Redis): @@ -383,7 +422,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", "foo", "test message", pattern="f*" ) - await p.close() + await p.aclose() async def test_unicode_channel_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -394,7 +433,7 @@ async def test_unicode_channel_message_handler(self, r: redis.Redis): assert await r.publish(channel, "test message") == 1 assert await wait_for_message(p) is None assert self.message == make_message("message", channel, "test message") - await p.close() + await p.aclose() @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html @@ -410,7 +449,7 @@ async def test_unicode_pattern_message_handler(self, r: redis.Redis): assert self.message == make_message( "pmessage", channel, "test message", pattern=pattern ) - await p.close() + await p.aclose() async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub): p = pubsub @@ -524,7 +563,7 @@ async def test_channel_message_handler(self, r: redis.Redis): await r.publish(self.channel, new_data) assert await wait_for_message(p) is None assert self.message == self.make_message("message", self.channel, new_data) - await p.close() + await p.aclose() async def test_pattern_message_handler(self, r: redis.Redis): p = r.pubsub(ignore_subscribe_messages=True) @@ -546,7 +585,7 @@ async def test_pattern_message_handler(self, r: redis.Redis): assert self.message == self.make_message( "pmessage", self.channel, new_data, pattern=self.pattern ) - await p.close() + await p.aclose() async def test_context_manager(self, r: redis.Redis): async with r.pubsub() as pubsub: @@ -556,7 +595,7 @@ async def test_context_manager(self, r: redis.Redis): assert pubsub.connection is None assert pubsub.channels == {} assert pubsub.patterns == {} - await pubsub.close() + await pubsub.aclose() @pytest.mark.onlynoncluster @@ -597,9 +636,9 @@ async def test_pubsub_numsub(self, r: redis.Redis): channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] assert await r.pubsub_numsub("foo", "bar", "baz") == channels - await p1.close() - await p2.close() - await p3.close() + await p1.aclose() + await p2.aclose() + await p3.aclose() @skip_if_server_version_lt("2.8.0") async def test_pubsub_numpat(self, r: redis.Redis): @@ -608,7 +647,7 @@ async def test_pubsub_numpat(self, r: redis.Redis): for i in range(3): assert (await wait_for_message(p))["type"] == "psubscribe" assert await r.pubsub_numpat() == 3 - await p.close() + await p.aclose() @pytest.mark.onlynoncluster @@ -621,7 +660,7 @@ async def test_send_pubsub_ping(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="", pattern=None ) - await p.close() + await p.aclose() @skip_if_server_version_lt("3.0.0") async def test_send_pubsub_ping_message(self, r: redis.Redis): @@ -631,7 +670,7 @@ async def test_send_pubsub_ping_message(self, r: redis.Redis): assert await wait_for_message(p) == make_message( type="pong", channel=None, data="hello world", pattern=None ) - await p.close() + await p.aclose() @pytest.mark.onlynoncluster diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index ab0fc6be98..ef70a8ff35 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -1,6 +1,7 @@ import os import re import time +from contextlib import closing from threading import Thread from unittest import mock @@ -51,6 +52,16 @@ def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs + def test_closing(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = redis.ConnectionPool( + connection_class=DummyConnection, + max_connections=None, + **connection_kwargs, + ) + with closing(pool): + pass + def test_multiple_connections(self, master_host): connection_kwargs = {"host": master_host[0], "port": master_host[1]} pool = self.get_pool(connection_kwargs=connection_kwargs) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 7b048eec01..e64a763bae 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,3 +1,6 @@ +from contextlib import closing +from unittest import mock + import pytest import redis @@ -284,6 +287,24 @@ def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command["command"] == "UNWATCH" + @pytest.mark.onlynoncluster + def test_close_is_reset(self, r): + with r.pipeline() as pipe: + called = 0 + + def mock_reset(): + nonlocal called + called += 1 + + with mock.patch.object(pipe, "reset", mock_reset): + pipe.close() + assert called == 1 + + @pytest.mark.onlynoncluster + def test_closing(self, r): + with closing(r.pipeline()): + pass + @pytest.mark.onlynoncluster def test_transaction_callable(self, r): r["a"] = 1