diff --git a/neo4j/_async/driver.py b/neo4j/_async/driver.py index 37062581..a8d4514c 100644 --- a/neo4j/_async/driver.py +++ b/neo4j/_async/driver.py @@ -669,7 +669,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config): return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): - _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) + _Routing.__init__(self, [pool.address]) AsyncDriver.__init__(self, pool, default_workspace_config) if not t.TYPE_CHECKING: diff --git a/neo4j/_async/io/_pool.py b/neo4j/_async/io/_pool.py index f61ebd46..53a5d5e7 100644 --- a/neo4j/_async/io/_pool.py +++ b/neo4j/_async/io/_pool.py @@ -288,12 +288,9 @@ def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given address. """ - try: - connections = self.connections[address] - except KeyError: - return 0 - else: - return sum(1 if connection.in_use else 0 for connection in connections) + with self.lock: + connections = self.connections.get(address, ()) + return sum(connection.in_use for connection in connections) async def mark_all_stale(self): with self.lock: @@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address): # Each database have a routing table, the default database is a special case. log.debug("[#0000] C: routing address %r", address) self.address = address - self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} + self.routing_tables = {} self.refresh_lock = AsyncRLock() def __repr__(self): @@ -456,37 +453,15 @@ def __repr__(self): :return: The representation :rtype: str """ - return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) - - @property - def first_initial_routing_address(self): - return self.get_default_database_initial_router_addresses()[0] - - def get_default_database_initial_router_addresses(self): - """ Get the initial router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().initial_routers - - def get_default_database_router_addresses(self): - """ Get the router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().routers - - def get_routing_table_for_default_database(self): - return self.routing_tables[self.workspace_config.database] + return "<{} address={!r}>".format(self.__class__.__name__, + self.address) async def get_or_create_routing_table(self, database): async with self.refresh_lock: if database not in self.routing_tables: self.routing_tables[database] = RoutingTable( database=database, - routers=self.get_default_database_initial_router_addresses() + routers=[self.address] ) return self.routing_tables[database] @@ -651,7 +626,7 @@ async def update_routing_table( if prefer_initial_routing_address: # TODO: Test this state if await self._update_routing_table_from( - self.first_initial_routing_address, database=database, + self.address, database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -659,7 +634,7 @@ async def update_routing_table( # Why is only the first initial routing address used? return if await self._update_routing_table_from( - *(existing_routers - {self.first_initial_routing_address}), + *(existing_routers - {self.address}), database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -668,7 +643,7 @@ async def update_routing_table( if not prefer_initial_routing_address: if await self._update_routing_table_from( - self.first_initial_routing_address, database=database, + self.address, database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -705,6 +680,14 @@ async def ensure_routing_table_is_fresh( """ from neo4j.api import READ_ACCESS async with self.refresh_lock: + for database_ in list(self.routing_tables.keys()): + # Remove unused databases in the routing table + # Remove the routing table after a timeout = TTL + 30s + log.debug("[#0000] C: database=%s", database_) + routing_table = self.routing_tables[database_] + if routing_table.should_be_purged_from_memory(): + del self.routing_tables[database_] + routing_table = await self.get_or_create_routing_table(database) if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): # Readers are fresh. @@ -717,14 +700,6 @@ async def ensure_routing_table_is_fresh( ) await self.update_connection_pool(database=database) - for database in list(self.routing_tables.keys()): - # Remove unused databases in the routing table - # Remove the routing table after a timeout = TTL + 30s - log.debug("[#0000] C: database=%s", database) - if (self.routing_tables[database].should_be_purged_from_memory() - and database != self.workspace_config.database): - del self.routing_tables[database] - return True async def _select_address(self, *, access_mode, database): @@ -732,10 +707,14 @@ async def _select_address(self, *, access_mode, database): """ Selects the address with the fewest in-use connections. """ async with self.refresh_lock: - if access_mode == READ_ACCESS: - addresses = self.routing_tables[database].readers + routing_table = self.routing_tables.get(database) + if routing_table: + if access_mode == READ_ACCESS: + addresses = routing_table.readers + else: + addresses = routing_table.writers else: - addresses = self.routing_tables[database].writers + addresses = () addresses_by_usage = {} for address in addresses: addresses_by_usage.setdefault( @@ -763,13 +742,12 @@ async def acquire( from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) - async with self.refresh_lock: - log.debug("[#0000] C: %r", - self.routing_tables) - await self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, imp_user=None, - bookmarks=bookmarks, acquisition_timeout=timeout - ) + + await self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, + imp_user=None, bookmarks=bookmarks, + acquisition_timeout=timeout + ) while True: try: diff --git a/neo4j/_sync/driver.py b/neo4j/_sync/driver.py index 18d2c8de..baa7a593 100644 --- a/neo4j/_sync/driver.py +++ b/neo4j/_sync/driver.py @@ -668,7 +668,7 @@ def open(cls, *targets, auth=None, routing_context=None, **config): return cls(pool, default_workspace_config) def __init__(self, pool, default_workspace_config): - _Routing.__init__(self, pool.get_default_database_initial_router_addresses()) + _Routing.__init__(self, [pool.address]) Driver.__init__(self, pool, default_workspace_config) if not t.TYPE_CHECKING: diff --git a/neo4j/_sync/io/_pool.py b/neo4j/_sync/io/_pool.py index 0a379009..468ce1dc 100644 --- a/neo4j/_sync/io/_pool.py +++ b/neo4j/_sync/io/_pool.py @@ -288,12 +288,9 @@ def in_use_connection_count(self, address): """ Count the number of connections currently in use to a given address. """ - try: - connections = self.connections[address] - except KeyError: - return 0 - else: - return sum(1 if connection.in_use else 0 for connection in connections) + with self.lock: + connections = self.connections.get(address, ()) + return sum(connection.in_use for connection in connections) def mark_all_stale(self): with self.lock: @@ -447,7 +444,7 @@ def __init__(self, opener, pool_config, workspace_config, address): # Each database have a routing table, the default database is a special case. log.debug("[#0000] C: routing address %r", address) self.address = address - self.routing_tables = {workspace_config.database: RoutingTable(database=workspace_config.database, routers=[address])} + self.routing_tables = {} self.refresh_lock = RLock() def __repr__(self): @@ -456,37 +453,15 @@ def __repr__(self): :return: The representation :rtype: str """ - return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) - - @property - def first_initial_routing_address(self): - return self.get_default_database_initial_router_addresses()[0] - - def get_default_database_initial_router_addresses(self): - """ Get the initial router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().initial_routers - - def get_default_database_router_addresses(self): - """ Get the router addresses for the default database. - - :return: - :rtype: OrderedSet - """ - return self.get_routing_table_for_default_database().routers - - def get_routing_table_for_default_database(self): - return self.routing_tables[self.workspace_config.database] + return "<{} address={!r}>".format(self.__class__.__name__, + self.address) def get_or_create_routing_table(self, database): with self.refresh_lock: if database not in self.routing_tables: self.routing_tables[database] = RoutingTable( database=database, - routers=self.get_default_database_initial_router_addresses() + routers=[self.address] ) return self.routing_tables[database] @@ -651,7 +626,7 @@ def update_routing_table( if prefer_initial_routing_address: # TODO: Test this state if self._update_routing_table_from( - self.first_initial_routing_address, database=database, + self.address, database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -659,7 +634,7 @@ def update_routing_table( # Why is only the first initial routing address used? return if self._update_routing_table_from( - *(existing_routers - {self.first_initial_routing_address}), + *(existing_routers - {self.address}), database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -668,7 +643,7 @@ def update_routing_table( if not prefer_initial_routing_address: if self._update_routing_table_from( - self.first_initial_routing_address, database=database, + self.address, database=database, imp_user=imp_user, bookmarks=bookmarks, acquisition_timeout=acquisition_timeout, database_callback=database_callback @@ -705,6 +680,14 @@ def ensure_routing_table_is_fresh( """ from neo4j.api import READ_ACCESS with self.refresh_lock: + for database_ in list(self.routing_tables.keys()): + # Remove unused databases in the routing table + # Remove the routing table after a timeout = TTL + 30s + log.debug("[#0000] C: database=%s", database_) + routing_table = self.routing_tables[database_] + if routing_table.should_be_purged_from_memory(): + del self.routing_tables[database_] + routing_table = self.get_or_create_routing_table(database) if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): # Readers are fresh. @@ -717,14 +700,6 @@ def ensure_routing_table_is_fresh( ) self.update_connection_pool(database=database) - for database in list(self.routing_tables.keys()): - # Remove unused databases in the routing table - # Remove the routing table after a timeout = TTL + 30s - log.debug("[#0000] C: database=%s", database) - if (self.routing_tables[database].should_be_purged_from_memory() - and database != self.workspace_config.database): - del self.routing_tables[database] - return True def _select_address(self, *, access_mode, database): @@ -732,10 +707,14 @@ def _select_address(self, *, access_mode, database): """ Selects the address with the fewest in-use connections. """ with self.refresh_lock: - if access_mode == READ_ACCESS: - addresses = self.routing_tables[database].readers + routing_table = self.routing_tables.get(database) + if routing_table: + if access_mode == READ_ACCESS: + addresses = routing_table.readers + else: + addresses = routing_table.writers else: - addresses = self.routing_tables[database].writers + addresses = () addresses_by_usage = {} for address in addresses: addresses_by_usage.setdefault( @@ -763,13 +742,12 @@ def acquire( from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) - with self.refresh_lock: - log.debug("[#0000] C: %r", - self.routing_tables) - self.ensure_routing_table_is_fresh( - access_mode=access_mode, database=database, imp_user=None, - bookmarks=bookmarks, acquisition_timeout=timeout - ) + + self.ensure_routing_table_is_fresh( + access_mode=access_mode, database=database, + imp_user=None, bookmarks=bookmarks, + acquisition_timeout=timeout + ) while True: try: