Skip to content

Commit 4b0e313

Browse files
authored
Merge pull request #174 from neo4j/1.5-least-connected
Load balancing strategies and least connected
2 parents 871009c + bb7fc5e commit 4b0e313

File tree

6 files changed

+314
-87
lines changed

6 files changed

+314
-87
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ htmlcov
99
.coverage
1010
.test
1111
.tox
12+
.benchmarks
1213
.cache
1314

1415
docs/build

neo4j/bolt/connection.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,17 @@ def release(self, connection):
418418
with self.lock:
419419
connection.in_use = False
420420

421+
def in_use_connection_count(self, address):
422+
""" Count the number of connections currently in use to a given
423+
address.
424+
"""
425+
try:
426+
connections = self.connections[address]
427+
except KeyError:
428+
return 0
429+
else:
430+
return sum(1 if connection.in_use else 0 for connection in connections)
431+
421432
def remove(self, address):
422433
""" Remove an address from the connection pool, if present, closing
423434
all connections to that address.

neo4j/v1/routing.py

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,28 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
21-
20+
from abc import abstractmethod
21+
from sys import maxsize
2222
from threading import Lock
2323
from time import clock
2424

2525
from neo4j.addressing import SocketAddress, resolve
2626
from neo4j.bolt import ConnectionPool, ServiceUnavailable, ProtocolError, DEFAULT_PORT, connect
2727
from neo4j.compat.collections import MutableSet, OrderedDict
2828
from neo4j.exceptions import CypherError
29+
from neo4j.util import ServerVersion
2930
from neo4j.v1.api import Driver, READ_ACCESS, WRITE_ACCESS, fix_statement, fix_parameters
3031
from neo4j.v1.exceptions import SessionExpired
3132
from neo4j.v1.security import SecurityPlan
3233
from neo4j.v1.session import BoltSession
33-
from neo4j.util import ServerVersion
3434

3535

36-
class RoundRobinSet(MutableSet):
36+
LOAD_BALANCING_STRATEGY_LEAST_CONNECTED = 0
37+
LOAD_BALANCING_STRATEGY_ROUND_ROBIN = 1
38+
LOAD_BALANCING_STRATEGY_DEFAULT = LOAD_BALANCING_STRATEGY_LEAST_CONNECTED
39+
40+
41+
class OrderedSet(MutableSet):
3742

3843
def __init__(self, elements=()):
3944
self._elements = OrderedDict.fromkeys(elements)
@@ -45,22 +50,15 @@ def __repr__(self):
4550
def __contains__(self, element):
4651
return element in self._elements
4752

48-
def __next__(self):
49-
current = None
50-
if self._elements:
51-
if self._current is None:
52-
self._current = 0
53-
else:
54-
self._current = (self._current + 1) % len(self._elements)
55-
current = list(self._elements.keys())[self._current]
56-
return current
57-
5853
def __iter__(self):
5954
return iter(self._elements)
6055

6156
def __len__(self):
6257
return len(self._elements)
6358

59+
def __getitem__(self, index):
60+
return list(self._elements.keys())[index]
61+
6462
def add(self, element):
6563
self._elements[element] = None
6664

@@ -73,9 +71,6 @@ def discard(self, element):
7371
except KeyError:
7472
pass
7573

76-
def next(self):
77-
return self.__next__()
78-
7974
def remove(self, element):
8075
try:
8176
del self._elements[element]
@@ -126,9 +121,9 @@ def parse_routing_info(cls, records):
126121
return cls(routers, readers, writers, ttl)
127122

128123
def __init__(self, routers=(), readers=(), writers=(), ttl=0):
129-
self.routers = RoundRobinSet(routers)
130-
self.readers = RoundRobinSet(readers)
131-
self.writers = RoundRobinSet(writers)
124+
self.routers = OrderedSet(routers)
125+
self.readers = OrderedSet(readers)
126+
self.writers = OrderedSet(writers)
132127
self.last_updated_time = self.timer()
133128
self.ttl = ttl
134129

@@ -168,17 +163,102 @@ def __run__(self, ignored, routing_context):
168163
return self._run(fix_statement(statement), fix_parameters(parameters))
169164

170165

166+
class LoadBalancingStrategy(object):
167+
168+
@classmethod
169+
def build(cls, connection_pool, **config):
170+
load_balancing_strategy = config.get("load_balancing_strategy", LOAD_BALANCING_STRATEGY_DEFAULT)
171+
if load_balancing_strategy == LOAD_BALANCING_STRATEGY_LEAST_CONNECTED:
172+
return LeastConnectedLoadBalancingStrategy(connection_pool)
173+
elif load_balancing_strategy == LOAD_BALANCING_STRATEGY_ROUND_ROBIN:
174+
return RoundRobinLoadBalancingStrategy()
175+
else:
176+
raise ValueError("Unknown load balancing strategy '%s'" % load_balancing_strategy)
177+
178+
@abstractmethod
179+
def select_reader(self, known_readers):
180+
raise NotImplementedError()
181+
182+
@abstractmethod
183+
def select_writer(self, known_writers):
184+
raise NotImplementedError()
185+
186+
187+
class RoundRobinLoadBalancingStrategy(LoadBalancingStrategy):
188+
189+
_readers_offset = 0
190+
_writers_offset = 0
191+
192+
def select_reader(self, known_readers):
193+
address = self._select(self._readers_offset, known_readers)
194+
self._readers_offset += 1
195+
return address
196+
197+
def select_writer(self, known_writers):
198+
address = self._select(self._writers_offset, known_writers)
199+
self._writers_offset += 1
200+
return address
201+
202+
@classmethod
203+
def _select(cls, offset, addresses):
204+
if not addresses:
205+
return None
206+
return addresses[offset % len(addresses)]
207+
208+
209+
class LeastConnectedLoadBalancingStrategy(LoadBalancingStrategy):
210+
211+
def __init__(self, connection_pool):
212+
self._readers_offset = 0
213+
self._writers_offset = 0
214+
self._connection_pool = connection_pool
215+
216+
def select_reader(self, known_readers):
217+
address = self._select(self._readers_offset, known_readers)
218+
self._readers_offset += 1
219+
return address
220+
221+
def select_writer(self, known_writers):
222+
address = self._select(self._writers_offset, known_writers)
223+
self._writers_offset += 1
224+
return address
225+
226+
def _select(self, offset, addresses):
227+
if not addresses:
228+
return None
229+
num_addresses = len(addresses)
230+
start_index = offset % num_addresses
231+
index = start_index
232+
233+
least_connected_address = None
234+
least_in_use_connections = maxsize
235+
236+
while True:
237+
address = addresses[index]
238+
index = (index + 1) % num_addresses
239+
240+
in_use_connections = self._connection_pool.in_use_connection_count(address)
241+
242+
if in_use_connections < least_in_use_connections:
243+
least_connected_address = address
244+
least_in_use_connections = in_use_connections
245+
246+
if index == start_index:
247+
return least_connected_address
248+
249+
171250
class RoutingConnectionPool(ConnectionPool):
172251
""" Connection pool with routing table.
173252
"""
174253

175-
def __init__(self, connector, initial_address, routing_context, *routers):
254+
def __init__(self, connector, initial_address, routing_context, *routers, **config):
176255
super(RoutingConnectionPool, self).__init__(connector)
177256
self.initial_address = initial_address
178257
self.routing_context = routing_context
179258
self.routing_table = RoutingTable(routers)
180259
self.missing_writer = False
181260
self.refresh_lock = Lock()
261+
self.load_balancing_strategy = LoadBalancingStrategy.build(self, **config)
182262

183263
def fetch_routing_info(self, address):
184264
""" Fetch raw routing info from a given router address.
@@ -304,14 +384,16 @@ def acquire(self, access_mode=None):
304384
access_mode = WRITE_ACCESS
305385
if access_mode == READ_ACCESS:
306386
server_list = self.routing_table.readers
387+
server_selector = self.load_balancing_strategy.select_reader
307388
elif access_mode == WRITE_ACCESS:
308389
server_list = self.routing_table.writers
390+
server_selector = self.load_balancing_strategy.select_writer
309391
else:
310392
raise ValueError("Unsupported access mode {}".format(access_mode))
311393

