Skip to content

Commit e11879e

Browse files
authored
Fix pool closing connections too aggressively (#955)
Whenever a new routing table was fetched, the pool would close all connections to servers that were not part of the routing table. However, it might well be, that a missing server is present still in the routing table for another database. Hence, the pool now checks the routing tables for all databases before deciding which connections are no longer needed.
1 parent dc5bcf8 commit e11879e

File tree

4 files changed

+262
-72
lines changed

4 files changed

+262
-72
lines changed

src/neo4j/_async/io/_pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -813,8 +813,13 @@ async def update_routing_table(
813813
raise ServiceUnavailable("Unable to retrieve routing information")
814814

815815
async def update_connection_pool(self, *, database):
816-
routing_table = await self.get_or_create_routing_table(database)
817-
servers = routing_table.servers()
816+
async with self.refresh_lock:
817+
routing_tables = [await self.get_or_create_routing_table(database)]
818+
for db in self.routing_tables.keys():
819+
if db == database:
820+
continue
821+
routing_tables.append(self.routing_tables[db])
822+
servers = set.union(*(rt.servers() for rt in routing_tables))
818823
for address in list(self.connections):
819824
if address._unresolved not in servers:
820825
await super(AsyncNeo4jPool, self).deactivate(address)
@@ -960,6 +965,7 @@ async def deactivate(self, address):
960965
async def on_write_failure(self, address):
961966
""" Remove a writer address from the routing table, if present.
962967
"""
968+
# FIXME: only need to remove the writer for a specific database
963969
log.debug("[#0000] _: <POOL> removing writer %r", address)
964970
async with self.refresh_lock:
965971
for database in self.routing_tables.keys():

src/neo4j/_sync/io/_pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -810,8 +810,13 @@ def update_routing_table(
810810
raise ServiceUnavailable("Unable to retrieve routing information")
811811

812812
def update_connection_pool(self, *, database):
813-
routing_table = self.get_or_create_routing_table(database)
814-
servers = routing_table.servers()
813+
with self.refresh_lock:
814+
routing_tables = [self.get_or_create_routing_table(database)]
815+
for db in self.routing_tables.keys():
816+
if db == database:
817+
continue
818+
routing_tables.append(self.routing_tables[db])
819+
servers = set.union(*(rt.servers() for rt in routing_tables))
815820
for address in list(self.connections):
816821
if address._unresolved not in servers:
817822
super(Neo4jPool, self).deactivate(address)
@@ -957,6 +962,7 @@ def deactivate(self, address):
957962
def on_write_failure(self, address):
958963
""" Remove a writer address from the routing table, if present.
959964
"""
965+
# FIXME: only need to remove the writer for a specific database
960966
log.debug("[#0000] _: <POOL> removing writer %r", address)
961967
with self.refresh_lock:
962968
for database in self.routing_tables.keys():

tests/unit/async_/io/test_neo4j_pool.py

Lines changed: 123 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
import inspect
20+
from collections import defaultdict
2021

2122
import pytest
2223

@@ -50,26 +51,32 @@
5051
ROUTER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9000), host_name="host")
5152
ROUTER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9001), host_name="host")
5253
ROUTER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9002), host_name="host")
53-
READER_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host")
54-
WRITER_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host")
54+
READER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9010), host_name="host")
55+
READER2_ADDRESS = ResolvedAddress(("1.2.3.1", 9011), host_name="host")
56+
READER3_ADDRESS = ResolvedAddress(("1.2.3.1", 9012), host_name="host")
57+
WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host")
5558

5659

5760
@pytest.fixture
58-
def routing_failure_opener(async_fake_connection_generator, mocker):
59-
def make_opener(failures=None):
61+
def custom_routing_opener(async_fake_connection_generator, mocker):
62+
def make_opener(failures=None, get_readers=None):
6063
def routing_side_effect(*args, **kwargs):
6164
nonlocal failures
6265
res = next(failures, None)
6366
if res is None:
67+
if get_readers is not None:
68+
readers = get_readers(kwargs.get("database"))
69+
else:
70+
readers = [str(READER1_ADDRESS)]
6471
return [{
6572
"ttl": 1000,
6673
"servers": [
6774
{"addresses": [str(ROUTER1_ADDRESS),
6875
str(ROUTER2_ADDRESS),
6976
str(ROUTER3_ADDRESS)],
7077
"role": "ROUTE"},
71-
{"addresses": [str(READER_ADDRESS)], "role": "READ"},
72-
{"addresses": [str(WRITER_ADDRESS)], "role": "WRITE"},
78+
{"addresses": readers, "role": "READ"},
79+
{"addresses": [str(WRITER1_ADDRESS)], "role": "WRITE"},
7380
],
7481
}]
7582
raise res
@@ -96,8 +103,8 @@ async def open_(addr, auth, timeout):
96103

97104

98105
@pytest.fixture
99-
def opener(routing_failure_opener):
100-
return routing_failure_opener()
106+
def opener(custom_routing_opener):
107+
return custom_routing_opener()
101108

