Skip to content

Fix data race in routing pool #852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion neo4j/_async/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config):
return cls(pool, default_workspace_config)

def __init__(self, pool, default_workspace_config):
_Routing.__init__(self, pool.get_default_database_initial_router_addresses())
_Routing.__init__(self, [pool.address])
AsyncDriver.__init__(self, pool, default_workspace_config)

if not t.TYPE_CHECKING:
Expand Down
84 changes: 31 additions & 53 deletions neo4j/_async/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,9 @@ def in_use_connection_count(self, address):
""" Count the number of connections currently in use to a given
address.
"""
try:
connections = self.connections[address]
except KeyError:
return 0
else:
return sum(1 if connection.in_use else 0 for connection in connections)
with self.lock:
connections = self.connections.get(address, ())
return sum(connection.in_use for connection in connections)

async def mark_all_stale(self):
with self.lock:
Expand Down Expand Up @@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address):
# Each database have a routing table, the default database is a special case.
log.debug("[#0000] C: <NEO4J POOL> routing address %r", address)
self.address = address
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
self.routing_tables = {}
self.refresh_lock = AsyncRLock()

def __repr__(self):
Expand All @@ -456,37 +453,15 @@ def __repr__(self):
:return: The representation
:rtype: str
"""
return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())

@property
def first_initial_routing_address(self):
return self.get_default_database_initial_router_addresses()[0]

def get_default_database_initial_router_addresses(self):
""" Get the initial router addresses for the default database.

:return:
:rtype: OrderedSet
"""
return self.get_routing_table_for_default_database().initial_routers

def get_default_database_router_addresses(self):
""" Get the router addresses for the default database.

:return:
:rtype: OrderedSet
"""
return self.get_routing_table_for_default_database().routers

def get_routing_table_for_default_database(self):
return self.routing_tables[self.workspace_config.database]
return "<{} address={!r}>".format(self.__class__.__name__,
self.address)

async def get_or_create_routing_table(self, database):
async with self.refresh_lock:
if database not in self.routing_tables:
self.routing_tables[database] = RoutingTable(
database=database,
routers=self.get_default_database_initial_router_addresses()
routers=[self.address]
)
return self.routing_tables[database]

Expand Down Expand Up @@ -651,15 +626,15 @@ async def update_routing_table(
if prefer_initial_routing_address:
# TODO: Test this state
if await self._update_routing_table_from(
self.first_initial_routing_address, database=database,
self.address, database=database,
imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
):
# Why is only the first initial routing address used?
return
if await self._update_routing_table_from(
*(existing_routers - {self.first_initial_routing_address}),
*(existing_routers - {self.address}),
database=database, imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
Expand All @@ -668,7 +643,7 @@ async def update_routing_table(

if not prefer_initial_routing_address:
if await self._update_routing_table_from(
self.first_initial_routing_address, database=database,
self.address, database=database,
imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
Expand Down Expand Up @@ -705,6 +680,14 @@ async def ensure_routing_table_is_fresh(
"""
from neo4j.api import READ_ACCESS
async with self.refresh_lock:
for database_ in list(self.routing_tables.keys()):
# Remove unused databases in the routing table
# Remove the routing table after a timeout = TTL + 30s
log.debug("[#0000] C: <ROUTING AGED> database=%s", database_)
routing_table = self.routing_tables[database_]
if routing_table.should_be_purged_from_memory():
del self.routing_tables[database_]

routing_table = await self.get_or_create_routing_table(database)
if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
# Readers are fresh.
Expand All @@ -717,25 +700,21 @@ async def ensure_routing_table_is_fresh(
)
await self.update_connection_pool(database=database)

for database in list(self.routing_tables.keys()):
# Remove unused databases in the routing table
# Remove the routing table after a timeout = TTL + 30s
log.debug("[#0000] C: <ROUTING AGED> database=%s", database)
if (self.routing_tables[database].should_be_purged_from_memory()
and database != self.workspace_config.database):
del self.routing_tables[database]

return True

async def _select_address(self, *, access_mode, database):
from ...api import READ_ACCESS
""" Selects the address with the fewest in-use connections.
"""
async with self.refresh_lock:
if access_mode == READ_ACCESS:
addresses = self.routing_tables[database].readers
routing_table = self.routing_tables.get(database)
if routing_table:
if access_mode == READ_ACCESS:
addresses = routing_table.readers
else:
addresses = routing_table.writers
else:
addresses = self.routing_tables[database].writers
addresses = ()
addresses_by_usage = {}
for address in addresses:
addresses_by_usage.setdefault(
Expand Down Expand Up @@ -763,13 +742,12 @@ async def acquire(

from neo4j.api import check_access_mode
access_mode = check_access_mode(access_mode)
async with self.refresh_lock:
log.debug("[#0000] C: <ROUTING TABLE ENSURE FRESH> %r",
self.routing_tables)
await self.ensure_routing_table_is_fresh(
access_mode=access_mode, database=database, imp_user=None,
bookmarks=bookmarks, acquisition_timeout=timeout
)

await self.ensure_routing_table_is_fresh(
access_mode=access_mode, database=database,
imp_user=None, bookmarks=bookmarks,
acquisition_timeout=timeout
)

while True:
try:
Expand Down
2 changes: 1 addition & 1 deletion neo4j/_sync/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config):
return cls(pool, default_workspace_config)

def __init__(self, pool, default_workspace_config):
_Routing.__init__(self, pool.get_default_database_initial_router_addresses())
_Routing.__init__(self, [pool.address])
Driver.__init__(self, pool, default_workspace_config)

if not t.TYPE_CHECKING:
Expand Down
84 changes: 31 additions & 53 deletions neo4j/_sync/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,9 @@ def in_use_connection_count(self, address):
""" Count the number of connections currently in use to a given
address.
"""
try:
connections = self.connections[address]
except KeyError:
return 0
else:
return sum(1 if connection.in_use else 0 for connection in connections)
with self.lock:
connections = self.connections.get(address, ())
return sum(connection.in_use for connection in connections)

def mark_all_stale(self):
with self.lock:
Expand Down Expand Up @@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address):
# Each database have a routing table, the default database is a special case.
log.debug("[#0000] C: <NEO4J POOL> routing address %r", address)
self.address = address
self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])}
self.routing_tables = {}
self.refresh_lock = RLock()

