Skip to content

Commit dc5bcf8

Browse files
authored
Fix missing lock acquisition in pool (#952)
Fix some functions in the pool manipulating the routing information without holding the right lock.
1 parent 80b7be4 commit dc5bcf8

File tree

8 files changed

+36
-20
lines changed

8 files changed

+36
-20
lines changed

src/neo4j/_async/io/_bolt3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ async def _process_message(self, tag, fields):
407407
raise
408408
except (NotALeader, ForbiddenOnReadOnlyDatabase):
409409
if self.pool:
410-
self.pool.on_write_failure(address=self.unresolved_address)
410+
await self.pool.on_write_failure(
411+
address=self.unresolved_address
412+
)
411413
raise
412414
except Neo4jError as e:
413415
await self.pool.on_neo4j_error(e, self)

src/neo4j/_async/io/_bolt4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ async def _process_message(self, tag, fields):
356356
raise
357357
except (NotALeader, ForbiddenOnReadOnlyDatabase):
358358
if self.pool:
359-
self.pool.on_write_failure(address=self.unresolved_address)
359+
await self.pool.on_write_failure(
360+
address=self.unresolved_address
361+
)
360362
raise
361363
except Neo4jError as e:
362364
if self.pool:

src/neo4j/_async/io/_bolt5.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ async def _process_message(self, tag, fields):
356356
raise
357357
except (NotALeader, ForbiddenOnReadOnlyDatabase):
358358
if self.pool:
359-
self.pool.on_write_failure(address=self.unresolved_address)
359+
await self.pool.on_write_failure(
360+
address=self.unresolved_address
361+
)
360362
raise
361363
except Neo4jError as e:
362364
if self.pool:

src/neo4j/_async/io/_pool.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ async def deactivate(self, address):
451451

452452
await self._close_connections(closable_connections)
453453

454-
def on_write_failure(self, address):
454+
async def on_write_failure(self, address):
455455
raise WriteServiceUnavailable(
456456
"No write service available for pool {}".format(self)
457457
)
@@ -949,17 +949,19 @@ async def deactivate(self, address):
949949
log.debug("[#0000] _: <POOL> deactivating address %r", address)
950950
# We use `discard` instead of `remove` here since the former
951951
# will not fail if the address has already been removed.
952-
for database in self.routing_tables.keys():
953-
self.routing_tables[database].routers.discard(address)
954-
self.routing_tables[database].readers.discard(address)
955-
self.routing_tables[database].writers.discard(address)
952+
async with self.refresh_lock:
953+
for database in self.routing_tables.keys():
954+
self.routing_tables[database].routers.discard(address)
955+
self.routing_tables[database].readers.discard(address)
956+
self.routing_tables[database].writers.discard(address)
956957
log.debug("[#0000] _: <POOL> table=%r", self.routing_tables)
957958
await super(AsyncNeo4jPool, self).deactivate(address)
958959

959-
def on_write_failure(self, address):
960+
async def on_write_failure(self, address):
960961
""" Remove a writer address from the routing table, if present.
961962
"""
962963
log.debug("[#0000] _: <POOL> removing writer %r", address)
963-
for database in self.routing_tables.keys():
964-
self.routing_tables[database].writers.discard(address)
964+
async with self.refresh_lock:
965+
for database in self.routing_tables.keys():
966+
self.routing_tables[database].writers.discard(address)
965967
log.debug("[#0000] _: <POOL> table=%r", self.routing_tables)

src/neo4j/_sync/io/_bolt3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,9 @@ def _process_message(self, tag, fields):
407407
raise
408408
except (NotALeader, ForbiddenOnReadOnlyDatabase):
409409
if self.pool:
410-
self.pool.on_write_failure(address=self.unresolved_address)
410+
self.pool.on_write_failure(
411+
address=self.unresolved_address
412+
)
411413
raise
412414
except Neo4jError as e:
413415
self.pool.on_neo4j_error(e, self)

src/neo4j/_sync/io/_bolt4.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def _process_message(self, tag, fields):
356356
raise
357357
except (NotALeader, ForbiddenOnReadOnlyDatabase):
358358
if self.pool:
359-
self.pool.on_write_failure(address=self.unresolved_address)
359+
self.pool.on_write_failure(
360+
address=self.unresolved_address
361+
)
360362
raise
361363
except Neo4jError as e:
362364
if self.pool:

src/neo4j/_sync/io/_bolt5.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def _process_message(self, tag, fields):
356356
raise
357357
except (NotALeader, ForbiddenOnReadOnlyDatabase):
358358
if self.pool:
359-
self.pool.on_write_failure(address=self.unresolved_address)
359+
self.pool.on_write_failure(
360+
address=self.unresolved_address
361+
)
360362
raise
361363
except Neo4jError as e:
362364
if self.pool:

src/neo4j/_sync/io/_pool.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -946,17 +946,19 @@ def deactivate(self, address):
946946
log.debug("[#0000] _: <POOL> deactivating address %r", address)
947947
# We use `discard` instead of `remove` here since the former
948948
# will not fail if the address has already been removed.
949-
for database in self.routing_tables.keys():
950-
self.routing_tables[database].routers.discard(address)
951-
self.routing_tables[database].readers.discard(address)
952-
self.routing_tables[database].writers.discard(address)
949+
with self.refresh_lock:
950+
for database in self.routing_tables.keys():
951+
self.routing_tables[database].routers.discard(address)
952+
self.routing_tables[database].readers.discard(address)
953+
self.routing_tables[database].writers.discard(address)
953954
log.debug("[#0000] _: <POOL> table=%r", self.routing_tables)
954955
super(Neo4jPool, self).deactivate(address)
955956

956957
def on_write_failure(self, address):
957958
""" Remove a writer address from the routing table, if present.
958959
"""
959960
log.debug("[#0000] _: <POOL> removing writer %r", address)
960-
for database in self.routing_tables.keys():
961-
self.routing_tables[database].writers.discard(address)
961+
with self.refresh_lock:
962+
for database in self.routing_tables.keys():
963+
self.routing_tables[database].writers.discard(address)
962964
log.debug("[#0000] _: <POOL> table=%r", self.routing_tables)

0 commit comments

Comments
 (0)