102109

103110
def _pool_config():
@@ -177,9 +184,9 @@ async def test_chooses_right_connection_type(opener, type_):
177184
)
178185
await pool.release(cx1)
179186
if type_ == "r":
180-
assert cx1.unresolved_address == READER_ADDRESS
187+
assert cx1.unresolved_address == READER1_ADDRESS
181188
else:
182-
assert cx1.unresolved_address == WRITER_ADDRESS
189+
assert cx1.unresolved_address == WRITER1_ADDRESS
183190

184191

185192
@mark_async_test
@@ -298,9 +305,9 @@ async def test_acquire_performs_no_liveness_check_on_fresh_connection(
298305
opener, liveness_timeout
299306
):
300307
pool = _simple_pool(opener)
301-
cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
308+
cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
302309
liveness_timeout)
303-
assert cx1.unresolved_address == READER_ADDRESS
310+
assert cx1.unresolved_address == READER1_ADDRESS
304311
cx1.reset.assert_not_called()
305312

306313

@@ -311,11 +318,11 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
311318
):
312319
pool = _simple_pool(opener)
313320
# populate the pool with a connection
314-
cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
321+
cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
315322
liveness_timeout)
316323

317324
# make sure we assume the right state
318-
assert cx1.unresolved_address == READER_ADDRESS
325+
assert cx1.unresolved_address == READER1_ADDRESS
319326
cx1.is_idle_for.assert_not_called()
320327
cx1.reset.assert_not_called()
321328

@@ -326,7 +333,7 @@ async def test_acquire_performs_liveness_check_on_existing_connection(
326333
cx1.reset.assert_not_called()
327334

328335
# then acquire it again and assert the liveness check was performed
329-
cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
336+
cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
330337
liveness_timeout)
331338
assert cx1 is cx2
332339
cx1.is_idle_for.assert_called_once_with(liveness_timeout)
@@ -345,11 +352,11 @@ def liveness_side_effect(*args, **kwargs):
345352
liveness_timeout = 1
346353
pool = _simple_pool(opener)
347354
# populate the pool with a connection
348-
cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
355+
cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
349356
liveness_timeout)
350357

351358
# make sure we assume the right state
352-
assert cx1.unresolved_address == READER_ADDRESS
359+
assert cx1.unresolved_address == READER1_ADDRESS
353360
cx1.is_idle_for.assert_not_called()
354361
cx1.reset.assert_not_called()
355362

@@ -362,7 +369,7 @@ def liveness_side_effect(*args, **kwargs):
362369
cx1.reset.assert_not_called()
363370

364371
# then acquire it again and assert the liveness check was performed
365-
cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
372+
cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
366373
liveness_timeout)
367374
assert cx1 is not cx2
368375
assert cx1.unresolved_address == cx2.unresolved_address
@@ -384,14 +391,14 @@ def liveness_side_effect(*args, **kwargs):
384391
liveness_timeout = 1
385392
pool = _simple_pool(opener)
386393
# populate the pool with a connection
387-
cx1 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
394+
cx1 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
388395
liveness_timeout)
389-
cx2 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
396+
cx2 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
390397
liveness_timeout)
391398

392399
# make sure we assume the right state
393-
assert cx1.unresolved_address == READER_ADDRESS
394-
assert cx2.unresolved_address == READER_ADDRESS
400+
assert cx1.unresolved_address == READER1_ADDRESS
401+
assert cx2.unresolved_address == READER1_ADDRESS
395402
assert cx1 is not cx2
396403
cx1.is_idle_for.assert_not_called()
397404
cx2.is_idle_for.assert_not_called()
@@ -409,7 +416,7 @@ def liveness_side_effect(*args, **kwargs):
409416
cx2.reset.assert_not_called()
410417

411418
# then acquire it again and assert the liveness check was performed
412-
cx3 = await pool._acquire(READER_ADDRESS, None, Deadline(30),
419+
cx3 = await pool._acquire(READER1_ADDRESS, None, Deadline(30),
413420
liveness_timeout)
414421
assert cx3 is cx2
415422
cx1.is_idle_for.assert_called_once_with(liveness_timeout)
@@ -426,7 +433,7 @@ def mock_connection_breaks_on_close(cx):
426433
async def close_side_effect():
427434
cx.closed.return_value = True
428435
cx.defunct.return_value = True
429-
await pool.deactivate(READER_ADDRESS)
436+
await pool.deactivate(READER1_ADDRESS)
430437

431438
cx.attach_mock(mocker.AsyncMock(side_effect=close_side_effect),
432439
"close")
@@ -470,9 +477,9 @@ async def test__acquire_new_later_with_room(opener):
470477
pool = AsyncNeo4jPool(
471478
opener, config, WorkspaceConfig(), ROUTER1_ADDRESS
472479
)
473-
assert pool.connections_reservations[READER_ADDRESS] == 0
474-
creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1))
475-
assert pool.connections_reservations[READER_ADDRESS] == 1
480+
assert pool.connections_reservations[READER1_ADDRESS] == 0
481+
creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1))
482+
assert pool.connections_reservations[READER1_ADDRESS] == 1
476483
assert callable(creator)
477484
if AsyncUtil.is_async_code:
478485
assert inspect.iscoroutinefunction(creator)
@@ -487,9 +494,9 @@ async def test__acquire_new_later_without_room(opener):
487494
)
488495
_ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None)
489496
# pool is full now
490-
assert pool.connections_reservations[READER_ADDRESS] == 0
491-
creator = pool._acquire_new_later(READER_ADDRESS, None, Deadline(1))
492-
assert pool.connections_reservations[READER_ADDRESS] == 0
497+
assert pool.connections_reservations[READER1_ADDRESS] == 0
498+
creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1))
499+
assert pool.connections_reservations[READER1_ADDRESS] == 0
493500
assert creator is None
494501