312394
self.ensure_routing_table_is_fresh(access_mode)
313395
while True:
314-
address = next(server_list)
396+
address = server_selector(server_list)
315397
if address is None:
316398
break
317399
try:
@@ -354,7 +436,7 @@ def __init__(self, uri, **config):
354436
def connector(a):
355437
return connect(a, security_plan.ssl_context, **config)
356438

357-
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address))
439+
pool = RoutingConnectionPool(connector, initial_address, routing_context, *resolve(initial_address), **config)
358440
try:
359441
pool.update_routing_table()
360442
except:

test/integration/test_connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,11 @@ def test_cannot_acquire_after_close(self):
108108
pool.close()
109109
with self.assertRaises(ServiceUnavailable):
110110
_ = pool.acquire_direct("X")
111+
112+
def test_in_use_count(self):
113+
address = ("127.0.0.1", 7687)
114+
self.assertEqual(self.pool.in_use_connection_count(address), 0)
115+
connection = self.pool.acquire_direct(address)
116+
self.assertEqual(self.pool.in_use_connection_count(address), 1)
117+
self.pool.release(connection)
118+
self.assertEqual(self.pool.in_use_connection_count(address), 0)

test/stub/test_routingdriver.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
# limitations under the License.
2020

2121

22-
from neo4j.v1 import GraphDatabase, RoutingDriver, READ_ACCESS, WRITE_ACCESS, SessionExpired
22+
from neo4j.v1 import GraphDatabase, READ_ACCESS, WRITE_ACCESS, SessionExpired, \
23+
RoutingDriver, RoutingConnectionPool, LeastConnectedLoadBalancingStrategy, LOAD_BALANCING_STRATEGY_ROUND_ROBIN, \
24+
RoundRobinLoadBalancingStrategy
2325
from neo4j.bolt import ProtocolError, ServiceUnavailable
2426

