Skip to content

Commit 0168b44

Browse files
committed
Thread-safe connection pool
redis/redis-py#909
1 parent 36d9d21 commit 0168b44

File tree

3 files changed

+213
-45
lines changed

3 files changed

+213
-45
lines changed

README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ Change log
9090
v2.1.0
9191
~~~~~~
9292

93+
* Thread-safe implementation of the sentinel connection pool, so only one pool per process is now used.
9394
* Added `disconnect()` method for resetting the connection pool
9495

9596
v2.0.1

flask_redis_sentinel.py

Lines changed: 210 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,175 @@
1414

1515
import six
1616
import inspect
17+
import random
18+
import threading
19+
import logging
20+
import weakref
1721
import redis
1822
import redis.sentinel
1923
import redis_sentinel_url
24+
from redis._compat import nativestr
2025
from flask import current_app
21-
from werkzeug.local import Local, LocalProxy
26+
from redis.exceptions import ConnectionError, ReadOnlyError
27+
from werkzeug.local import LocalProxy
2228
from werkzeug.utils import import_string
23-
from six.moves import queue
29+
30+
logger = logging.getLogger(__name__)
2431

2532

2633
_EXTENSION_KEY = 'redissentinel'
2734

2835

36+
class SentinelManagedConnection(redis.Connection):
37+
def __init__(self, **kwargs):
38+
self.connection_pool = kwargs.pop('connection_pool')
39+
super(SentinelManagedConnection, self).__init__(**kwargs)
40+
41+
def __repr__(self):
42+
pool = self.connection_pool
43+
s = '%s<service=%s%%s>' % (type(self).__name__, pool.service_name)
44+
if self.host:
45+
host_info = ',host=%s,port=%s' % (self.host, self.port)
46+
s = s % host_info
47+
return s
48+
49+
def connect_to(self, address):
50+
self.host, self.port = address
51+
super(SentinelManagedConnection, self).connect()
52+
if self.connection_pool.check_connection:
53+
self.send_command('PING')
54+
if nativestr(self.read_response()) != 'PONG':
55+
raise ConnectionError('PING failed')
56+
57+
def connect(self):
58+
if self._sock:
59+
return # already connected
60+
if self.connection_pool.is_master:
61+
self.connect_to(self.connection_pool.get_master_address())
62+
else:
63+
for slave in self.connection_pool.rotate_slaves():
64+
try:
65+
return self.connect_to(slave)
66+
except ConnectionError:
67+
continue
68+
raise SlaveNotFoundError # Never be here
69+
70+
def read_response(self):
71+
try:
72+
return super(SentinelManagedConnection, self).read_response()
73+
except ReadOnlyError:
74+
if self.connection_pool.is_master:
75+
# When talking to a master, a ReadOnlyError when likely
76+
# indicates that the previous master that we're still connected
77+
# to has been demoted to a slave and there's a new master.
78+
raise ConnectionError('The previous master is now a slave')
79+
raise
80+
81+
82+
class SentinelConnectionPool(redis.ConnectionPool):
83+
"""
84+
Sentinel backed connection pool.
85+
86+
If ``check_connection`` flag is set to True, SentinelManagedConnection
87+
sends a PING command right after establishing the connection.
88+
"""
89+
90+
def __init__(self, service_name, sentinel_manager, **kwargs):
91+
kwargs['connection_class'] = kwargs.get(
92+
'connection_class', SentinelManagedConnection)
93+
self.is_master = kwargs.pop('is_master', True)
94+
self.check_connection = kwargs.pop('check_connection', False)
95+
super(SentinelConnectionPool, self).__init__(**kwargs)
96+
self.connection_kwargs['connection_pool'] = weakref.proxy(self)
97+
self.service_name = service_name
98+
self.sentinel_manager = sentinel_manager
99+
100+
def __repr__(self):
101+
return "%s<service=%s(%s)" % (
102+
type(self).__name__,
103+
self.service_name,
104+
self.is_master and 'master' or 'slave',
105+
)
106+
107+
def reset(self):
108+
super(SentinelConnectionPool, self).reset()
109+
self.master_address = None
110+
self.slave_rr_counter = None
111+
112+
def get_master_address(self):
113+
"""Get the address of the current master"""
114+
master_address = self.sentinel_manager.discover_master(
115+
self.service_name)
116+
if self.is_master:
117+
if master_address != self.master_address:
118+
self.master_address = master_address
119+
return master_address
120+
121+
def rotate_slaves(self):
122+
"Round-robin slave balancer"
123+
slaves = self.sentinel_manager.discover_slaves(self.service_name)
124+
if slaves:
125+
if self.slave_rr_counter is None:
126+
self.slave_rr_counter = random.randint(0, len(slaves) - 1)
127+
for _ in xrange(len(slaves)):
128+
self.slave_rr_counter = (
129+
self.slave_rr_counter + 1) % len(slaves)
130+
slave = slaves[self.slave_rr_counter]
131+
yield slave
132+
# Fallback to the master connection
133+
try:
134+
yield self.get_master_address()
135+
except MasterNotFoundError:
136+
pass
137+
raise SlaveNotFoundError('No slave found for %r' % (self.service_name))
138+
139+
def _check_connection(self, connection):
140+
if self.is_master and self.master_address != (connection.host, connection.port):
141+
# this is not a connection to the current master, stop using it
142+
connection.disconnect()
143+
return False
144+
return True
145+
146+
def get_connection(self, command_name, *keys, **options):
147+
"Get a connection from the pool"
148+
self._checkpid()
149+
while True:
150+
try:
151+
connection = self._available_connections.pop()
152+
except IndexError:
153+
connection = self.make_connection()
154+
else:
155+
if not self._check_connection(connection):
156+
continue
157+
self._in_use_connections.add(connection)
158+
return connection
159+
160+
def release(self, connection):
161+
"Releases the connection back to the pool"
162+
self._checkpid()
163+
if connection.pid != self.pid:
164+
return
165+
self._in_use_connections.remove(connection)
166+
if not self._check_connection(connection):
167+
return
168+
self._available_connections.append(connection)
169+
170+
171+
class Sentinel(redis.sentinel.Sentinel):
172+
173+
def master_for(self, service_name, redis_class=redis.StrictRedis,
174+
connection_pool_class=SentinelConnectionPool, **kwargs):
175+
return super(Sentinel, self).master_for(
176+
service_name, redis_class=redis_class,
177+
connection_pool_class=connection_pool_class, **kwargs)
178+
179+
def slave_for(self, service_name, redis_class=redis.StrictRedis,
180+
connection_pool_class=SentinelConnectionPool, **kwargs):
181+
return super(Sentinel, self).slave_for(
182+
service_name, redis_class=redis_class,
183+
connection_pool_class=connection_pool_class, **kwargs)
184+
185+
29186
class RedisSentinelInstance(object):
30187