495502

@@ -519,8 +526,8 @@ async def test_passes_pool_config_to_connection(mocker):
519526
"Neo.ClientError.Security.AuthorizationExpired"),
520527
))
521528
@mark_async_test
522-
async def test_discovery_is_retried(routing_failure_opener, error):
523-
opener = routing_failure_opener([
529+
async def test_discovery_is_retried(custom_routing_opener, error):
530+
opener = custom_routing_opener([
524531
None, # first call to router for seeding the RT with more routers
525532
error, # will be retried
526533
])
@@ -563,8 +570,8 @@ async def test_discovery_is_retried(routing_failure_opener, error):
563570
)
564571
))
565572
@mark_async_test
566-
async def test_fast_failing_discovery(routing_failure_opener, error):
567-
opener = routing_failure_opener([
573+
async def test_fast_failing_discovery(custom_routing_opener, error):
574+
opener = custom_routing_opener([
568575
None, # first call to router for seeding the RT with more routers
569576
error, # will be retried
570577
])
@@ -648,3 +655,85 @@ async def test_connection_error_callback(
648655
cx.mark_unauthenticated.assert_not_called()
649656
for cx in cxs_write:
650657
cx.mark_unauthenticated.assert_not_called()
658+
659+
660+
@mark_async_test
661+
async def test_pool_closes_connections_dropped_from_rt(custom_routing_opener):
662+
readers = {"db1": [str(READER1_ADDRESS)]}
663+
664+
def get_readers(database):
665+
return readers[database]
666+
667+
opener = custom_routing_opener(get_readers=get_readers)
668+
669+
pool = AsyncNeo4jPool(
670+
opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS
671+
)
672+
cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None)
673+
assert cx1.unresolved_address == READER1_ADDRESS
674+
await pool.release(cx1)
675+
676+
cx1.close.assert_not_called()
677+
assert len(pool.connections[READER1_ADDRESS]) == 1
678+
679+
# force RT refresh, returning a different reader
680+
del pool.routing_tables["db1"]
681+
readers["db1"] = [str(READER2_ADDRESS)]
682+
683+
cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None)
684+
assert cx2.unresolved_address == READER2_ADDRESS
685+
686+
cx1.close.assert_awaited_once()
687+
assert len(pool.connections[READER1_ADDRESS]) == 0
688+
689+
await pool.release(cx2)
690+
assert len(pool.connections[READER2_ADDRESS]) == 1
691+
692+
693+
@mark_async_test
694+
async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server(
695+
custom_routing_opener
696+
):
697+
readers = {
698+
"db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)],
699+
"db2": [str(READER1_ADDRESS)]
700+
}
701+
702+
def get_readers(database):
703+
return readers[database]
704+
705+
opener = custom_routing_opener(get_readers=get_readers)
706+
707+
pool = AsyncNeo4jPool(
708+
opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS
709+
)
710+
cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None)
711+
await pool.release(cx1)
712+
assert cx1.unresolved_address in (READER1_ADDRESS, READER2_ADDRESS)
713+
reader1_connection_count = len(pool.connections[READER1_ADDRESS])
714+
reader2_connection_count = len(pool.connections[READER2_ADDRESS])
715+
assert reader1_connection_count + reader2_connection_count == 1
716+
717+
cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None)
718+
await pool.release(cx2)
719+
assert cx2.unresolved_address == READER1_ADDRESS
720+
cx1.close.assert_not_called()
721+
cx2.close.assert_not_called()
722+
assert len(pool.connections[READER1_ADDRESS]) == 1
723+
assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count
724+
725+
726+
# force RT refresh, returning a different reader
727+
del pool.routing_tables["db2"]
728+
readers["db2"] = [str(READER3_ADDRESS)]
729+
730+
cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None)
731+
await pool.release(cx3)
732+
assert cx3.unresolved_address == READER3_ADDRESS
733+
734+
cx1.close.assert_not_called()
735+
cx2.close.assert_not_called()
736+
cx3.close.assert_not_called()
737+
assert len(pool.connections[READER1_ADDRESS]) == 1
738+
assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count
739+
assert len(pool.connections[READER3_ADDRESS]) == 1

0 commit comments

Comments
 (0)