diff --git a/redis/connection.py b/redis/connection.py index 6ff3650805..49a202e3af 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -1084,6 +1084,10 @@ class initializer. In the case of conflicting arguments, querystring arguments always win. """ url_options = parse_url(url) + + if "connection_class" in kwargs: + url_options["connection_class"] = kwargs["connection_class"] + kwargs.update(url_options) return cls(**kwargs) diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 288d43dfd7..b423835ec8 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -430,6 +430,15 @@ def test_extra_querystring_options(self): 'b': '2' } + def test_connection_class_override(self): + class MyConnection(redis.UnixDomainSocketConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'unix:///socket', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + @pytest.mark.skipif(not ssl_available, reason="SSL not installed") class TestSSLConnectionURLParsing: @@ -440,6 +449,15 @@ def test_host(self): 'host': 'my.host', } + def test_connection_class_override(self): + class MyConnection(redis.SSLConnection): + pass + + pool = redis.ConnectionPool.from_url( + 'rediss://my.host', connection_class=MyConnection + ) + assert pool.connection_class == MyConnection + def test_cert_reqs_options(self): import ssl