17
17
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
18
# See the License for the specific language governing permissions and
19
19
# limitations under the License.
20
-
21
-
20
+ from abc import abstractmethod
21
+ from sys import maxsize
22
22
from threading import Lock
23
23
from time import clock
24
24
25
25
from neo4j .addressing import SocketAddress , resolve
26
26
from neo4j .bolt import ConnectionPool , ServiceUnavailable , ProtocolError , DEFAULT_PORT , connect
27
27
from neo4j .compat .collections import MutableSet , OrderedDict
28
28
from neo4j .exceptions import CypherError
29
+ from neo4j .util import ServerVersion
29
30
from neo4j .v1 .api import Driver , READ_ACCESS , WRITE_ACCESS , fix_statement , fix_parameters
30
31
from neo4j .v1 .exceptions import SessionExpired
31
32
from neo4j .v1 .security import SecurityPlan
32
33
from neo4j .v1 .session import BoltSession
33
- from neo4j .util import ServerVersion
34
34
35
35
36
- class RoundRobinSet (MutableSet ):
36
+ LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
37
+ LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
38
+ LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
39
+
40
+
41
+ class OrderedSet (MutableSet ):
37
42
38
43
def __init__ (self , elements = ()):
39
44
self ._elements = OrderedDict .fromkeys (elements )
@@ -45,22 +50,15 @@ def __repr__(self):
45
50
def __contains__ (self , element ):
46
51
return element in self ._elements
47
52
48
- def __next__ (self ):
49
- current = None
50
- if self ._elements :
51
- if self ._current is None :
52
- self ._current = 0
53
- else :
54
- self ._current = (self ._current + 1 ) % len (self ._elements )
55
- current = list (self ._elements .keys ())[self ._current ]
56
- return current
57
-
58
53
def __iter__ (self ):
59
54
return iter (self ._elements )
60
55
61
56
def __len__ (self ):
62
57
return len (self ._elements )
63
58
59
+ def __getitem__ (self , index ):
60
+ return list (self ._elements .keys ())[index ]
61
+
64
62
def add (self , element ):
65
63
self ._elements [element ] = None
66
64
@@ -73,9 +71,6 @@ def discard(self, element):
73
71
except KeyError :
74
72
pass
75
73
76
- def next (self ):
77
- return self .__next__ ()
78
-
79
74
def remove (self , element ):
80
75
try :
81
76
del self ._elements [element ]
@@ -126,9 +121,9 @@ def parse_routing_info(cls, records):
126
121
return cls (routers , readers , writers , ttl )
127
122
128
123
def __init__ (self , routers = (), readers = (), writers = (), ttl = 0 ):
129
- self .routers = RoundRobinSet (routers )
130
- self .readers = RoundRobinSet (readers )
131
- self .writers = RoundRobinSet (writers )
124
+ self .routers = OrderedSet (routers )
125
+ self .readers = OrderedSet (readers )
126
+ self .writers = OrderedSet (writers )
132
127
self .last_updated_time = self .timer ()
133
128
self .ttl = ttl
134
129
@@ -168,17 +163,102 @@ def __run__(self, ignored, routing_context):
168
163
return self ._run (fix_statement (statement ), fix_parameters (parameters ))
169
164
170
165
166
+ class LoadBalancingStrategy (object ):
167
+
168
+ @classmethod
169
+ def build (cls , connection_pool , ** config ):
170
+ load_balancing_strategy = config .get ("load_balancing_strategy" , LOAD_BALANCING_STRATEGY_DEFAULT )
171
+ if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED :
172
+ return LeastConnectedLoadBalancingStrategy (connection_pool )
173
+ elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN :
174
+ return RoundRobinLoadBalancingStrategy ()
175
+ else :
176
+ raise ValueError ("Unknown load balancing strategy '%s'" % load_balancing_strategy )
177
+
178
+ @abstractmethod
179
+ def select_reader (self , known_readers ):
180
+ raise NotImplementedError ()
181
+
182
+ @abstractmethod
183
+ def select_writer (self , known_writers ):
184
+ raise NotImplementedError ()
185
+
186
+
187
+ class RoundRobinLoadBalancingStrategy (LoadBalancingStrategy ):
188
+
189
+ _readers_offset = 0
190
+ _writers_offset = 0
191
+
192
+ def select_reader (self , known_readers ):
193
+ address = self ._select (self ._readers_offset , known_readers )
194
+ self ._readers_offset += 1
195
+ return address
196
+
197
+ def select_writer (self , known_writers ):
198
+ address = self ._select (self ._writers_offset , known_writers )
199
+ self ._writers_offset += 1
200
+ return address
201
+
202
+ @classmethod
203
+ def _select (cls , offset , addresses ):
204
+ if not addresses :
205
+ return None
206
+ return addresses [offset % len (addresses )]
207
+
208
+
209
+ class LeastConnectedLoadBalancingStrategy (LoadBalancingStrategy ):
210
+
211
+ def __init__ (self , connection_pool ):
212
+ self ._readers_offset = 0
213
+ self ._writers_offset = 0
214
+ self ._connection_pool = connection_pool
215
+
216
+ def select_reader (self , known_readers ):
217
+ address = self ._select (self ._readers_offset , known_readers )
218
+ self ._readers_offset += 1
219
+ return address
220
+
221
+ def select_writer (self , known_writers ):
222
+ address = self ._select (self ._writers_offset , known_writers )
223
+ self ._writers_offset += 1
224
+ return address
225
+
226
+ def _select (self , offset , addresses ):
227
+ if not addresses :
228
+ return None
229
+ num_addresses = len (addresses )
230
+ start_index = offset % num_addresses
231
+ index = start_index
232
+
233
+ least_connected_address = None
234
+ least_in_use_connections = maxsize
235
+
236
+ while True :
237
+ address = addresses [index ]
238
+ index = (index + 1 ) % num_addresses
239
+
240
+ in_use_connections = self ._connection_pool .in_use_connection_count (address )
241
+
242
+ if in_use_connections < least_in_use_connections :
243
+ least_connected_address = address
244
+ least_in_use_connections = in_use_connections
245
+
246
+ if index == start_index :
247
+ return least_connected_address
248
+
249
+
171
250
class RoutingConnectionPool (ConnectionPool ):
172
251
""" Connection pool with routing table.
173
252
"""
174
253
175
- def __init__ (self , connector , initial_address , routing_context , * routers ):
254
+ def __init__ (self , connector , initial_address , routing_context , * routers , ** config ):
176
255
super (RoutingConnectionPool , self ).__init__ (connector )
177
256
self .initial_address = initial_address
178
257
self .routing_context = routing_context
179
258
self .routing_table = RoutingTable (routers )
180
259
self .missing_writer = False
181
260
self .refresh_lock = Lock ()
261
+ self .load_balancing_strategy = LoadBalancingStrategy .build (self , ** config )
182
262
183
263
def fetch_routing_info (self , address ):
184
264
""" Fetch raw routing info from a given router address.
@@ -304,14 +384,16 @@ def acquire(self, access_mode=None):
304
384
access_mode = WRITE_ACCESS
305
385
if access_mode == READ_ACCESS :
306
386
server_list = self .routing_table .readers
387
+ server_selector = self .load_balancing_strategy .select_reader
307
388
elif access_mode == WRITE_ACCESS :
308
389
server_list = self .routing_table .writers
390
+ server_selector = self .load_balancing_strategy .select_writer
309
391
else :
310
392
raise ValueError ("Unsupported access mode {}" .format (access_mode ))
311
393
312
394
self .ensure_routing_table_is_fresh (access_mode )
313
395
while True :
314
- address = next (server_list )
396
+ address = server_selector (server_list )
315
397
if address is None :
316
398
break
317
399
try :
@@ -354,7 +436,7 @@ def __init__(self, uri, **config):
354
436
def connector (a ):
355
437
return connect (a , security_plan .ssl_context , ** config )
356
438
357
- pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ))
439
+ pool = RoutingConnectionPool (connector , initial_address , routing_context , * resolve (initial_address ), ** config )
358
440
try :
359
441
pool .update_routing_table ()
360
442
except :
0 commit comments