2527
from test.stub.tools import StubTestCase, StubCluster
@@ -215,3 +217,27 @@ def test_should_error_when_missing_reader(self):
215217
uri = "bolt+routing://127.0.0.1:9001"
216218
with self.assertRaises(ProtocolError):
217219
GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False)
220+
221+
def test_default_load_balancing_strategy_is_least_connected(self):
222+
with StubCluster({9001: "router.script"}):
223+
uri = "bolt+routing://127.0.0.1:9001"
224+
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False) as driver:
225+
self.assertIsInstance(driver, RoutingDriver)
226+
self.assertIsInstance(driver._pool, RoutingConnectionPool)
227+
self.assertIsInstance(driver._pool.load_balancing_strategy, LeastConnectedLoadBalancingStrategy)
228+
229+
def test_can_select_round_robin_load_balancing_strategy(self):
230+
with StubCluster({9001: "router.script"}):
231+
uri = "bolt+routing://127.0.0.1:9001"
232+
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False,
233+
load_balancing_strategy=LOAD_BALANCING_STRATEGY_ROUND_ROBIN) as driver:
234+
self.assertIsInstance(driver, RoutingDriver)
235+
self.assertIsInstance(driver._pool, RoutingConnectionPool)
236+
self.assertIsInstance(driver._pool.load_balancing_strategy, RoundRobinLoadBalancingStrategy)
237+
238+
def test_no_other_load_balancing_strategies_are_available(self):
239+
with StubCluster({9001: "router.script"}):
240+
uri = "bolt+routing://127.0.0.1:9001"
241+
with self.assertRaises(ValueError):
242+
with GraphDatabase.driver(uri, auth=self.auth_token, encrypted=False, load_balancing_strategy=-1):
243+
pass

0 commit comments

Comments
 (0)