@@ -23,6 +23,54 @@ class SlaveNotFoundError(ConnectionError):
2323 pass
2424
2525
26+ class AsyncSentinelConnectionPoolProxy :
27+ def __init__ (
28+ self ,
29+ connection_pool ,
30+ is_master ,
31+ check_connection ,
32+ service_name ,
33+ sentinel_manager ,
34+ ):
35+ self .connection_pool_ref = weakref .ref (connection_pool )
36+ self .is_master = is_master
37+ self .check_connection = check_connection
38+ self .service_name = service_name
39+ self .sentinel_manager = sentinel_manager
40+ self .reset ()
41+
42+ def reset (self ):
43+ self .master_address = None
44+ self .slave_rr_counter = None
45+
46+ async def get_master_address (self ):
47+ master_address = await self .sentinel_manager .discover_master (self .service_name )
48+ if self .is_master and self .master_address != master_address :
49+ self .master_address = master_address
50+ # disconnect any idle connections so that they reconnect
51+ # to the new master the next time that they are used.
52+ connection_pool = self .connection_pool_ref ()
53+ if connection_pool is not None :
54+ await connection_pool .disconnect (inuse_connections = False )
55+ return master_address
56+
57+ async def rotate_slaves (self ) -> AsyncIterator :
58+ slaves = await self .sentinel_manager .discover_slaves (self .service_name )
59+ if slaves :
60+ if self .slave_rr_counter is None :
61+ self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
62+ for _ in range (len (slaves )):
63+ self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
64+ slave = slaves [self .slave_rr_counter ]
65+ yield slave
66+ # Fallback to the master connection
67+ try :
68+ yield await self .get_master_address ()
69+ except MasterNotFoundError :
70+ pass
71+ raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
72+
73+
2674class SentinelManagedConnection (Connection ):
2775 def __init__ (self , ** kwargs ):
2876 self .connection_pool = kwargs .pop ("connection_pool" )
@@ -116,12 +164,17 @@ def __init__(self, service_name, sentinel_manager, **kwargs):
116164 )
117165 self .is_master = kwargs .pop ("is_master" , True )
118166 self .check_connection = kwargs .pop ("check_connection" , False )
167+ self .proxy = AsyncSentinelConnectionPoolProxy (
168+ connection_pool = self ,
169+ is_master = self .is_master ,
170+ check_connection = self .check_connection ,
171+ service_name = service_name ,
172+ sentinel_manager = sentinel_manager ,
173+ )
119174 super ().__init__ (** kwargs )
120- self .connection_kwargs ["connection_pool" ] = weakref .proxy ( self )
175+ self .connection_kwargs ["connection_pool" ] = self .proxy
121176 self .service_name = service_name
122177 self .sentinel_manager = sentinel_manager
123- self .master_address = None
124- self .slave_rr_counter = None
125178
126179 def __repr__ (self ):
127180 return (
@@ -131,8 +184,11 @@ def __repr__(self):
131184
132185 def reset (self ):
133186 super ().reset ()
134- self .master_address = None
135- self .slave_rr_counter = None
187+ self .proxy .reset ()
188+
189+ @property
190+ def master_address (self ):
191+ return self .proxy .master_address
136192
137193 def owns_connection (self , connection : Connection ):
138194 check = not self .is_master or (
@@ -141,31 +197,12 @@ def owns_connection(self, connection: Connection):
141197 return check and super ().owns_connection (connection )
142198
143199 async def get_master_address (self ):
144- master_address = await self .sentinel_manager .discover_master (self .service_name )
145- if self .is_master :
146- if self .master_address != master_address :
147- self .master_address = master_address
148- # disconnect any idle connections so that they reconnect
149- # to the new master the next time that they are used.
150- await self .disconnect (inuse_connections = False )
151- return master_address
200+ return await self .proxy .get_master_address ()
152201
153202 async def rotate_slaves (self ) -> AsyncIterator :
154203 """Round-robin slave balancer"""
155- slaves = await self .sentinel_manager .discover_slaves (self .service_name )
156- if slaves :
157- if self .slave_rr_counter is None :
158- self .slave_rr_counter = random .randint (0 , len (slaves ) - 1 )
159- for _ in range (len (slaves )):
160- self .slave_rr_counter = (self .slave_rr_counter + 1 ) % len (slaves )
161- slave = slaves [self .slave_rr_counter ]
162- yield slave
163- # Fallback to the master connection
164- try :
165- yield await self .get_master_address ()
166- except MasterNotFoundError :
167- pass
168- raise SlaveNotFoundError (f"No slave found for { self .service_name !r} " )
204+ async for x in self .proxy .rotate_slaves ():
205+ yield x
169206
170207
171208class Sentinel (AsyncSentinelCommands ):
0 commit comments