From cceec10e5a86c2b866b283b2c6c257787cd749af Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Fri, 8 Nov 2019 16:10:38 -0800 Subject: [PATCH 1/2] Add equality test on Redis client and conn pool Currently, there's no `__eq__()` method defined on the `Redis` or `ConnectionPool` classes. Therefore, no two instances of either of these classes will ever be equal. I have a use case where it would be nice for one Redis client instance to be equal to another if they have the same connection kwargs; that is, if they're connected to the same Redis database. --- redis/client.py | 3 +++ redis/connection.py | 3 +++ tests/test_client.py | 23 +++++++++++++++++++++++ tests/test_connection_pool.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+) create mode 100644 tests/test_client.py diff --git a/redis/client.py b/redis/client.py index 5adaa7ca65..ffc805a2b3 100755 --- a/redis/client.py +++ b/redis/client.py @@ -707,6 +707,9 @@ def __init__(self, host='localhost', port=6379, def __repr__(self): return "%s<%s>" % (type(self).__name__, repr(self.connection_pool)) + def __eq__(self, other): + return self.connection_pool == other.connection_pool + def set_response_callback(self, command, callback): "Set a custom Response Callback" self.response_callbacks[command] = callback diff --git a/redis/connection.py b/redis/connection.py index feea041288..d60b2cf039 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1044,6 +1044,9 @@ def __repr__(self): repr(self.connection_class(**self.connection_kwargs)), ) + def __eq__(self, other): + return self.connection_kwargs == other.connection_kwargs + def reset(self): self.pid = os.getpid() self._created_connections = 0 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000000..f9c0c89dde --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,23 @@ +import redis + + +class TestClient(object): + def test_client_equality(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6379/9') + assert r1 == r2 + + def test_clients_unequal_if_different_hosts(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://127.0.0.1:6379/9') + assert r1 != r2 + + def test_clients_unequal_if_different_ports(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/9') + assert r1 != r2 + + def test_clients_unequal_if_different_dbs(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/10') + assert r1 != r2 diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index f580f71d52..9cdae5fef8 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -77,6 +77,41 @@ def test_repr_contains_db_info_unix(self): expected = 'ConnectionPool>' assert repr(pool) == expected + def test_pool_equality(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool1 == pool2 + + def test_pools_unequal_if_different_hosts(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': '127.0.0.1', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_ports(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6380, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_dbs(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6379, 'db': 2} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): From 2f4e2ad6f3694c00ab31dca708390047ce18a8f6 Mon Sep 17 00:00:00 2001 From: Raj Shah Date: Fri, 8 Nov 2019 16:37:02 -0800 Subject: [PATCH 2/2] Make __eq__() methods able to accept other types As of this change, `connection_pool != None` and `redis != None` both work, and both return `False`. --- redis/client.py | 5 ++++- redis/connection.py | 5 ++++- tests/test_client.py | 4 ++++ tests/test_connection_pool.py | 6 ++++++ 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/redis/client.py b/redis/client.py index ffc805a2b3..d77ef0121e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -708,7 +708,10 @@ def __repr__(self): return "%s<%s>" % (type(self).__name__, repr(self.connection_pool)) def __eq__(self, other): - return self.connection_pool == other.connection_pool + return ( + isinstance(other, self.__class__) + and self.connection_pool == other.connection_pool + ) def set_response_callback(self, command, callback): "Set a custom Response Callback" diff --git a/redis/connection.py b/redis/connection.py index d60b2cf039..44a9922b43 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1045,7 +1045,10 @@ def __repr__(self): ) def __eq__(self, other): - return self.connection_kwargs == other.connection_kwargs + return ( + isinstance(other, self.__class__) + and self.connection_kwargs == other.connection_kwargs + ) def reset(self): self.pid = os.getpid() diff --git a/tests/test_client.py b/tests/test_client.py index f9c0c89dde..e8f79b1be5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -7,6 +7,10 @@ def test_client_equality(self): r2 = redis.Redis.from_url('redis://localhost:6379/9') assert r1 == r2 + def test_clients_unequal_if_different_types(self): + r = redis.Redis.from_url('redis://localhost:6379/9') + assert r != 0 + def test_clients_unequal_if_different_hosts(self): r1 = redis.Redis.from_url('redis://localhost:6379/9') r2 = redis.Redis.from_url('redis://127.0.0.1:6379/9') diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 9cdae5fef8..406b5dbc5b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -85,6 +85,12 @@ def test_pool_equality(self): connection_class=redis.Connection) assert pool1 == pool2 + def test_pools_unequal_if_different_types(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool != 0 + def test_pools_unequal_if_different_hosts(self): connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} connection_kwargs2 = {'host': '127.0.0.1', 'port': 6379, 'db': 1}