31188
def __init__(self, url, client_class, client_options, sentinel_class, sentinel_options):
@@ -34,27 +191,34 @@ def __init__(self, url, client_class, client_options, sentinel_class, sentinel_o
34191
self.client_options = client_options
35192
self.sentinel_class = sentinel_class
36193
self.sentinel_options = sentinel_options
37-
self.local = Local()
38-
self._connections = queue.Queue()
194+
self.connection = None
195+
self.master_connections = {}
196+
self.slave_connections = {}
197+
self._connect_lock = threading.Lock()
39198
self._connect()
40-
if self.local.connection[0] is None:
41-
# if there is no sentinel, we don't need to use thread-local storage
42-
self.connection = self.local.connection
43-
self.local = self
44199

45200
def _connect(self):
46-
try:
47-
return self.local.connection
48-
except AttributeError:
201+
with self._connect_lock:
202+
if self.connection is not None:
203+
return self.connection
204+
49205
conn = redis_sentinel_url.connect(
50206
self.url,
51207
sentinel_class=self.sentinel_class, sentinel_options=self.sentinel_options,
52208
client_class=self.client_class, client_options=self.client_options)
53-
self.local.connection = conn
54-
self._connections.put(conn[0])
55-
self._connections.put(conn[1])
209+
self.connection = conn
56210
return conn
57211

212+
def _iter_connections(self):
213+
if self.connection is not None:
214+
for conn in self.connection:
215+
if conn is not None:
216+
yield conn
217+
for conn in six.itervalues(self.master_connections):
218+
yield conn
219+
for conn in six.itervalues(self.slave_connections):
220+
yield conn
221+
58222
@property
59223
def sentinel(self):
60224
return self._connect()[0]
@@ -64,50 +228,53 @@ def default_connection(self):
64228
return self._connect()[1]
65229

66230
def master_for(self, service_name, **kwargs):
67-
try:
68-
return self.local.master_connections[service_name]
69-
except AttributeError:
70-
self.local.master_connections = {}
71-
except KeyError:
72-
pass
231+
with self._connect_lock:
232+
try:
233+
return self.master_connections[service_name]
234+
except KeyError:
235+
pass
73236

74237
sentinel = self.sentinel
75238
if sentinel is None:
76239
msg = 'Cannot get master {} using non-sentinel configuration'
77240
raise RuntimeError(msg.format(service_name))
78241

79-
conn = sentinel.master_for(service_name, redis_class=self.client_class, **kwargs)
80-
self.local.master_connections[service_name] = conn
81-
self._connections.put(conn)
82-
return conn
242+
with self._connect_lock:
243+
try:
244+
return self.master_connections[service_name]
245+
except KeyError:
246+
pass
247+
248+
conn = sentinel.master_for(service_name, redis_class=self.client_class, **kwargs)
249+
self.master_connections[service_name] = conn
250+
return conn
83251

84252
def slave_for(self, service_name, **kwargs):
85-
try:
86-
return self.local.slave_connections[service_name]
87-
except AttributeError:
88-
self.local.slave_connections = {}
89-
except KeyError:
90-
pass
253+
with self._connect_lock:
254+
try:
255+
return self.slave_connections[service_name]
256+
except KeyError:
257+
pass
91258

92259
sentinel = self.sentinel
93260
if sentinel is None:
94261
msg = 'Cannot get slave {} using non-sentinel configuration'
95262
raise RuntimeError(msg.format(service_name))
96263

97-
conn = sentinel.slave_for(service_name, redis_class=self.client_class, **kwargs)
98-
self.local.slave_connections[service_name] = conn
99-
self._connections.put(conn)
100-
return conn
264+
with self._connect_lock:
265+
try:
266+
return self.slave_connections[service_name]
267+
except KeyError:
268+
pass
269+
270+
conn = sentinel.slave_for(service_name, redis_class=self.client_class, **kwargs)
271+
self.slave_connections[service_name] = conn
272+
return conn
101273

102274
def disconnect(self):
103-
while True:
104-
try:
105-
conn = self._connections.get_nowait()
106-
except queue.Empty:
107-
break
108-
else:
109-
if conn is not None:
110-
conn.connection_pool.disconnect()
275+
with self._connect_lock:
276+
for conn in self._iter_connections():
277+
conn.connection_pool.disconnect()
111278

112279

113280
class RedisSentinel(object):
@@ -143,7 +310,7 @@ def init_app(self, app, config_prefix=None, client_class=None, sentinel_class=No
143310
client_class = self._resolve_class(
144311
config, 'CLASS', 'client_class', client_class, redis.StrictRedis)
145312
sentinel_class = self._resolve_class(
146-
config, 'SENTINEL_CLASS', 'sentinel_class', sentinel_class, redis.sentinel.Sentinel)
313+
config, 'SENTINEL_CLASS', 'sentinel_class', sentinel_class, Sentinel)
147314

148315
url = config.pop('URL')
149316
client_options = self._config_from_variables(config, client_class)

test_flask_redis_sentinel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ def test_sentinel_threads(self):
302302
sentinel.init_app(self.app)
303303

304304
connections = self._check_threads(sentinel)
305-
self.assertIsNot(connections['from_another_thread'], connections['from_main_thread'])
306-
self.assertIsNot(connections['from_another_thread'], connections['from_main_thread_later'])
305+
self.assertIs(connections['from_another_thread'], connections['from_main_thread'])
306+
self.assertIs(connections['from_another_thread'], connections['from_main_thread_later'])
307307
self.assertIs(connections['from_main_thread'], connections['from_main_thread_later'])
308308

309309
def test_redis_threads(self):

0 commit comments

Comments
 (0)