def __repr__(self):
Expand All @@ -456,37 +453,15 @@ def __repr__(self):
:return: The representation
:rtype: str
"""
return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses())

@property
def first_initial_routing_address(self):
return self.get_default_database_initial_router_addresses()[0]

def get_default_database_initial_router_addresses(self):
""" Get the initial router addresses for the default database.

:return:
:rtype: OrderedSet
"""
return self.get_routing_table_for_default_database().initial_routers

def get_default_database_router_addresses(self):
""" Get the router addresses for the default database.

:return:
:rtype: OrderedSet
"""
return self.get_routing_table_for_default_database().routers

def get_routing_table_for_default_database(self):
return self.routing_tables[self.workspace_config.database]
return "<{} address={!r}>".format(self.__class__.__name__,
self.address)

def get_or_create_routing_table(self, database):
with self.refresh_lock:
if database not in self.routing_tables:
self.routing_tables[database] = RoutingTable(
database=database,
routers=self.get_default_database_initial_router_addresses()
routers=[self.address]
)
return self.routing_tables[database]

Expand Down Expand Up @@ -651,15 +626,15 @@ def update_routing_table(
if prefer_initial_routing_address:
# TODO: Test this state
if self._update_routing_table_from(
self.first_initial_routing_address, database=database,
self.address, database=database,
imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
):
# Why is only the first initial routing address used?
return
if self._update_routing_table_from(
*(existing_routers - {self.first_initial_routing_address}),
*(existing_routers - {self.address}),
database=database, imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
Expand All @@ -668,7 +643,7 @@ def update_routing_table(

if not prefer_initial_routing_address:
if self._update_routing_table_from(
self.first_initial_routing_address, database=database,
self.address, database=database,
imp_user=imp_user, bookmarks=bookmarks,
acquisition_timeout=acquisition_timeout,
database_callback=database_callback
Expand Down Expand Up @@ -705,6 +680,14 @@ def ensure_routing_table_is_fresh(
"""
from neo4j.api import READ_ACCESS
with self.refresh_lock:
for database_ in list(self.routing_tables.keys()):
# Remove unused databases in the routing table
# Remove the routing table after a timeout = TTL + 30s
log.debug("[#0000] C: <ROUTING AGED> database=%s", database_)
routing_table = self.routing_tables[database_]
if routing_table.should_be_purged_from_memory():
del self.routing_tables[database_]

routing_table = self.get_or_create_routing_table(database)
if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)):
# Readers are fresh.
Expand All @@ -717,25 +700,21 @@ def ensure_routing_table_is_fresh(
)
self.update_connection_pool(database=database)

for database in list(self.routing_tables.keys()):
# Remove unused databases in the routing table
# Remove the routing table after a timeout = TTL + 30s
log.debug("[#0000] C: <ROUTING AGED> database=%s", database)
if (self.routing_tables[database].should_be_purged_from_memory()
and database != self.workspace_config.database):
del self.routing_tables[database]

return True

def _select_address(self, *, access_mode, database):
from ...api import READ_ACCESS
""" Selects the address with the fewest in-use connections.
"""
with self.refresh_lock:
if access_mode == READ_ACCESS:
addresses = self.routing_tables[database].readers
routing_table = self.routing_tables.get(database)
if routing_table:
if access_mode == READ_ACCESS:
addresses = routing_table.readers
else:
addresses = routing_table.writers
else:
addresses = self.routing_tables[database].writers
addresses = ()
addresses_by_usage = {}
for address in addresses:
addresses_by_usage.setdefault(
Expand Down Expand Up @@ -763,13 +742,12 @@ def acquire(

from neo4j.api import check_access_mode
access_mode = check_access_mode(access_mode)
with self.refresh_lock:
log.debug("[#0000] C: <ROUTING TABLE ENSURE FRESH> %r",
self.routing_tables)
self.ensure_routing_table_is_fresh(
access_mode=access_mode, database=database, imp_user=None,
bookmarks=bookmarks, acquisition_timeout=timeout
)

self.ensure_routing_table_is_fresh(
access_mode=access_mode, database=database,
imp_user=None, bookmarks=bookmarks,
acquisition_timeout=timeout
)

while True:
try:
Expand Down