diff --git a/redis/client.py b/redis/client.py index 5adaa7ca65..d77ef0121e 100755 --- a/redis/client.py +++ b/redis/client.py @@ -707,6 +707,12 @@ def __init__(self, host='localhost', port=6379, def __repr__(self): return "%s<%s>" % (type(self).__name__, repr(self.connection_pool)) + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.connection_pool == other.connection_pool + ) + def set_response_callback(self, command, callback): "Set a custom Response Callback" self.response_callbacks[command] = callback diff --git a/redis/connection.py b/redis/connection.py index feea041288..44a9922b43 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1044,6 +1044,12 @@ def __repr__(self): repr(self.connection_class(**self.connection_kwargs)), ) + def __eq__(self, other): + return ( + isinstance(other, self.__class__) + and self.connection_kwargs == other.connection_kwargs + ) + def reset(self): self.pid = os.getpid() self._created_connections = 0 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000000..e8f79b1be5 --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,27 @@ +import redis + + +class TestClient(object): + def test_client_equality(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6379/9') + assert r1 == r2 + + def test_clients_unequal_if_different_types(self): + r = redis.Redis.from_url('redis://localhost:6379/9') + assert r != 0 + + def test_clients_unequal_if_different_hosts(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://127.0.0.1:6379/9') + assert r1 != r2 + + def test_clients_unequal_if_different_ports(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/9') + assert r1 != r2 + + def test_clients_unequal_if_different_dbs(self): + r1 = redis.Redis.from_url('redis://localhost:6379/9') + r2 = redis.Redis.from_url('redis://localhost:6380/10') + assert r1 != r2 diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index f580f71d52..406b5dbc5b 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -77,6 +77,47 @@ def test_repr_contains_db_info_unix(self): expected = 'ConnectionPool>' assert repr(pool) == expected + def test_pool_equality(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool1 == pool2 + + def test_pools_unequal_if_different_types(self): + connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1} + pool = self.get_pool(connection_kwargs=connection_kwargs, + connection_class=redis.Connection) + assert pool != 0 + + def test_pools_unequal_if_different_hosts(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': '127.0.0.1', 'port': 6379, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_ports(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6380, 'db': 1} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + + def test_pools_unequal_if_different_dbs(self): + connection_kwargs1 = {'host': 'localhost', 'port': 6379, 'db': 1} + connection_kwargs2 = {'host': 'localhost', 'port': 6379, 'db': 2} + pool1 = self.get_pool(connection_kwargs=connection_kwargs1, + connection_class=redis.Connection) + pool2 = self.get_pool(connection_kwargs=connection_kwargs2, + connection_class=redis.Connection) + assert pool1 != pool2 + class TestBlockingConnectionPool(object): def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):