From daea0844d1b055580a2a666afd97a56144e5c20e Mon Sep 17 00:00:00 2001 From: Petya Slavova Date: Fri, 13 Jun 2025 16:31:36 +0300 Subject: [PATCH] Removing the threading.Lock locks and replacing them with RLock objects to avoid deadlocks. --- redis/client.py | 8 ++------ redis/connection.py | 10 ++-------- redis/event.py | 8 ++++---- tests/test_cluster_transaction.py | 4 ++-- tests/test_connection.py | 6 +++--- tests/test_credentials.py | 10 +++++----- 6 files changed, 18 insertions(+), 28 deletions(-) diff --git a/redis/client.py b/redis/client.py index 28e9a82f76..0e05b6f542 100755 --- a/redis/client.py +++ b/redis/client.py @@ -368,9 +368,7 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") - # TODO: To avoid breaking changes during the bug fix, we have to keep non-reentrant lock. - # TODO: Remove this before next major version (7.0.0) - self.single_connection_lock = threading.Lock() + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client if self._single_connection_client: @@ -776,9 +774,7 @@ def __init__( else: self._event_dispatcher = event_dispatcher - # TODO: To avoid breaking changes during the bug fix, we have to keep non-reentrant lock. - # TODO: Remove this before next major version (7.0.0) - self._lock = threading.Lock() + self._lock = threading.RLock() if self.encoder is None: self.encoder = self.connection_pool.get_encoder() self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE) diff --git a/redis/connection.py b/redis/connection.py index 131ae68c61..d457b1015c 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -810,7 +810,7 @@ def __init__( self, conn: ConnectionInterface, cache: CacheInterface, - pool_lock: threading.Lock, + pool_lock: threading.RLock, ): self.pid = os.getpid() self._conn = conn @@ -1422,13 +1422,7 @@ def __init__( # release the lock. self._fork_lock = threading.RLock() - - if self.cache is None: - self._lock = threading.RLock() - else: - # TODO: To avoid breaking changes during the bug fix, we have to keep non-reentrant lock. - # TODO: Remove this before next major version (7.0.0) - self._lock = threading.Lock() + self._lock = threading.RLock() self.reset() diff --git a/redis/event.py b/redis/event.py index 5cc6c0017c..b86c66b082 100644 --- a/redis/event.py +++ b/redis/event.py @@ -152,7 +152,7 @@ def __init__( self, connection, client_type: ClientType, - connection_lock: Union[threading.Lock, asyncio.Lock], + connection_lock: Union[threading.RLock, asyncio.Lock], ): self._connection = connection self._client_type = client_type @@ -167,7 +167,7 @@ def client_type(self) -> ClientType: return self._client_type @property - def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]: + def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: return self._connection_lock @@ -177,7 +177,7 @@ def __init__( pubsub_connection, connection_pool, client_type: ClientType, - connection_lock: Union[threading.Lock, asyncio.Lock], + connection_lock: Union[threading.RLock, asyncio.Lock], ): self._pubsub_connection = pubsub_connection self._connection_pool = connection_pool @@ -197,7 +197,7 @@ def client_type(self) -> ClientType: return self._client_type @property - def connection_lock(self) -> Union[threading.Lock, asyncio.Lock]: + def connection_lock(self) -> Union[threading.RLock, asyncio.Lock]: return self._connection_lock diff --git a/tests/test_cluster_transaction.py b/tests/test_cluster_transaction.py index cd43441f4c..6ebd6df566 100644 --- a/tests/test_cluster_transaction.py +++ b/tests/test_cluster_transaction.py @@ -285,7 +285,7 @@ def test_retry_transaction_on_connection_error(self, r, mock_connection): mock_pool = Mock(spec=ConnectionPool) mock_pool.get_connection.return_value = mock_connection mock_pool._available_connections = [mock_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) node_importing.redis_connection.connection_pool = mock_pool @@ -310,7 +310,7 @@ def test_retry_transaction_on_connection_error_with_watched_keys( mock_pool = Mock(spec=ConnectionPool) mock_pool.get_connection.return_value = mock_connection mock_pool._available_connections = [mock_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() _node_migrating, node_importing = _find_source_and_target_node_for_slot(r, slot) node_importing.redis_connection.connection_pool = mock_pool diff --git a/tests/test_connection.py b/tests/test_connection.py index 9664146ce5..89ea04df75 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -442,7 +442,7 @@ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf): mock_connection.credential_provider = UsernamePasswordCredentialProvider() proxy_connection = CacheProxyConnection( - mock_connection, cache, threading.Lock() + mock_connection, cache, threading.RLock() ) proxy_connection.disconnect() @@ -492,7 +492,7 @@ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection): mock_connection.can_read.return_value = False proxy_connection = CacheProxyConnection( - mock_connection, mock_cache, threading.Lock() + mock_connection, mock_cache, threading.RLock() ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) assert proxy_connection.read_response() == b"bar" @@ -554,7 +554,7 @@ def test_triggers_invalidation_processing_on_another_connection( mock_connection.can_read.return_value = False proxy_connection = CacheProxyConnection( - mock_connection, mock_cache, threading.Lock() + mock_connection, mock_cache, threading.RLock() ) proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]}) diff --git a/tests/test_credentials.py b/tests/test_credentials.py index 58bbd01f28..8b8e0cfc2c 100644 --- a/tests/test_credentials.py +++ b/tests/test_credentials.py @@ -323,7 +323,7 @@ def test_re_auth_all_connections(self, credential_provider): } mock_pool.get_connection.return_value = mock_connection mock_pool._available_connections = [mock_connection, mock_another_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() auth_token = None def re_auth_callback(token): @@ -382,7 +382,7 @@ def test_re_auth_partial_connections(self, credential_provider): mock_another_connection, mock_failed_connection, ] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() def _raise(error: RedisError): pass @@ -442,7 +442,7 @@ def test_re_auth_pub_sub_in_resp3(self, credential_provider): mock_another_connection, ] mock_pool._available_connections = [mock_another_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() auth_token = None def re_auth_callback(token): @@ -502,7 +502,7 @@ def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): mock_another_connection, ] mock_pool._available_connections = [mock_another_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() auth_token = None def re_auth_callback(token): @@ -560,7 +560,7 @@ def test_fails_on_token_renewal(self, credential_provider): } mock_pool.get_connection.return_value = mock_connection mock_pool._available_connections = [mock_connection, mock_another_connection] - mock_pool._lock = threading.Lock() + mock_pool._lock = threading.RLock() Redis( connection_pool=mock_pool,