diff --git a/.gitattributes b/.gitattributes index ca24c0e4..35dc3fdc 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,5 +1,5 @@ # configure github not to display generated files /src/neo4j/_sync/** linguist-generated=true -/tests/unit/sync_/** linguist-generated=true -/tests/integration/sync_/** linguist-generated=true +/tests/unit/sync/** linguist-generated=true +/tests/integration/sync/** linguist-generated=true /testkitbackend/_sync/** linguist-generated=true diff --git a/docs/source/api.rst b/docs/source/api.rst index c9f45e45..281d8c92 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -260,7 +260,8 @@ Closing a driver will immediately shut down all connections in the pool. :param database\_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. .. Note:: It is recommended to always specify the database explicitly @@ -1034,7 +1035,7 @@ Specifically, the following applies: all queries within that session are executed with the explicit database name 'movies' supplied. Any change to the user’s home database is reflected only in sessions created after such change takes effect. This - behavior requires additional network communication. In clustered + behavior may require additional network communication. In clustered environments, it is strongly recommended to avoid a single point of failure. For instance, by ensuring that the connection URI resolves to multiple endpoints. For older Bolt protocol versions the behavior is the diff --git a/docs/source/async_api.rst b/docs/source/async_api.rst index 74ce23a5..6c6e62e4 100644 --- a/docs/source/async_api.rst +++ b/docs/source/async_api.rst @@ -247,7 +247,8 @@ Closing a driver will immediately shut down all connections in the pool. :param database\_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. .. Note:: It is recommended to always specify the database explicitly diff --git a/src/neo4j/_async/driver.py b/src/neo4j/_async/driver.py index ce206233..a5420318 100644 --- a/src/neo4j/_async/driver.py +++ b/src/neo4j/_async/driver.py @@ -799,7 +799,8 @@ async def example(driver: neo4j.AsyncDriver) -> int: :param database_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. .. Note:: It is recommended to always specify the database explicitly diff --git a/src/neo4j/_async/home_db_cache.py b/src/neo4j/_async/home_db_cache.py new file mode 100644 index 00000000..96c2850f --- /dev/null +++ b/src/neo4j/_async/home_db_cache.py @@ -0,0 +1,150 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import typing as t +from time import monotonic + +from .._async_compat.concurrency import AsyncCooperativeLock + + +if t.TYPE_CHECKING: + import typing_extensions as te + + TKey: te.TypeAlias = t.Union[ + str, + t.Tuple[t.Tuple[str, t.Hashable], ...], + t.Tuple[None], + ] + TVal: te.TypeAlias = t.Tuple[float, str] + + +class AsyncHomeDbCache: + _ttl: float + _enabled: bool + _max_size: int | None + + def __init__( + self, + enabled: bool = True, + ttl: float = float("inf"), + max_size: int | None = None, + ) -> None: + if math.isnan(ttl) or ttl <= 0: + raise ValueError(f"home db cache ttl must be greater 0, got {ttl}") + self._enabled = enabled + self._ttl = ttl + self._cache: dict[TKey, TVal] = {} + self._lock = AsyncCooperativeLock() + self._oldest_entry = monotonic() + if max_size is not None and max_size <= 0: + raise ValueError( + f"home db cache max_size must be greater 0 or None, " + f"got {max_size}" + ) + self._max_size = max_size + self._truncate_size = ( + min(max_size, int(0.01 * max_size * math.log(max_size))) + if max_size is not None + else None + ) + + def compute_key( + self, + imp_user: str | None, + auth: dict | None, + ) -> TKey: + if not self._enabled: + return (None,) + if imp_user is not None: + return imp_user + if auth is not None: + return _consolidate_auth_token(auth) + return (None,) + + def get(self, key: TKey) -> str | None: + if not self._enabled: + return None + with self._lock: + self._clean(monotonic()) + val = self._cache.get(key) + if val is None: + return None + return val[1] + + def set(self, key: TKey, value: str | None) -> None: + if not self._enabled: + return + with self._lock: + now = monotonic() + self._clean(now) + if value is None: + self._cache.pop(key, None) + else: + self._cache[key] = (now, value) + + def clear(self) -> None: + if not self._enabled: + return + with self._lock: + self._cache = {} + self._oldest_entry = monotonic() + + def _clean(self, now: float | None = None) -> None: + now = monotonic() if now is None else now + if now - self._oldest_entry > self._ttl: + self._cache = { + k: v + for k, v in self._cache.items() + if now - v[0] < self._ttl * 0.9 + } + self._oldest_entry = min( + (v[0] for v in self._cache.values()), default=now + ) + if self._max_size and len(self._cache) > self._max_size: + self._cache = dict( + sorted( + self._cache.items(), + key=lambda item: item[1][0], + reverse=True, + )[: self._truncate_size] + ) + + def __len__(self) -> int: + return len(self._cache) + + @property + def enabled(self) -> bool: + return self._enabled + + +def _consolidate_auth_token(auth: dict) -> tuple | str: + if auth.get("scheme") == "basic" and isinstance( + auth.get("principal"), str + ): + return auth["principal"] + return _hashable_dict(auth) + + +def _hashable_dict(d: dict) -> tuple: + return tuple( + (k, _hashable_dict(v) if isinstance(v, dict) else v) + for k, v in sorted(d.items()) + ) diff --git a/src/neo4j/_async/io/__init__.py b/src/neo4j/_async/io/__init__.py index 3571ad94..7f068da7 100644 --- a/src/neo4j/_async/io/__init__.py +++ b/src/neo4j/_async/io/__init__.py @@ -22,7 +22,8 @@ """ __all__ = [ - "AcquireAuth", + "AcquisitionAuth", + "AcquisitionDatabase", "AsyncBolt", "AsyncBoltPool", "AsyncNeo4jPool", @@ -37,7 +38,8 @@ ConnectionErrorHandler, ) from ._pool import ( - AcquireAuth, + AcquisitionAuth, + AcquisitionDatabase, AsyncBoltPool, AsyncNeo4jPool, ) diff --git a/src/neo4j/_async/io/_bolt.py b/src/neo4j/_async/io/_bolt.py index b3e485d4..b2a73977 100644 --- a/src/neo4j/_async/io/_bolt.py +++ b/src/neo4j/_async/io/_bolt.py @@ -25,6 +25,7 @@ from ..._async_compat.network import AsyncBoltSocket from ..._async_compat.util import AsyncUtil +from ..._auth_management import to_auth_dict from ..._codec.hydration import ( HydrationHandlerABC, v1 as hydration_v1, @@ -39,12 +40,10 @@ from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( - Auth, ServerInfo, Version, ) from ...exceptions import ( - AuthError, ConfigurationError, DriverError, IncompleteCommit, @@ -158,10 +157,7 @@ def __init__( ), self.PROTOCOL_VERSION, ) - # so far `connection.recv_timeout_seconds` is the only available - # configuration hint that exists. Therefore, all hints can be stored at - # connection level. This might change in the future. - self.configuration_hints = {} + self.connection_hints = {} self.patch = {} self.outbox = AsyncOutbox( self.socket, @@ -187,7 +183,7 @@ def __init__( self.user_agent = USER_AGENT self.auth = auth - self.auth_dict = self._to_auth_dict(auth) + self.auth_dict = to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled @@ -206,26 +202,14 @@ def _get_server_state_manager(self) -> ServerStateManagerBase: ... @abc.abstractmethod def _get_client_state_manager(self) -> ClientStateManagerBase: ... - @classmethod - def _to_auth_dict(cls, auth): - # Determine auth details - if not auth: - return {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - return vars(Auth("basic", *auth)) - else: - try: - return vars(auth) - except (KeyError, TypeError) as e: - # TODO: 6.0 - change this to be a DriverError (or subclass) - raise AuthError( - f"Cannot determine auth details from {auth!r}" - ) from e - @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") + @property + @abc.abstractmethod + def ssr_enabled(self) -> bool: ... + @property @abc.abstractmethod def supports_multiple_results(self): @@ -308,6 +292,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5, AsyncBolt5x6, AsyncBolt5x7, + AsyncBolt5x8, ) handlers = { @@ -325,6 +310,7 @@ def protocol_handlers(cls, protocol_version=None): AsyncBolt5x5.PROTOCOL_VERSION: AsyncBolt5x5, AsyncBolt5x6.PROTOCOL_VERSION: AsyncBolt5x6, AsyncBolt5x7.PROTOCOL_VERSION: AsyncBolt5x7, + AsyncBolt5x8.PROTOCOL_VERSION: AsyncBolt5x8, } if protocol_version is None: @@ -461,7 +447,10 @@ async def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import AsyncBolt5x8 + bolt_cls = AsyncBolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import AsyncBolt5x7 bolt_cls = AsyncBolt5x7 elif protocol_version == (5, 6): @@ -626,7 +615,7 @@ def re_auth( :returns: whether the auth was changed """ - new_auth_dict = self._to_auth_dict(auth) + new_auth_dict = to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: self.auth_manager = auth_manager self.auth = auth diff --git a/src/neo4j/_async/io/_bolt3.py b/src/neo4j/_async/io/_bolt3.py index 08e75abb..2997296f 100644 --- a/src/neo4j/_async/io/_bolt3.py +++ b/src/neo4j/_async/io/_bolt3.py @@ -148,6 +148,8 @@ class AsyncBolt3(AsyncBolt): PROTOCOL_VERSION = Version(3, 0) + ssr_enabled = False + supports_multiple_results = False supports_multiple_databases = False diff --git a/src/neo4j/_async/io/_bolt4.py b/src/neo4j/_async/io/_bolt4.py index 202d5570..abc9d4cb 100644 --- a/src/neo4j/_async/io/_bolt4.py +++ b/src/neo4j/_async/io/_bolt4.py @@ -64,6 +64,8 @@ class AsyncBolt4x0(AsyncBolt): PROTOCOL_VERSION = Version(4, 0) + ssr_enabled = False + supports_multiple_results = True supports_multiple_databases = True @@ -614,10 +616,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: diff --git a/src/neo4j/_async/io/_bolt5.py b/src/neo4j/_async/io/_bolt5.py index 06336193..a6f9e469 100644 --- a/src/neo4j/_async/io/_bolt5.py +++ b/src/neo4j/_async/io/_bolt5.py @@ -107,6 +107,10 @@ def _on_client_state_change(self, old_state, new_state): def _get_client_state_manager(self) -> ClientStateManagerBase: return self._client_state_manager + @property + def ssr_enabled(self) -> bool: + return False + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -141,10 +145,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -615,10 +619,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -702,10 +706,10 @@ async def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -883,7 +887,7 @@ def telemetry( hydration_hooks=None, **handlers, ) -> None: - if self.telemetry_disabled or not self.configuration_hints.get( + if self.telemetry_disabled or not self.connection_hints.get( "telemetry.enabled", False ): return @@ -1225,3 +1229,11 @@ async def _process_message(self, tag, fields): ) return len(details), 1 + + +class AsyncBolt5x8(AsyncBolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 7a520abe..be697ddb 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -59,6 +59,7 @@ WriteServiceUnavailable, ) from ..config import AsyncPoolConfig +from ..home_db_cache import AsyncHomeDbCache from ._bolt import AsyncBolt @@ -74,11 +75,48 @@ @dataclass -class AcquireAuth: +class AcquisitionAuth: auth: AsyncAuthManager | AuthManager | None force_auth: bool = False +@dataclass +class AcquisitionDatabase: + name: str | None + guessed: bool = False + + +@dataclass +class ConnectionFeatureTracker: + feature_check: t.Callable[[AsyncBolt], bool] + with_feature: int = 0 + without_feature: int = 0 + + @property + def has_feature(self): + return self.with_feature > 0 and self.without_feature == 0 + + def add_connection(self, connection): + if self.feature_check(connection): + self.with_feature += 1 + else: + self.without_feature += 1 + + def remove_connection(self, connection): + if self.feature_check(connection): + if self.with_feature == 0: + raise RuntimeError( + "No connections to be removed from feature tracker" + ) + self.with_feature -= 1 + else: + if self.without_feature == 0: + raise RuntimeError( + "No connections to be removed from feature tracker" + ) + self.without_feature -= 1 + + class AsyncIOPool(abc.ABC): """A collection of connections to one or more server addresses.""" @@ -94,11 +132,20 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = AsyncCooperativeRLock() self.cond = AsyncCondition(self.lock) + self.home_db_cache = AsyncHomeDbCache(max_size=10_000) + self._ssr_feature_tracker = ConnectionFeatureTracker( + feature_check=lambda connection: connection.ssr_enabled + ) @property @abc.abstractmethod def is_direct_pool(self) -> bool: ... + @property + def ssr_enabled(self) -> bool: + with self.lock: + return self._ssr_feature_tracker.has_feature + async def __aenter__(self): return self @@ -133,6 +180,20 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._ssr_feature_tracker.remove_connection(connection) + + def _add_connections(self, address, *connections): + with self.lock: + self.connections[address].extend(connections) + for connection in connections: + self._ssr_feature_tracker.add_connection(connection) + + def _remove_connections(self, address, *connections): + with self.lock: + existing_connections = self.connections.get(address, []) + for connection in connections: + existing_connections.remove(connection) + self._ssr_feature_tracker.remove_connection(connection) async def _acquire_from_pool_checked( self, address, health_check, deadline @@ -193,7 +254,7 @@ async def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 released_reservation = True - self.connections[address].append(connection) + self._add_connections(address, connection) return connection finally: if not released_reservation: @@ -261,7 +322,7 @@ async def _acquire(self, address, auth, deadline, liveness_check_timeout): This method is thread safe. """ if auth is None: - auth = AcquireAuth(None) + auth = AcquisitionAuth(None) force_auth = auth.force_auth auth = auth.auth if liveness_check_timeout is None: @@ -356,8 +417,9 @@ async def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -369,6 +431,7 @@ async def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param database_callback: """ ... @@ -493,8 +556,7 @@ async def deactivate(self, address): # First remove all connections in question, then try to close them. # If closing of a connection fails, we will end up in this method # again. - for conn in closable_connections: - connections.remove(conn) + self._remove_connections(address, *closable_connections) if not self.connections[address]: del self.connections[address] @@ -540,6 +602,8 @@ async def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) await self._close_connections(connections) except TypeError: pass @@ -585,8 +649,9 @@ async def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -676,6 +741,10 @@ async def get_or_create_routing_table(self, database): ) return self.routing_tables[database] + async def get_routing_table(self, database): + async with self.refresh_lock: + return self.routing_tables.get(database) + async def fetch_routing_info( self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): @@ -853,8 +922,7 @@ async def _update_routing_table_from( address, self.routing_tables[new_database], ) - if callable(database_callback): - database_callback(new_database) + await AsyncUtil.callback(database_callback, new_database) return True await self.deactivate(router) return False @@ -888,13 +956,16 @@ async def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ async with self.refresh_lock: - routing_table = await self.get_or_create_routing_table(database) - # copied because it can be modified - existing_routers = set(routing_table.routers) - - prefer_initial_routing_address = self.routing_tables[ - database - ].initialized_without_writers + routing_table = await self.get_routing_table(database) + if routing_table is not None: + # copied because it can be modified + existing_routers = set(routing_table.routers) + prefer_initial_routing_address = ( + routing_table.initialized_without_writers + ) + else: + existing_routers = {self.address} + prefer_initial_routing_address = True if ( prefer_initial_routing_address @@ -942,14 +1013,14 @@ async def update_routing_table( log.error("Unable to retrieve routing information") raise ServiceUnavailable("Unable to retrieve routing information") - async def update_connection_pool(self, *, database): + async def update_connection_pool(self): async with self.refresh_lock: - routing_tables = [await self.get_or_create_routing_table(database)] - for db in self.routing_tables: - if db == database: - continue - routing_tables.append(self.routing_tables[db]) - servers = set.union(*(rt.servers() for rt in routing_tables)) + routing_tables = list(self.routing_tables.values()) + + servers = set.union( + *(rt.servers() for rt in routing_tables), + self.address, + ) for address in list(self.connections): if address._unresolved not in servers: await super().deactivate(address) @@ -958,13 +1029,13 @@ async def ensure_routing_table_is_fresh( self, *, access_mode, - database, + database: AcquisitionDatabase, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None, - ): + ) -> bool: """ Update the routing table if stale. @@ -996,8 +1067,10 @@ async def ensure_routing_table_is_fresh( ) del self.routing_tables[database_] - routing_table = await self.get_or_create_routing_table(database) - if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + routing_table = await self.get_routing_table(database.name) + if routing_table is not None and routing_table.is_fresh( + readonly=(access_mode == READ_ACCESS) + ): # table is still valid log.debug( "[#0000] _: using existing routing table %r", @@ -1005,15 +1078,20 @@ async def ensure_routing_table_is_fresh( ) return False + database_request = database.name if not database.guessed else None + + async def wrapped_database_callback(database: str | None) -> None: + await AsyncUtil.callback(database_callback, database) + await self.update_connection_pool() + await self.update_routing_table( - database=database, + database=database_request, imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, - database_callback=database_callback, + database_callback=wrapped_database_callback, ) - await self.update_connection_pool(database=database) return True @@ -1050,10 +1128,11 @@ async def acquire( self, access_mode, timeout, - database, + database: AcquisitionDatabase, bookmarks, - auth: AcquireAuth | None, + auth: AcquisitionAuth | None, liveness_check_timeout, + database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError @@ -1067,10 +1146,14 @@ async def acquire( from ...api import check_access_mode access_mode = check_access_mode(access_mode) - # await self.ensure_routing_table_is_fresh( - # access_mode=access_mode, database=database, imp_user=None, - # bookmarks=bookmarks, acquisition_timeout=timeout - # ) + + target_database = database.name + + async def wrapped_database_callback(new_database): + nonlocal target_database + if new_database is not None: + target_database = new_database + await AsyncUtil.callback(database_callback, new_database) log.debug( "[#0000] _: acquire routing connection, " @@ -1085,6 +1168,11 @@ async def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, + database_callback=( + wrapped_database_callback + if database.guessed + else database_callback + ), ) while True: @@ -1092,7 +1180,7 @@ async def acquire( # Get an address for a connection that have the fewest in-use # connections. address = await self._select_address( - access_mode=access_mode, database=database + access_mode=access_mode, database=target_database ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( @@ -1103,7 +1191,7 @@ async def acquire( log.debug( "[#0000] _: acquire address, database=%r " "address=%r", - database, + target_database, address, ) deadline = Deadline.from_timeout_or_deadline(timeout) diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index ac62fa1f..721fac4a 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -114,6 +114,7 @@ def __init__( warn_notification_severity, on_closed, on_error, + on_database, ) -> None: self._connection_cls = connection.__class__ self._connection = ConnectionErrorHandler( @@ -122,6 +123,7 @@ def __init__( self._hydration_scope = connection.new_hydration_scope() self._on_error = on_error self._on_closed = on_closed + self._on_database = on_database self._metadata: dict = {} self._address: Address = self._connection.unresolved_address self._keys: tuple[str, ...] = () @@ -197,7 +199,7 @@ async def _run( } self._database = db - def on_attached(metadata): + async def on_attached(metadata): self._metadata.update(metadata) # For auto-commit there is no qid and Bolt 3 does not support qid self._raw_qid = metadata.get("qid", -1) @@ -205,6 +207,9 @@ def on_attached(metadata): self._connection.most_recent_qid = self._raw_qid self._keys = metadata.get("fields") self._attached = True + db_ = metadata.get("db") + if isinstance(db_, str): + await AsyncUtil.callback(self._on_database, db_) async def on_failed_attach(metadata): self._metadata.update(metadata) diff --git a/src/neo4j/_async/work/session.py b/src/neo4j/_async/work/session.py index 77c0a7bf..dd7324ab 100644 --- a/src/neo4j/_async/work/session.py +++ b/src/neo4j/_async/work/session.py @@ -321,6 +321,7 @@ async def run( self._config.warn_notification_severity, self._result_closed, self._result_error, + self._make_db_resolution_callback(), ) bookmarks = await self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -448,6 +449,7 @@ async def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, + self._make_db_resolution_callback(), ) bookmarks = await self._get_bookmarks() await self._transaction._begin( diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 2a1aa062..921a29ef 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -47,6 +47,7 @@ def __init__( on_closed, on_error, on_cancel, + on_database, ): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -62,6 +63,7 @@ def __init__( self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + self._on_database = on_database super().__init__() async def _enter(self) -> te.Self: @@ -92,6 +94,11 @@ async def _begin( notifications_disabled_classifications, pipelined=False, ): + async def on_begin_success(metadata_): + db = metadata_.get("db") + if isinstance(db, str): + await AsyncUtil.callback(self._on_database, db) + self._database = database self._connection.begin( bookmarks=bookmarks, @@ -101,7 +108,10 @@ async def _begin( db=database, imp_user=imp_user, notifications_min_severity=notifications_min_severity, - notifications_disabled_classifications=notifications_disabled_classifications, + notifications_disabled_classifications=( + notifications_disabled_classifications + ), + on_success=on_begin_success, ) if not pipelined: await self._error_handling_connection.send_all() @@ -188,6 +198,7 @@ async def run( self._warn_notification_severity, self._result_on_closed_handler, self._error_handler, + None, ) self._results.append(result) diff --git a/src/neo4j/_async/work/workspace.py b/src/neo4j/_async/work/workspace.py index dc044249..6f0a08c7 100644 --- a/src/neo4j/_async/work/workspace.py +++ b/src/neo4j/_async/work/workspace.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +import typing as t from ..._async_compat.util import AsyncUtil +from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig from ..._meta import ( deprecation_warn, @@ -32,11 +34,25 @@ ) from .._debug import AsyncNonConcurrentMethodChecker from ..io import ( - AcquireAuth, - AsyncNeo4jPool, + AcquisitionAuth, + AcquisitionDatabase, ) +if t.TYPE_CHECKING: + from ...api import _TAuth + from ...auth_management import ( + AsyncAuthManager, + AuthManager, + ) + from ..home_db_cache import ( + AsyncHomeDbCache, + TKey, + ) +else: + _TAuth = t.Any + + log = logging.getLogger("neo4j") @@ -47,8 +63,9 @@ def __init__(self, pool, config): self._config = config self._connection = None self._connection_access_mode = None + self._last_cache_key: TKey | None = None # Sessions are supposed to cache the database on which to operate. - self._cached_database = False + self._pinned_database = False self._bookmarks = () self._initial_bookmarks = () self._bookmark_manager = None @@ -87,8 +104,26 @@ async def __aenter__(self) -> AsyncWorkspace: async def __aexit__(self, exc_type, exc_value, traceback): await self.close() - def _set_cached_database(self, database): - self._cached_database = True + def _make_db_resolution_callback( + self, + ) -> t.Callable[[str | None], None] | None: + if self._pinned_database: + return None + + def _database_callback(database: str | None) -> None: + self._set_pinned_database(database) + if self._last_cache_key is None or database is None: + return + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + db_cache.set(self._last_cache_key, database) + + return _database_callback + + def _set_pinned_database(self, database): + if self._pinned_database: + return + log.debug("[#0000] _: pinning database: %r", database) + self._pinned_database = True self._config.database = database def _initialize_bookmarks(self, bookmarks): @@ -138,12 +173,10 @@ async def _update_bookmark(self, bookmark): return await self._update_bookmarks((bookmark,)) - async def _connect(self, access_mode, auth=None, **acquire_kwargs): + async def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout - auth = AcquireAuth( - auth, - force_auth=acquire_kwargs.pop("force_auth", False), - ) + force_auth = acquire_kwargs.pop("force_auth", False) + acquire_auth = AcquisitionAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -151,40 +184,112 @@ async def _connect(self, access_mode, auth=None, **acquire_kwargs): await self._connection.send_all() await self._connection.fetch_all() await self._disconnect() - if not self._cached_database: - if self._config.database is not None or not isinstance( - self._pool, AsyncNeo4jPool - ): - self._set_cached_database(self._config.database) - else: - # This is the first time we open a connection to a server in a - # cluster environment for this session without explicitly - # configured database. Hence, we request a routing table update - # to try to fetch the home database. If provided by the server, - # we shall use this database explicitly for all subsequent - # actions within this session. - log.debug("[#0000] _: resolve home database") - await self._pool.update_routing_table( - database=self._config.database, - imp_user=self._config.impersonated_user, - bookmarks=await self._get_bookmarks(), - auth=auth, - acquisition_timeout=acquisition_timeout, - database_callback=self._set_cached_database, - ) + + ssr_enabled = self._pool.ssr_enabled + target_db = await self._get_routing_target_database( + acquire_auth, ssr_enabled=ssr_enabled + ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": self._config.database, + "database": target_db, "bookmarks": await self._get_bookmarks(), - "auth": auth, + "auth": acquire_auth, "liveness_check_timeout": None, + "database_callback": self._make_db_resolution_callback(), } acquire_kwargs_.update(acquire_kwargs) self._connection = await self._pool.acquire(**acquire_kwargs_) + if ( + target_db.guessed + and not self._pinned_database + and not self._connection.ssr_enabled + ): + # race condition: we now have created a connection which does not + # support SSR. + # => we need to fall back to explicit home database resolution + log.debug( + "[#0000] _: detected ssr support race; " + "falling back to explicit home database resolution", + ) + await self._disconnect() + target_db = await self._get_routing_target_database( + acquire_auth, ssr_enabled=False + ) + acquire_kwargs_["database"] = target_db + self._connection = await self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode + async def _get_routing_target_database( + self, + acquire_auth: AcquisitionAuth, + ssr_enabled: bool, + ) -> AcquisitionDatabase: + if ( + self._pinned_database + or self._config.database is not None + or self._pool.is_direct_pool + ): + log.debug( + "[#0000] _: routing towards fixed database: %s", + self._config.database, + ) + self._set_pinned_database(self._config.database) + return AcquisitionDatabase(self._config.database) + + auth = acquire_auth.auth + resolved_auth = await self._resolve_session_auth(auth) + db_cache: AsyncHomeDbCache = self._pool.home_db_cache + cache_key = db_cache.compute_key( + self._config.impersonated_user, + resolved_auth, + ) + self._last_cache_key = cache_key + + if ssr_enabled: + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: routing towards cached " + "database: %s" + ), + cached_db, + ) + return AcquisitionDatabase(cached_db, guessed=True) + + acquisition_timeout = self._config.connection_acquisition_timeout + log.debug("[#0000] _: resolve home database") + await self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=await self._get_bookmarks(), + auth=acquire_auth, + acquisition_timeout=acquisition_timeout, + database_callback=self._make_db_resolution_callback(), + ) + return AcquisitionDatabase(self._config.database) + + @staticmethod + async def _resolve_session_auth( + auth: AsyncAuthManager | AuthManager | None, + ) -> dict | None: + if auth is None: + return None + # resolved_auth = await AsyncUtil.callback(auth.get_auth) + # The above line breaks mypy + # https://github.com/python/mypy/issues/15295 + auth_getter: t.Callable[[], _TAuth | t.Awaitable[_TAuth]] = ( + auth.get_auth + ) + # so we enforce the right type here + # (explicit type annotation above added as it's a necessary assumption + # for this cast to be correct) + resolved_auth = t.cast(_TAuth, await AsyncUtil.callback(auth_getter)) + return to_auth_dict(resolved_auth) + async def _disconnect(self, sync=False): + self._last_cache_key = None if self._connection: if sync: try: diff --git a/src/neo4j/_auth_management.py b/src/neo4j/_auth_management.py index 79eea256..98a16150 100644 --- a/src/neo4j/_auth_management.py +++ b/src/neo4j/_auth_management.py @@ -21,6 +21,9 @@ import typing as t from dataclasses import dataclass +from .api import Auth +from .exceptions import AuthError + if t.TYPE_CHECKING: from os import PathLike @@ -306,3 +309,19 @@ async def get_certificate(self) -> ClientCertificate | None: .. seealso:: :meth:`.ClientCertificateProvider.get_certificate` """ ... + + +def to_auth_dict(auth: _TAuth) -> dict[str, t.Any]: + # Determine auth details + if not auth: + return {} + elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: + return vars(Auth("basic", *auth)) + else: + try: + return vars(auth) + except (KeyError, TypeError) as e: + # TODO: 6.0 - change this to be a DriverError (or subclass) + raise AuthError( + f"Cannot determine auth details from {auth!r}" + ) from e diff --git a/src/neo4j/_sync/driver.py b/src/neo4j/_sync/driver.py index 15a8c415..971f9be7 100644 --- a/src/neo4j/_sync/driver.py +++ b/src/neo4j/_sync/driver.py @@ -798,7 +798,8 @@ def example(driver: neo4j.Driver) -> int: :param database_: Database to execute the query against. - None (default) uses the database configured on the server side. + :data:`None` (default) uses the database configured on the server + side. .. Note:: It is recommended to always specify the database explicitly diff --git a/src/neo4j/_sync/home_db_cache.py b/src/neo4j/_sync/home_db_cache.py new file mode 100644 index 00000000..904fa662 --- /dev/null +++ b/src/neo4j/_sync/home_db_cache.py @@ -0,0 +1,150 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import typing as t +from time import monotonic + +from .._async_compat.concurrency import CooperativeLock + + +if t.TYPE_CHECKING: + import typing_extensions as te + + TKey: te.TypeAlias = t.Union[ + str, + t.Tuple[t.Tuple[str, t.Hashable], ...], + t.Tuple[None], + ] + TVal: te.TypeAlias = t.Tuple[float, str] + + +class HomeDbCache: + _ttl: float + _enabled: bool + _max_size: int | None + + def __init__( + self, + enabled: bool = True, + ttl: float = float("inf"), + max_size: int | None = None, + ) -> None: + if math.isnan(ttl) or ttl <= 0: + raise ValueError(f"home db cache ttl must be greater 0, got {ttl}") + self._enabled = enabled + self._ttl = ttl + self._cache: dict[TKey, TVal] = {} + self._lock = CooperativeLock() + self._oldest_entry = monotonic() + if max_size is not None and max_size <= 0: + raise ValueError( + f"home db cache max_size must be greater 0 or None, " + f"got {max_size}" + ) + self._max_size = max_size + self._truncate_size = ( + min(max_size, int(0.01 * max_size * math.log(max_size))) + if max_size is not None + else None + ) + + def compute_key( + self, + imp_user: str | None, + auth: dict | None, + ) -> TKey: + if not self._enabled: + return (None,) + if imp_user is not None: + return imp_user + if auth is not None: + return _consolidate_auth_token(auth) + return (None,) + + def get(self, key: TKey) -> str | None: + if not self._enabled: + return None + with self._lock: + self._clean(monotonic()) + val = self._cache.get(key) + if val is None: + return None + return val[1] + + def set(self, key: TKey, value: str | None) -> None: + if not self._enabled: + return + with self._lock: + now = monotonic() + self._clean(now) + if value is None: + self._cache.pop(key, None) + else: + self._cache[key] = (now, value) + + def clear(self) -> None: + if not self._enabled: + return + with self._lock: + self._cache = {} + self._oldest_entry = monotonic() + + def _clean(self, now: float | None = None) -> None: + now = monotonic() if now is None else now + if now - self._oldest_entry > self._ttl: + self._cache = { + k: v + for k, v in self._cache.items() + if now - v[0] < self._ttl * 0.9 + } + self._oldest_entry = min( + (v[0] for v in self._cache.values()), default=now + ) + if self._max_size and len(self._cache) > self._max_size: + self._cache = dict( + sorted( + self._cache.items(), + key=lambda item: item[1][0], + reverse=True, + )[: self._truncate_size] + ) + + def __len__(self) -> int: + return len(self._cache) + + @property + def enabled(self) -> bool: + return self._enabled + + +def _consolidate_auth_token(auth: dict) -> tuple | str: + if auth.get("scheme") == "basic" and isinstance( + auth.get("principal"), str + ): + return auth["principal"] + return _hashable_dict(auth) + + +def _hashable_dict(d: dict) -> tuple: + return tuple( + (k, _hashable_dict(v) if isinstance(v, dict) else v) + for k, v in sorted(d.items()) + ) diff --git a/src/neo4j/_sync/io/__init__.py b/src/neo4j/_sync/io/__init__.py index a1833c74..5a7ea831 100644 --- a/src/neo4j/_sync/io/__init__.py +++ b/src/neo4j/_sync/io/__init__.py @@ -22,7 +22,8 @@ """ __all__ = [ - "AcquireAuth", + "AcquisitionAuth", + "AcquisitionDatabase", "Bolt", "BoltPool", "Neo4jPool", @@ -37,7 +38,8 @@ ConnectionErrorHandler, ) from ._pool import ( - AcquireAuth, + AcquisitionAuth, + AcquisitionDatabase, BoltPool, Neo4jPool, ) diff --git a/src/neo4j/_sync/io/_bolt.py b/src/neo4j/_sync/io/_bolt.py index 2ad1790a..bcd5a6ba 100644 --- a/src/neo4j/_sync/io/_bolt.py +++ b/src/neo4j/_sync/io/_bolt.py @@ -25,6 +25,7 @@ from ..._async_compat.network import BoltSocket from ..._async_compat.util import Util +from ..._auth_management import to_auth_dict from ..._codec.hydration import ( HydrationHandlerABC, v1 as hydration_v1, @@ -39,12 +40,10 @@ from ..._sync.config import PoolConfig from ...addressing import ResolvedAddress from ...api import ( - Auth, ServerInfo, Version, ) from ...exceptions import ( - AuthError, ConfigurationError, DriverError, IncompleteCommit, @@ -158,10 +157,7 @@ def __init__( ), self.PROTOCOL_VERSION, ) - # so far `connection.recv_timeout_seconds` is the only available - # configuration hint that exists. Therefore, all hints can be stored at - # connection level. This might change in the future. - self.configuration_hints = {} + self.connection_hints = {} self.patch = {} self.outbox = Outbox( self.socket, @@ -187,7 +183,7 @@ def __init__( self.user_agent = USER_AGENT self.auth = auth - self.auth_dict = self._to_auth_dict(auth) + self.auth_dict = to_auth_dict(auth) self.auth_manager = auth_manager self.telemetry_disabled = telemetry_disabled @@ -206,26 +202,14 @@ def _get_server_state_manager(self) -> ServerStateManagerBase: ... @abc.abstractmethod def _get_client_state_manager(self) -> ClientStateManagerBase: ... - @classmethod - def _to_auth_dict(cls, auth): - # Determine auth details - if not auth: - return {} - elif isinstance(auth, tuple) and 2 <= len(auth) <= 3: - return vars(Auth("basic", *auth)) - else: - try: - return vars(auth) - except (KeyError, TypeError) as e: - # TODO: 6.0 - change this to be a DriverError (or subclass) - raise AuthError( - f"Cannot determine auth details from {auth!r}" - ) from e - @property def connection_id(self): return self.server_info._metadata.get("connection_id", "") + @property + @abc.abstractmethod + def ssr_enabled(self) -> bool: ... + @property @abc.abstractmethod def supports_multiple_results(self): @@ -308,6 +292,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5, Bolt5x6, Bolt5x7, + Bolt5x8, ) handlers = { @@ -325,6 +310,7 @@ def protocol_handlers(cls, protocol_version=None): Bolt5x5.PROTOCOL_VERSION: Bolt5x5, Bolt5x6.PROTOCOL_VERSION: Bolt5x6, Bolt5x7.PROTOCOL_VERSION: Bolt5x7, + Bolt5x8.PROTOCOL_VERSION: Bolt5x8, } if protocol_version is None: @@ -461,7 +447,10 @@ def open( # avoid new lines after imports for better readability and conciseness # fmt: off - if protocol_version == (5, 7): + if protocol_version == (5, 8): + from ._bolt5 import Bolt5x8 + bolt_cls = Bolt5x8 + elif protocol_version == (5, 7): from ._bolt5 import Bolt5x7 bolt_cls = Bolt5x7 elif protocol_version == (5, 6): @@ -626,7 +615,7 @@ def re_auth( :returns: whether the auth was changed """ - new_auth_dict = self._to_auth_dict(auth) + new_auth_dict = to_auth_dict(auth) if not force and new_auth_dict == self.auth_dict: self.auth_manager = auth_manager self.auth = auth diff --git a/src/neo4j/_sync/io/_bolt3.py b/src/neo4j/_sync/io/_bolt3.py index e3cfd142..3f4c93a3 100644 --- a/src/neo4j/_sync/io/_bolt3.py +++ b/src/neo4j/_sync/io/_bolt3.py @@ -148,6 +148,8 @@ class Bolt3(Bolt): PROTOCOL_VERSION = Version(3, 0) + ssr_enabled = False + supports_multiple_results = False supports_multiple_databases = False diff --git a/src/neo4j/_sync/io/_bolt4.py b/src/neo4j/_sync/io/_bolt4.py index 69bb6dd6..99c04185 100644 --- a/src/neo4j/_sync/io/_bolt4.py +++ b/src/neo4j/_sync/io/_bolt4.py @@ -64,6 +64,8 @@ class Bolt4x0(Bolt): PROTOCOL_VERSION = Version(4, 0) + ssr_enabled = False + supports_multiple_results = True supports_multiple_databases = True @@ -614,10 +616,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: diff --git a/src/neo4j/_sync/io/_bolt5.py b/src/neo4j/_sync/io/_bolt5.py index 4138a9d5..27e0b695 100644 --- a/src/neo4j/_sync/io/_bolt5.py +++ b/src/neo4j/_sync/io/_bolt5.py @@ -107,6 +107,10 @@ def _on_client_state_change(self, old_state, new_state): def _get_client_state_manager(self) -> ClientStateManagerBase: return self._client_state_manager + @property + def ssr_enabled(self) -> bool: + return False + @property def is_reset(self): # We can't be sure of the server's state if there are still pending @@ -141,10 +145,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -615,10 +619,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -702,10 +706,10 @@ def hello(self, dehydration_hooks=None, hydration_hooks=None): ) def on_success(metadata): - self.configuration_hints.update(metadata.pop("hints", {})) + self.connection_hints.update(metadata.pop("hints", {})) self.server_info.update(metadata) - if "connection.recv_timeout_seconds" in self.configuration_hints: - recv_timeout = self.configuration_hints[ + if "connection.recv_timeout_seconds" in self.connection_hints: + recv_timeout = self.connection_hints[ "connection.recv_timeout_seconds" ] if isinstance(recv_timeout, int) and recv_timeout > 0: @@ -883,7 +887,7 @@ def telemetry( hydration_hooks=None, **handlers, ) -> None: - if self.telemetry_disabled or not self.configuration_hints.get( + if self.telemetry_disabled or not self.connection_hints.get( "telemetry.enabled", False ): return @@ -1225,3 +1229,11 @@ def _process_message(self, tag, fields): ) return len(details), 1 + + +class Bolt5x8(Bolt5x7): + PROTOCOL_VERSION = Version(5, 8) + + @property + def ssr_enabled(self) -> bool: + return self.connection_hints.get("ssr.enabled", False) is True diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 1570e745..c04c1165 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -59,6 +59,7 @@ WriteServiceUnavailable, ) from ..config import PoolConfig +from ..home_db_cache import HomeDbCache from ._bolt import Bolt @@ -71,11 +72,48 @@ @dataclass -class AcquireAuth: +class AcquisitionAuth: auth: AuthManager | AuthManager | None force_auth: bool = False +@dataclass +class AcquisitionDatabase: + name: str | None + guessed: bool = False + + +@dataclass +class ConnectionFeatureTracker: + feature_check: t.Callable[[Bolt], bool] + with_feature: int = 0 + without_feature: int = 0 + + @property + def has_feature(self): + return self.with_feature > 0 and self.without_feature == 0 + + def add_connection(self, connection): + if self.feature_check(connection): + self.with_feature += 1 + else: + self.without_feature += 1 + + def remove_connection(self, connection): + if self.feature_check(connection): + if self.with_feature == 0: + raise RuntimeError( + "No connections to be removed from feature tracker" + ) + self.with_feature -= 1 + else: + if self.without_feature == 0: + raise RuntimeError( + "No connections to be removed from feature tracker" + ) + self.without_feature -= 1 + + class IOPool(abc.ABC): """A collection of connections to one or more server addresses.""" @@ -91,11 +129,20 @@ def __init__(self, opener, pool_config, workspace_config): self.connections_reservations = defaultdict(lambda: 0) self.lock = CooperativeRLock() self.cond = Condition(self.lock) + self.home_db_cache = HomeDbCache(max_size=10_000) + self._ssr_feature_tracker = ConnectionFeatureTracker( + feature_check=lambda connection: connection.ssr_enabled + ) @property @abc.abstractmethod def is_direct_pool(self) -> bool: ... + @property + def ssr_enabled(self) -> bool: + with self.lock: + return self._ssr_feature_tracker.has_feature + def __enter__(self): return self @@ -130,6 +177,20 @@ def _remove_connection(self, connection): # connection isn't in the pool anymore. with suppress(ValueError): self.connections.get(address, []).remove(connection) + self._ssr_feature_tracker.remove_connection(connection) + + def _add_connections(self, address, *connections): + with self.lock: + self.connections[address].extend(connections) + for connection in connections: + self._ssr_feature_tracker.add_connection(connection) + + def _remove_connections(self, address, *connections): + with self.lock: + existing_connections = self.connections.get(address, []) + for connection in connections: + existing_connections.remove(connection) + self._ssr_feature_tracker.remove_connection(connection) def _acquire_from_pool_checked( self, address, health_check, deadline @@ -190,7 +251,7 @@ def connection_creator(): with self.lock: self.connections_reservations[address] -= 1 released_reservation = True - self.connections[address].append(connection) + self._add_connections(address, connection) return connection finally: if not released_reservation: @@ -258,7 +319,7 @@ def _acquire(self, address, auth, deadline, liveness_check_timeout): This method is thread safe. """ if auth is None: - auth = AcquireAuth(None) + auth = AcquisitionAuth(None) force_auth = auth.force_auth auth = auth.auth if liveness_check_timeout is None: @@ -353,8 +414,9 @@ def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): """ Acquire a connection to a server that can satisfy a set of parameters. @@ -366,6 +428,7 @@ def acquire( :param bookmarks: :param auth: :param liveness_check_timeout: + :param database_callback: """ ... @@ -490,8 +553,7 @@ def deactivate(self, address): # First remove all connections in question, then try to close them. # If closing of a connection fails, we will end up in this method # again. - for conn in closable_connections: - connections.remove(conn) + self._remove_connections(address, *closable_connections) if not self.connections[address]: del self.connections[address] @@ -537,6 +599,8 @@ def close(self): for address in list(self.connections) for connection in self.connections.pop(address, ()) ] + for connection in connections: + self._ssr_feature_tracker.remove_connection(connection) self._close_connections(connections) except TypeError: pass @@ -582,8 +646,9 @@ def acquire( timeout, database, bookmarks, - auth: AcquireAuth, + auth: AcquisitionAuth, liveness_check_timeout, + database_callback=None, ): # The access_mode and database is not needed for a direct connection, # it's just there for consistency. @@ -673,6 +738,10 @@ def get_or_create_routing_table(self, database): ) return self.routing_tables[database] + def get_routing_table(self, database): + with self.refresh_lock: + return self.routing_tables.get(database) + def fetch_routing_info( self, address, database, imp_user, bookmarks, auth, acquisition_timeout ): @@ -850,8 +919,7 @@ def _update_routing_table_from( address, self.routing_tables[new_database], ) - if callable(database_callback): - database_callback(new_database) + Util.callback(database_callback, new_database) return True self.deactivate(router) return False @@ -885,13 +953,16 @@ def update_routing_table( :raise neo4j.exceptions.ServiceUnavailable: """ with self.refresh_lock: - routing_table = self.get_or_create_routing_table(database) - # copied because it can be modified - existing_routers = set(routing_table.routers) - - prefer_initial_routing_address = self.routing_tables[ - database - ].initialized_without_writers + routing_table = self.get_routing_table(database) + if routing_table is not None: + # copied because it can be modified + existing_routers = set(routing_table.routers) + prefer_initial_routing_address = ( + routing_table.initialized_without_writers + ) + else: + existing_routers = {self.address} + prefer_initial_routing_address = True if ( prefer_initial_routing_address @@ -939,14 +1010,14 @@ def update_routing_table( log.error("Unable to retrieve routing information") raise ServiceUnavailable("Unable to retrieve routing information") - def update_connection_pool(self, *, database): + def update_connection_pool(self): with self.refresh_lock: - routing_tables = [self.get_or_create_routing_table(database)] - for db in self.routing_tables: - if db == database: - continue - routing_tables.append(self.routing_tables[db]) - servers = set.union(*(rt.servers() for rt in routing_tables)) + routing_tables = list(self.routing_tables.values()) + + servers = set.union( + *(rt.servers() for rt in routing_tables), + self.address, + ) for address in list(self.connections): if address._unresolved not in servers: super().deactivate(address) @@ -955,13 +1026,13 @@ def ensure_routing_table_is_fresh( self, *, access_mode, - database, + database: AcquisitionDatabase, imp_user, bookmarks, auth=None, acquisition_timeout=None, database_callback=None, - ): + ) -> bool: """ Update the routing table if stale. @@ -993,8 +1064,10 @@ def ensure_routing_table_is_fresh( ) del self.routing_tables[database_] - routing_table = self.get_or_create_routing_table(database) - if routing_table.is_fresh(readonly=(access_mode == READ_ACCESS)): + routing_table = self.get_routing_table(database.name) + if routing_table is not None and routing_table.is_fresh( + readonly=(access_mode == READ_ACCESS) + ): # table is still valid log.debug( "[#0000] _: using existing routing table %r", @@ -1002,15 +1075,20 @@ def ensure_routing_table_is_fresh( ) return False + database_request = database.name if not database.guessed else None + + def wrapped_database_callback(database: str | None) -> None: + Util.callback(database_callback, database) + self.update_connection_pool() + self.update_routing_table( - database=database, + database=database_request, imp_user=imp_user, bookmarks=bookmarks, auth=auth, acquisition_timeout=acquisition_timeout, - database_callback=database_callback, + database_callback=wrapped_database_callback, ) - self.update_connection_pool(database=database) return True @@ -1047,10 +1125,11 @@ def acquire( self, access_mode, timeout, - database, + database: AcquisitionDatabase, bookmarks, - auth: AcquireAuth | None, + auth: AcquisitionAuth | None, liveness_check_timeout, + database_callback=None, ): if access_mode not in {WRITE_ACCESS, READ_ACCESS}: # TODO: 6.0 - change this to be a ValueError @@ -1064,10 +1143,14 @@ def acquire( from ...api import check_access_mode access_mode = check_access_mode(access_mode) - # await self.ensure_routing_table_is_fresh( - # access_mode=access_mode, database=database, imp_user=None, - # bookmarks=bookmarks, acquisition_timeout=timeout - # ) + + target_database = database.name + + def wrapped_database_callback(new_database): + nonlocal target_database + if new_database is not None: + target_database = new_database + Util.callback(database_callback, new_database) log.debug( "[#0000] _: acquire routing connection, " @@ -1082,6 +1165,11 @@ def acquire( bookmarks=bookmarks, auth=auth, acquisition_timeout=timeout, + database_callback=( + wrapped_database_callback + if database.guessed + else database_callback + ), ) while True: @@ -1089,7 +1177,7 @@ def acquire( # Get an address for a connection that have the fewest in-use # connections. address = self._select_address( - access_mode=access_mode, database=database + access_mode=access_mode, database=target_database ) except (ReadServiceUnavailable, WriteServiceUnavailable) as err: raise SessionExpired( @@ -1100,7 +1188,7 @@ def acquire( log.debug( "[#0000] _: acquire address, database=%r " "address=%r", - database, + target_database, address, ) deadline = Deadline.from_timeout_or_deadline(timeout) diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index 27164cf8..3e0337a6 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -114,6 +114,7 @@ def __init__( warn_notification_severity, on_closed, on_error, + on_database, ) -> None: self._connection_cls = connection.__class__ self._connection = ConnectionErrorHandler( @@ -122,6 +123,7 @@ def __init__( self._hydration_scope = connection.new_hydration_scope() self._on_error = on_error self._on_closed = on_closed + self._on_database = on_database self._metadata: dict = {} self._address: Address = self._connection.unresolved_address self._keys: tuple[str, ...] = () @@ -205,6 +207,9 @@ def on_attached(metadata): self._connection.most_recent_qid = self._raw_qid self._keys = metadata.get("fields") self._attached = True + db_ = metadata.get("db") + if isinstance(db_, str): + Util.callback(self._on_database, db_) def on_failed_attach(metadata): self._metadata.update(metadata) diff --git a/src/neo4j/_sync/work/session.py b/src/neo4j/_sync/work/session.py index 61bd23b8..910fe328 100644 --- a/src/neo4j/_sync/work/session.py +++ b/src/neo4j/_sync/work/session.py @@ -321,6 +321,7 @@ def run( self._config.warn_notification_severity, self._result_closed, self._result_error, + self._make_db_resolution_callback(), ) bookmarks = self._get_bookmarks() parameters = dict(parameters or {}, **kwargs) @@ -448,6 +449,7 @@ def _open_transaction( self._transaction_closed_handler, self._transaction_error_handler, self._transaction_cancel_handler, + self._make_db_resolution_callback(), ) bookmarks = self._get_bookmarks() self._transaction._begin( diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index f1625a24..f8a6e461 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -47,6 +47,7 @@ def __init__( on_closed, on_error, on_cancel, + on_database, ): self._connection = connection self._error_handling_connection = ConnectionErrorHandler( @@ -62,6 +63,7 @@ def __init__( self._on_closed = on_closed self._on_error = on_error self._on_cancel = on_cancel + self._on_database = on_database super().__init__() def _enter(self) -> te.Self: @@ -92,6 +94,11 @@ def _begin( notifications_disabled_classifications, pipelined=False, ): + def on_begin_success(metadata_): + db = metadata_.get("db") + if isinstance(db, str): + Util.callback(self._on_database, db) + self._database = database self._connection.begin( bookmarks=bookmarks, @@ -101,7 +108,10 @@ def _begin( db=database, imp_user=imp_user, notifications_min_severity=notifications_min_severity, - notifications_disabled_classifications=notifications_disabled_classifications, + notifications_disabled_classifications=( + notifications_disabled_classifications + ), + on_success=on_begin_success, ) if not pipelined: self._error_handling_connection.send_all() @@ -188,6 +198,7 @@ def run( self._warn_notification_severity, self._result_on_closed_handler, self._error_handler, + None, ) self._results.append(result) diff --git a/src/neo4j/_sync/work/workspace.py b/src/neo4j/_sync/work/workspace.py index 55ca883d..a85bdf8d 100644 --- a/src/neo4j/_sync/work/workspace.py +++ b/src/neo4j/_sync/work/workspace.py @@ -17,8 +17,10 @@ from __future__ import annotations import logging +import typing as t from ..._async_compat.util import Util +from ..._auth_management import to_auth_dict from ..._conf import WorkspaceConfig from ..._meta import ( deprecation_warn, @@ -32,11 +34,22 @@ ) from .._debug import NonConcurrentMethodChecker from ..io import ( - AcquireAuth, - Neo4jPool, + AcquisitionAuth, + AcquisitionDatabase, ) +if t.TYPE_CHECKING: + from ...api import _TAuth + from ...auth_management import AuthManager + from ..home_db_cache import ( + HomeDbCache, + TKey, + ) +else: + _TAuth = t.Any + + log = logging.getLogger("neo4j") @@ -47,8 +60,9 @@ def __init__(self, pool, config): self._config = config self._connection = None self._connection_access_mode = None + self._last_cache_key: TKey | None = None # Sessions are supposed to cache the database on which to operate. - self._cached_database = False + self._pinned_database = False self._bookmarks = () self._initial_bookmarks = () self._bookmark_manager = None @@ -87,8 +101,26 @@ def __enter__(self) -> Workspace: def __exit__(self, exc_type, exc_value, traceback): self.close() - def _set_cached_database(self, database): - self._cached_database = True + def _make_db_resolution_callback( + self, + ) -> t.Callable[[str | None], None] | None: + if self._pinned_database: + return None + + def _database_callback(database: str | None) -> None: + self._set_pinned_database(database) + if self._last_cache_key is None or database is None: + return + db_cache: HomeDbCache = self._pool.home_db_cache + db_cache.set(self._last_cache_key, database) + + return _database_callback + + def _set_pinned_database(self, database): + if self._pinned_database: + return + log.debug("[#0000] _: pinning database: %r", database) + self._pinned_database = True self._config.database = database def _initialize_bookmarks(self, bookmarks): @@ -138,12 +170,10 @@ def _update_bookmark(self, bookmark): return self._update_bookmarks((bookmark,)) - def _connect(self, access_mode, auth=None, **acquire_kwargs): + def _connect(self, access_mode, auth=None, **acquire_kwargs) -> None: acquisition_timeout = self._config.connection_acquisition_timeout - auth = AcquireAuth( - auth, - force_auth=acquire_kwargs.pop("force_auth", False), - ) + force_auth = acquire_kwargs.pop("force_auth", False) + acquire_auth = AcquisitionAuth(auth, force_auth=force_auth) if self._connection: # TODO: Investigate this @@ -151,40 +181,112 @@ def _connect(self, access_mode, auth=None, **acquire_kwargs): self._connection.send_all() self._connection.fetch_all() self._disconnect() - if not self._cached_database: - if self._config.database is not None or not isinstance( - self._pool, Neo4jPool - ): - self._set_cached_database(self._config.database) - else: - # This is the first time we open a connection to a server in a - # cluster environment for this session without explicitly - # configured database. Hence, we request a routing table update - # to try to fetch the home database. If provided by the server, - # we shall use this database explicitly for all subsequent - # actions within this session. - log.debug("[#0000] _: resolve home database") - self._pool.update_routing_table( - database=self._config.database, - imp_user=self._config.impersonated_user, - bookmarks=self._get_bookmarks(), - auth=auth, - acquisition_timeout=acquisition_timeout, - database_callback=self._set_cached_database, - ) + + ssr_enabled = self._pool.ssr_enabled + target_db = self._get_routing_target_database( + acquire_auth, ssr_enabled=ssr_enabled + ) acquire_kwargs_ = { "access_mode": access_mode, "timeout": acquisition_timeout, - "database": self._config.database, + "database": target_db, "bookmarks": self._get_bookmarks(), - "auth": auth, + "auth": acquire_auth, "liveness_check_timeout": None, + "database_callback": self._make_db_resolution_callback(), } acquire_kwargs_.update(acquire_kwargs) self._connection = self._pool.acquire(**acquire_kwargs_) + if ( + target_db.guessed + and not self._pinned_database + and not self._connection.ssr_enabled + ): + # race condition: we now have created a connection which does not + # support SSR. + # => we need to fall back to explicit home database resolution + log.debug( + "[#0000] _: detected ssr support race; " + "falling back to explicit home database resolution", + ) + self._disconnect() + target_db = self._get_routing_target_database( + acquire_auth, ssr_enabled=False + ) + acquire_kwargs_["database"] = target_db + self._connection = self._pool.acquire(**acquire_kwargs_) self._connection_access_mode = access_mode + def _get_routing_target_database( + self, + acquire_auth: AcquisitionAuth, + ssr_enabled: bool, + ) -> AcquisitionDatabase: + if ( + self._pinned_database + or self._config.database is not None + or self._pool.is_direct_pool + ): + log.debug( + "[#0000] _: routing towards fixed database: %s", + self._config.database, + ) + self._set_pinned_database(self._config.database) + return AcquisitionDatabase(self._config.database) + + auth = acquire_auth.auth + resolved_auth = self._resolve_session_auth(auth) + db_cache: HomeDbCache = self._pool.home_db_cache + cache_key = db_cache.compute_key( + self._config.impersonated_user, + resolved_auth, + ) + self._last_cache_key = cache_key + + if ssr_enabled: + cached_db = db_cache.get(cache_key) + if cached_db is not None: + log.debug( + ( + "[#0000] _: routing towards cached " + "database: %s" + ), + cached_db, + ) + return AcquisitionDatabase(cached_db, guessed=True) + + acquisition_timeout = self._config.connection_acquisition_timeout + log.debug("[#0000] _: resolve home database") + self._pool.update_routing_table( + database=self._config.database, + imp_user=self._config.impersonated_user, + bookmarks=self._get_bookmarks(), + auth=acquire_auth, + acquisition_timeout=acquisition_timeout, + database_callback=self._make_db_resolution_callback(), + ) + return AcquisitionDatabase(self._config.database) + + @staticmethod + def _resolve_session_auth( + auth: AuthManager | AuthManager | None, + ) -> dict | None: + if auth is None: + return None + # resolved_auth = await AsyncUtil.callback(auth.get_auth) + # The above line breaks mypy + # https://github.com/python/mypy/issues/15295 + auth_getter: t.Callable[[], _TAuth | t.Union[_TAuth]] = ( + auth.get_auth + ) + # so we enforce the right type here + # (explicit type annotation above added as it's a necessary assumption + # for this cast to be correct) + resolved_auth = t.cast(_TAuth, Util.callback(auth_getter)) + return to_auth_dict(resolved_auth) + def _disconnect(self, sync=False): + self._last_cache_key = None if self._connection: if sync: try: diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 3e26e389..f08bb30d 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -30,6 +30,7 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import AsyncUtil +from neo4j._routing import RoutingTable from neo4j.auth_management import ( AsyncAuthManager, AsyncAuthManagers, @@ -189,6 +190,7 @@ async def new_driver(backend, data): ("maxTxRetryTimeMs", "max_transaction_retry_time"), ("connectionAcquisitionTimeoutMs", "connection_acquisition_timeout"), ("livenessCheckTimeoutMs", "liveness_check_timeout"), + ("maxConnectionLifetimeMs", "max_connection_lifetime"), ): if data.get(timeout_testkit) is not None: kwargs[timeout_driver] = data[timeout_testkit] / 1000 @@ -991,7 +993,9 @@ async def get_routing_table(backend, data): driver_id = data["driverId"] database = data["database"] driver = backend.drivers[driver_id] - routing_table = driver._pool.routing_tables[database] + routing_table = await driver._pool.get_routing_table(database) + if routing_table is None: + routing_table = RoutingTable(database=database) response_data = { "database": routing_table.database, "ttl": routing_table.ttl, diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 44a9233b..586616e7 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -30,6 +30,7 @@ import neo4j.api import neo4j.auth_management from neo4j._async_compat.util import Util +from neo4j._routing import RoutingTable from neo4j.auth_management import ( AuthManager, AuthManagers, @@ -991,7 +992,9 @@ def get_routing_table(backend, data): driver_id = data["driverId"] database = data["database"] driver = backend.drivers[driver_id] - routing_table = driver._pool.routing_tables[database] + routing_table = driver._pool.get_routing_table(database) + if routing_table is None: + routing_table = RoutingTable(database=database) response_data = { "database": routing_table.database, "ttl": routing_table.ttl, diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index bca7f0ca..0fe66cc5 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -15,7 +15,9 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": "test_subtest_skips.tz_id", "stub\\.routing\\.test_routing_v[0-9x]+\\.RoutingV[0-9x]+\\.test_should_drop_connections_failing_liveness_check": - "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83" + "Liveness check error handling is not (yet) unified: https://github.com/neo-technology/drivers-adr/pull/83", + "'stub.homedb.test_homedb.TestHomeDbMixedCluster.test_connection_acquisition_timeout_during_fallback'": + "TODO: 6.0 - pending unification: connection acquisition timeout should count towards the total time spent waiting for a connection (including routing, home db resolution, ...)" }, "features": { "Feature:API:BookmarkManager": true, @@ -27,6 +29,7 @@ "Feature:API:Driver.VerifyAuthentication": true, "Feature:API:Driver.VerifyConnectivity": true, "Feature:API:Driver.SupportsSessionAuth": true, + "Feature:API:Driver:MaxConnectionLifetime": true, "Feature:API:Driver:NotificationsConfig": true, "Feature:API:Liveness.Check": true, "Feature:API:Result.List": true, @@ -59,6 +62,7 @@ "Feature:Bolt:5.5": true, "Feature:Bolt:5.6": true, "Feature:Bolt:5.7": true, + "Feature:Bolt:5.8": true, "Feature:Bolt:Patch:UTC": true, "Feature:Impersonation": true, "Feature:TLS:1.1": "Driver blocks TLS 1.1 for security reasons.", @@ -70,6 +74,7 @@ "Optimization:ConnectionReuse": true, "Optimization:EagerTransactionBegin": true, "Optimization:ExecuteQueryPipelining": true, + "Optimization:HomeDbCacheBasicPrincipalIsImpersonatedUser": true, "Optimization:ImplicitDefaultArguments": true, "Optimization:MinimalBookmarksSet": true, "Optimization:MinimalResets": true, diff --git a/tests/_async_compat/__init__.py b/tests/_async_compat/__init__.py index 67cd7a37..8170965a 100644 --- a/tests/_async_compat/__init__.py +++ b/tests/_async_compat/__init__.py @@ -14,6 +14,8 @@ # limitations under the License. +from functools import wraps as _wraps + from .mark_decorator import ( AsyncTestDecorators, mark_async_test, @@ -27,4 +29,15 @@ "TestDecorators", "mark_async_test", "mark_sync_test", + "wrap_async", ] + + +def wrap_async(func): + @_wraps(func) + async def wrapper(*args, **kwargs): # noqa: RUF029 + # [noqa] the hole point of this wrapper is to turn a sync function into + # an async one for testing purposes + return func(*args, **kwargs) + + return wrapper diff --git a/tests/_async_util.py b/tests/_async_util.py new file mode 100644 index 00000000..8a032ad8 --- /dev/null +++ b/tests/_async_util.py @@ -0,0 +1,34 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio + + +async def gather_cancel(*coros_or_futures): + """ + Return a future aggregating results from the given coroutines/futures. + + A thin wrapper around asyncio.gather that cancels all coroutines/futures + if any of them raises an exception. + """ + futures = [asyncio.ensure_future(coro) for coro in coros_or_futures] + try: + await asyncio.gather(*futures) + except: + for future in futures: + future.cancel() + await asyncio.gather(*futures, return_exceptions=True) + raise diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9bf96779..98c40df6 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -127,7 +127,14 @@ async def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock @@ -218,7 +225,14 @@ async def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock diff --git a/tests/unit/async_/fixtures/fake_pool.py b/tests/unit/async_/fixtures/fake_pool.py index 877c22fd..892d7b0d 100644 --- a/tests/unit/async_/fixtures/fake_pool.py +++ b/tests/unit/async_/fixtures/fake_pool.py @@ -17,6 +17,7 @@ import pytest from neo4j._async.config import AsyncPoolConfig +from neo4j._async.home_db_cache import AsyncHomeDbCache from neo4j._async.io._pool import AsyncIOPool @@ -32,6 +33,9 @@ def async_fake_pool(async_fake_connection_generator, mocker): pool.buffered_connection_mocks = [] pool.acquired_connection_mocks = [] pool.pool_config = AsyncPoolConfig() + pool.ssr_enabled = False + pool.is_direct_pool = True + pool.home_db_cache = AsyncHomeDbCache(enabled=False) def acquire_side_effect(*_, **__): if pool.buffered_connection_mocks: diff --git a/tests/unit/async_/io/test_class_bolt.py b/tests/unit/async_/io/test_class_bolt.py index b0ddbc96..4469b045 100644 --- a/tests/unit/async_/io/test_class_bolt.py +++ b/tests/unit/async_/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = AsyncBolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ async def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._async.io._bolt5.AsyncBolt5x5"), ((5, 6), "neo4j._async.io._bolt5.AsyncBolt5x6"), ((5, 7), "neo4j._async.io._bolt5.AsyncBolt5x7"), + ((5, 8), "neo4j._async.io._bolt5.AsyncBolt5x8"), ), ) @mark_async_test @@ -181,7 +183,7 @@ async def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) diff --git a/tests/unit/async_/io/test_class_bolt3.py b/tests/unit/async_/io/test_class_bolt3.py index e2f56ff9..509b6c76 100644 --- a/tests/unit/async_/io/test_class_bolt3.py +++ b/tests/unit/async_/io/test_class_bolt3.py @@ -129,7 +129,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -569,3 +569,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt3.PACKER_CLS, + unpacker_cls=AsyncBolt3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x0.py b/tests/unit/async_/io/test_class_bolt4x0.py index fa555fd1..e3af55ad 100644 --- a/tests/unit/async_/io/test_class_bolt4x0.py +++ b/tests/unit/async_/io/test_class_bolt4x0.py @@ -232,7 +232,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -660,3 +660,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x0.PACKER_CLS, + unpacker_cls=AsyncBolt4x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x1.py b/tests/unit/async_/io/test_class_bolt4x1.py index e7ca17e0..7f33fe01 100644 --- a/tests/unit/async_/io/test_class_bolt4x1.py +++ b/tests/unit/async_/io/test_class_bolt4x1.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -682,3 +682,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x1.PACKER_CLS, + unpacker_cls=AsyncBolt4x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x1( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x2.py b/tests/unit/async_/io/test_class_bolt4x2.py index bffb4424..d243aef6 100644 --- a/tests/unit/async_/io/test_class_bolt4x2.py +++ b/tests/unit/async_/io/test_class_bolt4x2.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -682,3 +682,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x2.PACKER_CLS, + unpacker_cls=AsyncBolt4x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x2( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x3.py b/tests/unit/async_/io/test_class_bolt4x3.py index 1f249feb..e1a23e01 100644 --- a/tests/unit/async_/io/test_class_bolt4x3.py +++ b/tests/unit/async_/io/test_class_bolt4x3.py @@ -254,7 +254,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -711,3 +711,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x3.PACKER_CLS, + unpacker_cls=AsyncBolt4x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt4x4.py b/tests/unit/async_/io/test_class_bolt4x4.py index 695ac7c9..c24f06a9 100644 --- a/tests/unit/async_/io/test_class_bolt4x4.py +++ b/tests/unit/async_/io/test_class_bolt4x4.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -671,3 +671,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt4x4.PACKER_CLS, + unpacker_cls=AsyncBolt4x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt4x4( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x0.py b/tests/unit/async_/io/test_class_bolt5x0.py index d1f09dcc..8a5de715 100644 --- a/tests/unit/async_/io/test_class_bolt5x0.py +++ b/tests/unit/async_/io/test_class_bolt5x0.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -735,3 +735,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x0.PACKER_CLS, + unpacker_cls=AsyncBolt5x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x0( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x1.py b/tests/unit/async_/io/test_class_bolt5x1.py index 003263aa..847ff059 100644 --- a/tests/unit/async_/io/test_class_bolt5x1.py +++ b/tests/unit/async_/io/test_class_bolt5x1.py @@ -280,7 +280,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -789,3 +789,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x1.PACKER_CLS, + unpacker_cls=AsyncBolt5x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x1( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x2.py b/tests/unit/async_/io/test_class_bolt5x2.py index 345c9a52..c4c08eda 100644 --- a/tests/unit/async_/io/test_class_bolt5x2.py +++ b/tests/unit/async_/io/test_class_bolt5x2.py @@ -278,7 +278,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -826,3 +826,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x2.PACKER_CLS, + unpacker_cls=AsyncBolt5x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x2( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x3.py b/tests/unit/async_/io/test_class_bolt5x3.py index c70a3df4..e3a76563 100644 --- a/tests/unit/async_/io/test_class_bolt5x3.py +++ b/tests/unit/async_/io/test_class_bolt5x3.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -713,3 +713,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x3.PACKER_CLS, + unpacker_cls=AsyncBolt5x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x3( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x4.py b/tests/unit/async_/io/test_class_bolt5x4.py index 7ff21e09..48e74114 100644 --- a/tests/unit/async_/io/test_class_bolt5x4.py +++ b/tests/unit/async_/io/test_class_bolt5x4.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -718,3 +718,25 @@ def on_success(metadata): await connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x4.PACKER_CLS, + unpacker_cls=AsyncBolt5x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x4( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x5.py b/tests/unit/async_/io/test_class_bolt5x5.py index 77d748de..60ea25ee 100644 --- a/tests/unit/async_/io/test_class_bolt5x5.py +++ b/tests/unit/async_/io/test_class_bolt5x5.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -756,3 +756,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x5.PACKER_CLS, + unpacker_cls=AsyncBolt5x5.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x5( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x6.py b/tests/unit/async_/io/test_class_bolt5x6.py index a5106572..1f11cf75 100644 --- a/tests/unit/async_/io/test_class_bolt5x6.py +++ b/tests/unit/async_/io/test_class_bolt5x6.py @@ -281,7 +281,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -760,3 +760,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x6.PACKER_CLS, + unpacker_cls=AsyncBolt5x6.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x6( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x7.py b/tests/unit/async_/io/test_class_bolt5x7.py index 97a8b4ea..09752de1 100644 --- a/tests/unit/async_/io/test_class_bolt5x7.py +++ b/tests/unit/async_/io/test_class_bolt5x7.py @@ -282,7 +282,7 @@ async def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) await connection.send_all() @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x7.PACKER_CLS, + unpacker_cls=AsyncBolt5x7.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x7( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/async_/io/test_class_bolt5x8.py b/tests/unit/async_/io/test_class_bolt5x8.py new file mode 100644 index 00000000..6a105d1e --- /dev/null +++ b/tests/unit/async_/io/test_class_bolt5x8.py @@ -0,0 +1,872 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.io._bolt5 import AsyncBolt5x8 +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_async_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = AsyncBolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_async_test +async def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_async_test +async def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + await connection.send_all() + tag, is_fields = await socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_async_test +async def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_async_test +async def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_async_test +async def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_async_test +async def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, socket, AsyncPoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + await connection.send_all() + tag, fields = await socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_async_test +async def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + await connection.hello() + tag, fields = await sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_async_test +async def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + await connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = await socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + await socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_async_test +async def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + await sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + await connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_async_test +async def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + await connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_async_test +async def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, AsyncBolt5x8.UNPACKER_CLS) + connection = AsyncBolt5x8( + address, + socket, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + await connection.send_all() + + _, fields = await socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_async_test +async def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, + sockets.client, + AsyncPoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_async_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +async def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = AsyncBolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + await connection.hello() + + _tag, fields = await sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_async_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +async def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + await connection.send_all() + _tag, fields = await sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_async_test +async def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + await sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + await sockets.server.send_message(b"\x70", {}) + await connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + await sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + await sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + await connection.send_all() + await connection.fetch_all() + assert connection.last_database == db + + await sockets.server.send_message(b"\x70", {}) + if finish == "reset": + await connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + await connection.send_all() + await connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_async_test +async def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + await sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + await connection.send_all() + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_async_test +async def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + connection = AsyncBolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + await sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + await connection.send_all() + with pytest.raises(Neo4jError): + await connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_async_test +async def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=AsyncBolt5x8.PACKER_CLS, + unpacker_cls=AsyncBolt5x8.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + await sockets.server.send_message(b"\x70", meta) + await sockets.server.send_message(b"\x70", {}) + connection = AsyncBolt5x8( + address, sockets.client, AsyncPoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + await connection.hello() + assert connection.ssr_enabled is bool(ssr_hint) diff --git a/tests/unit/async_/io/test_direct.py b/tests/unit/async_/io/test_direct.py index 80014266..c6add31f 100644 --- a/tests/unit/async_/io/test_direct.py +++ b/tests/unit/async_/io/test_direct.py @@ -67,6 +67,7 @@ async def acquire( bookmarks, auth, liveness_check_timeout, + database_callback=None, ): return await self._acquire( self.address, auth, timeout, liveness_check_timeout diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index c0be16ad..e1549fd0 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -24,6 +24,7 @@ ) from neo4j._async.config import AsyncPoolConfig from neo4j._async.io import ( + AcquisitionDatabase, AsyncBolt, AsyncNeo4jPool, ) @@ -53,11 +54,27 @@ WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + @pytest.fixture def custom_routing_opener(async_fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener( + failures=None, + get_readers=None, + db_resolve=_default_db_resolve, + on_open=None, + ): def routing_side_effect(*args, **kwargs): nonlocal failures + opener_.route_requests.append(kwargs.get("database")) res = next(failures, None) if res is None: routers = [ @@ -70,16 +87,18 @@ def routing_side_effect(*args, **kwargs): else: readers = [str(READER1_ADDRESS)] writers = [str(WRITER1_ADDRESS)] - return [ - { - "ttl": 1000, - "servers": [ - {"addresses": routers, "role": "ROUTE"}, - {"addresses": readers, "role": "READ"}, - {"addresses": writers, "role": "WRITE"}, - ], - } - ] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] raise res async def open_(addr, auth, timeout): @@ -92,11 +111,16 @@ async def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + return connection failures = iter(failures or []) opener_ = mocker.AsyncMock() opener_.connections = [] + opener_.route_requests = [] opener_.side_effect = open_ return opener_ @@ -124,54 +148,101 @@ def _simple_pool(opener) -> AsyncNeo4jPool: ) +TEST_DB1 = AcquisitionDatabase("test_db1") +TEST_DB2 = AcquisitionDatabase("test_db2") + + +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_async_test -async def test_acquires_new_routing_table_if_deleted(opener): +async def test_acquires_new_routing_table_if_deleted( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - del pool.routing_tables["test_db"] + del pool.routing_tables[db.name] - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_async_test -async def test_acquires_new_routing_table_if_stale(opener): +async def test_acquires_new_routing_table_if_stale( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - old_value = pool.routing_tables["test_db"].last_updated_time - pool.routing_tables["test_db"].ttl = 0 + old_value = pool.routing_tables[db.name].last_updated_time + pool.routing_tables[db.name].ttl = 0 - cx = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, db, None, None, None) await pool.release(cx) - assert pool.routing_tables["test_db"].last_updated_time > old_value + assert pool.routing_tables[db.name].last_updated_time > old_value + assert opener.route_requests == [None if guessed_db else db.name] @mark_async_test async def test_removes_old_routing_table(opener): pool = _simple_pool(opener) - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db1") - cx = await pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) + assert pool.routing_tables.get(TEST_DB1.name) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx) - assert pool.routing_tables.get("test_db2") + assert pool.routing_tables.get(TEST_DB2.name) - old_value = pool.routing_tables["test_db1"].last_updated_time - pool.routing_tables["test_db1"].ttl = 0 - db2_rt = pool.routing_tables["test_db2"] + old_value = pool.routing_tables[TEST_DB1.name].last_updated_time + pool.routing_tables[TEST_DB1.name].ttl = 0 + db2_rt = pool.routing_tables[TEST_DB2.name] db2_rt.ttl = -RoutingConfig.routing_table_purge_delay - cx = await pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx) - assert pool.routing_tables["test_db1"].last_updated_time > old_value - assert "test_db2" not in pool.routing_tables + assert pool.routing_tables[TEST_DB1.name].last_updated_time > old_value + assert TEST_DB2.name not in pool.routing_tables + + +@pytest.mark.parametrize("guessed_db", (True, False)) +@mark_async_test +async def test_db_resolution_callback(custom_routing_opener, guessed_db): + cb_calls = [] + + def cb(db_): + nonlocal cb_calls + cb_calls.append(db_) + + db = AcquisitionDatabase("test_db", guessed=guessed_db) + home_db = "home_db" + expected_target_db = home_db if db.guessed else db.name + + opener = custom_routing_opener(db_resolve=make_home_db_resolve(home_db)) + pool = _simple_pool(opener) + cx = await pool.acquire( + READ_ACCESS, 30, db, None, None, None, database_callback=cb + ) + await pool.release(cx) + + assert pool.routing_tables.get(expected_target_db) + assert opener.route_requests == [None if guessed_db else db.name] + assert cb_calls == [expected_target_db] @pytest.mark.parametrize("type_", ("r", "w")) @@ -181,7 +252,7 @@ async def test_chooses_right_connection_type(opener, type_): cx1 = await pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, - "test_db", + TEST_DB1, None, None, None, @@ -196,9 +267,9 @@ async def test_chooses_right_connection_type(opener, type_): @mark_async_test async def test_reuses_connection(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 is cx2 @@ -216,7 +287,7 @@ async def break_connection(): return None pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) and then @@ -226,7 +297,7 @@ async def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -241,12 +312,12 @@ async def break_connection(): @mark_async_test async def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -259,7 +330,7 @@ async def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -271,7 +342,7 @@ async def test_does_not_close_stale_connections_in_use(opener): @mark_async_test async def test_release_resets_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() await pool.release(cx1) @@ -282,7 +353,7 @@ async def test_release_resets_connections(opener): @mark_async_test async def test_release_does_not_resets_closed_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -295,7 +366,7 @@ async def test_release_does_not_resets_closed_connections(opener): @mark_async_test async def test_release_does_not_resets_defunct_connections(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -457,8 +528,8 @@ async def close_side_effect(): # create pool with 2 idle connections pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) await pool.release(cx2) @@ -470,7 +541,7 @@ async def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -479,11 +550,11 @@ async def close_side_effect(): @mark_async_test async def test_failing_opener_leaves_connections_in_use_alone(opener): pool = _simple_pool(opener) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert not cx1.closed() @@ -505,7 +576,7 @@ async def test__acquire_new_later_without_room(opener): config = _pool_config() config.max_connection_pool_size = 1 pool = AsyncNeo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) - _ = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + _ = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) # pool is full now assert pool.connections_reservations[READER1_ADDRESS] == 0 creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) @@ -559,13 +630,13 @@ async def test_discovery_is_retried(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 - cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx2) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(TEST_DB1.name) assert cx1 is cx2 @@ -611,12 +682,12 @@ async def test_fast_failing_discovery(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 with pytest.raises(error.__class__) as exc: - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert exc.value is error @@ -657,11 +728,11 @@ async def test_connection_error_callback( config.auth = auth_manager pool = AsyncNeo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) cxs_read = [ - await pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] cxs_write = [ - await pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + await pool.acquire(WRITE_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] @@ -690,7 +761,7 @@ async def test_connection_error_callback( @mark_async_test async def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): - readers = {"db1": [str(READER1_ADDRESS)]} + readers = {TEST_DB1.name: [str(READER1_ADDRESS)]} def get_readers(database): return readers[database] @@ -700,7 +771,7 @@ def get_readers(database): pool = AsyncNeo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1.unresolved_address == READER1_ADDRESS await pool.release(cx1) @@ -708,10 +779,10 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 # force RT refresh, returning a different reader - del pool.routing_tables["db1"] - readers["db1"] = [str(READER2_ADDRESS)] + del pool.routing_tables[TEST_DB1.name] + readers[TEST_DB1.name] = [str(READER2_ADDRESS)] - cx2 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx2.unresolved_address == READER2_ADDRESS cx1.close.assert_awaited_once() @@ -726,8 +797,8 @@ async def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( custom_routing_opener, ): readers = { - "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], - "db2": [str(READER1_ADDRESS)], + TEST_DB1.name: [str(READER1_ADDRESS), str(READER2_ADDRESS)], + TEST_DB2.name: [str(READER1_ADDRESS)], } def get_readers(database): @@ -738,14 +809,14 @@ def get_readers(database): pool = AsyncNeo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = await pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) await pool.release(cx1) assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 - cx2 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx2) assert cx2.unresolved_address == READER1_ADDRESS cx1.close.assert_not_called() @@ -754,10 +825,10 @@ def get_readers(database): assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count # force RT refresh, returning a different reader - del pool.routing_tables["db2"] - readers["db2"] = [str(READER3_ADDRESS)] + del pool.routing_tables[TEST_DB2.name] + readers[TEST_DB2.name] = [str(READER3_ADDRESS)] - cx3 = await pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) await pool.release(cx3) assert cx3.unresolved_address == READER3_ADDRESS @@ -767,3 +838,79 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_async_test +async def test_tracks_ssr_connection_hints(custom_routing_opener): + connection_count = 0 + + def on_open(connection): + if connection.unresolved_address in { + ROUTER1_ADDRESS, + ROUTER2_ADDRESS, + ROUTER3_ADDRESS, + }: + connection.ssr_enabled = True + return + nonlocal connection_count + connection_count += 1 + connection.ssr_enabled = connection_count != 2 + + opener = custom_routing_opener(on_open=on_open) + pool = AsyncNeo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + # no connection in pool => cannot know => defensive assumption: off + assert not pool.ssr_enabled + + # open 1st reader connection (supports SSR) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx1.ssr_enabled # double check we got the mocking right + + assert pool.ssr_enabled + + # open 2nd reader connection (does not support SSR) + cx2 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert not cx2.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + # open 3rd reader connection (supports SSR) + cx3 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx3.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + await pool.release(cx1) + await pool.release(cx2) + await pool.release(cx3) + + assert not pool.ssr_enabled + + cxs = [ + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert sum(not c.ssr_enabled for c in cxs) == 1 # double check + + for cx in (cx for cx in cxs if not cx.ssr_enabled): + await cx.close() + + # after the single connection without SSR support is closed + for cx in cxs: + await pool.release(cx) + + # force pool cleaning up all stale connections: + cxs = [ + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert all(cx.ssr_enabled for cx in cxs) # double check + + assert pool.ssr_enabled + + for cx in cxs: + await pool.release(cx) + + assert pool.ssr_enabled diff --git a/tests/unit/async_/test_driver.py b/tests/unit/async_/test_driver.py index 71ba0e90..c26d94a8 100644 --- a/tests/unit/async_/test_driver.py +++ b/tests/unit/async_/test_driver.py @@ -277,6 +277,7 @@ async def test_driver_opens_write_session_by_default( bookmarks=mocker.ANY, auth=mocker.ANY, liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, ) tx._begin.assert_awaited_once_with( mocker.ANY, diff --git a/tests/unit/async_/test_home_db_cache.py b/tests/unit/async_/test_home_db_cache.py new file mode 100644 index 00000000..fe644ed1 --- /dev/null +++ b/tests/unit/async_/test_home_db_cache.py @@ -0,0 +1,293 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import time +import typing as t +from datetime import ( + datetime, + timedelta, +) + +import freezegun +import pytest +import pytz + +from neo4j._async.config import AsyncPoolConfig +from neo4j._async.home_db_cache import AsyncHomeDbCache +from neo4j._async.io._pool import AsyncNeo4jPool +from neo4j._conf import WorkspaceConfig +from neo4j.time import DateTime + + +if t.TYPE_CHECKING: + from neo4j._async.home_db_cache import TKey + + +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_none_is_none(enabled: bool) -> None: + assert AsyncHomeDbCache(enabled=enabled).compute_key(None, None) == (None,) + + +@pytest.mark.parametrize( + "auth", + ( + None, + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "nice token"}, + {"foo": "bar"}, + ), +) +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_imp_precedence_over_auth( + auth: dict | None, + enabled: bool, +) -> None: + cache = AsyncHomeDbCache(enabled=enabled) + assert cache.compute_key("bob", auth) == ("bob" if enabled else (None,)) + + +@pytest.mark.parametrize( + "auth", + ( + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "basic", "principal": "this is wrong, no password?"}, + {"scheme": "basic", "credentials": "this is wrong, no user?"}, + {"scheme": "none"}, + {"scheme": "none", "principal": "even though the scheme is none"}, + {"scheme": "kerberos", "principal": "", "credentials": "ticket"}, + {"scheme": "bearer", "credentials": "nice SSO token"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "bar", "parameters": {"oh": "hi"}}, + {"foo": "bar"}, + ), +) +def test_key_reduces_basic_auth_to_principal(auth: dict) -> None: + key = AsyncHomeDbCache().compute_key(None, auth) + if auth.get("scheme") == "basic" and "principal" in auth: + assert isinstance(key, str) + assert key == auth["principal"] + else: + assert isinstance(key, tuple) + for e in key: + assert isinstance(e, tuple) and len(e) == 2 + assert isinstance(e[0], str) + + +_NAN = float("nan") +_NOW = pytz.timezone("Europe/Stockholm").localize( + DateTime(2021, 8, 12, 12, 34, 57, 123456789) +) + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + ( + ({}, {}), + ({"foo": "bar"}, {"foo": "bar"}), + ({"a": 1, "b": 2}, {"b": 2, "a": 1}), + ( + { + "scheme": "funky", + "credentials": "t0pS3cr3t!!11", + "parameters": { + "how much": 1.5, + # Note: for special values (NaN, temporal types, etc.), + # equality may rely on object identity. + "why": "because", + "difficult": _NAN, + "also difficult 🔥": _NOW, + }, + }, + { + "parameters": { + "also difficult 🔥": _NOW, + "difficult": _NAN, + "why": "because", + "how much": 1.5, + }, + "credentials": "t0pS3cr3t!!11", + "scheme": "funky", + }, + ), + ), +) +def test_key_auth_equality(auth1: dict, auth2: dict) -> None: + cache = AsyncHomeDbCache() + key1 = cache.compute_key(None, auth1) + key2 = cache.compute_key(None, auth2) + + assert len(cache) == 0 + + cache.set(key1, "value") + assert len(cache) == 1 + assert cache.get(key1) == "value" + + cache.set(key2, "value2") + assert len(cache) == 1 + assert cache.get(key1) == "value2" + assert cache.get(key2) == "value2" + + assert key1 == key2 + + +def _assert_entries( + cache: AsyncHomeDbCache, + expected_entries: t.Collection[tuple[TKey, str]], + allow_subset: bool = False, +) -> None: + __tracebackhide__ = True + if not allow_subset: + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + else: + hits = sum(cache.get(key) == value for key, value in expected_entries) + assert hits == len(cache) + + +def _force_cache_clean( + cache: AsyncHomeDbCache, + now: float | None = None, +) -> None: + cache._clean(now) + + +def test_cache_ttl() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + + entries = [] + for i in range(1, 11): + time.move_to(t0 + timedelta(seconds=0.25) * (i - 1)) + + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i - timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i + timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + entries = entries[-3:] + _assert_entries(cache, entries) + + +def test_cache_ttl_empty_cache() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_does_not_return_expired_entries() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = AsyncHomeDbCache(ttl=1) + key = cache.compute_key("key", None) + value = "value" + + cache.set(cache.compute_key("key", None), value) + assert cache.get(key) == value + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + assert cache.get(key) is None + + +def test_cache_max_size() -> None: + cache = AsyncHomeDbCache(max_size=4) + + entries = [] + for i in range(1, 11): + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + entries = entries[-4:] + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries, allow_subset=True) + + +def test_cache_max_size_empty_cache() -> None: + cache = AsyncHomeDbCache(max_size=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_clean_up_time() -> None: + def get_default_cache(): + pool = AsyncNeo4jPool( + lambda: None, AsyncPoolConfig(), WorkspaceConfig(), None + ) + return pool.home_db_cache + + repetitions = 5 + scenario_timings = [] + + # Test assumes that by default the driver uses a home db cache only limited + # by its size. + default_cache = get_default_cache() + default_max_size = default_cache._max_size + assert isinstance(default_max_size, int) + # If ttl ever get used, this test needs to be updated to also test pruning + # by TTL. + assert math.isinf(default_cache._ttl) and default_cache._ttl > 0 + + for max_size, count in ( + # no pruning needed + (default_max_size * 10, default_max_size * 10), + # pruning needed + (default_max_size, default_max_size * 10), + ): + cache = AsyncHomeDbCache(max_size=max_size) + keys = [cache.compute_key(f"key{i}", None) for i in range(count)] + rep_timings = [] + for _ in range(repetitions): + t0 = time.perf_counter() + for key in keys: + cache.set(key, "value") + t1 = time.perf_counter() + rep_timings.append(t1 - t0) + scenario_timings.append(sum(rep_timings) / len(rep_timings)) + + # pruning shouldn't take more than 20 times the time of no pruning + # N.B., the pruning takes O(n * log(n)) where n is max_size. By only + # pruning O(n * log(n)) elements, we get an amortized pruning overhead of + # O(1) (as long as max_size is small enough to be able to choose a positive + # pruning size). + assert scenario_timings[1] <= 20 * scenario_timings[0] diff --git a/tests/unit/async_/work/test_result.py b/tests/unit/async_/work/test_result.py index 1291d95e..fb71105f 100644 --- a/tests/unit/async_/work/test_result.py +++ b/tests/unit/async_/work/test_result.py @@ -60,7 +60,10 @@ Neo4jWarning, ) -from ...._async_compat import mark_async_test +from ...._async_compat import ( + mark_async_test, + wrap_async, +) if t.TYPE_CHECKING: @@ -315,7 +318,7 @@ async def fetch_and_compare_all_records( @mark_async_test async def test_result_iteration(method, records): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, 2, None, noop, noop) + result = AsyncResult(connection, 2, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) await fetch_and_compare_all_records(result, "x", records, method) @@ -324,7 +327,7 @@ async def test_result_iteration(method, records): async def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, 4, None, noop, noop) + result = AsyncResult(connection, 4, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) iter1 = AsyncUtil.iter(result) iter2 = AsyncUtil.iter(result) @@ -372,9 +375,9 @@ async def test_parallel_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = AsyncResult(connection, 2, None, noop, noop) + result1 = AsyncResult(connection, 2, None, noop, noop, None) await result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = AsyncResult(connection, 2, None, noop, noop) + result2 = AsyncResult(connection, 2, None, noop, noop, None) await result2._run("CYPHER2", {}, None, None, "r", None, None, None) if invert_fetch: await fetch_and_compare_all_records(result2, "x", records2, method) @@ -395,9 +398,9 @@ async def test_interwoven_result_iteration(method, invert_fetch): connection = AsyncConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = AsyncResult(connection, 2, None, noop, noop) + result1 = AsyncResult(connection, 2, None, noop, noop, None) await result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = AsyncResult(connection, 2, None, noop, noop) + result2 = AsyncResult(connection, 2, None, noop, noop, None) await result2._run("CYPHER2", {}, None, None, "r", None, None, None) start = 0 for n in (1, 2, 3, 1, None): @@ -424,7 +427,7 @@ async def test_interwoven_result_iteration(method, invert_fetch): @mark_async_test async def test_result_peek(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) for i in range(len(records) + 1): record = await result.peek() @@ -447,7 +450,7 @@ async def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) == 0: assert await result.single(**kwargs) is None @@ -466,7 +469,7 @@ async def test_result_single_non_strict(records, fetch_size, default): @mark_async_test async def test_result_single_strict(records, fetch_size): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) != 1: with pytest.raises(ResultNotSingleError) as exc: @@ -490,7 +493,7 @@ async def test_result_single_strict(records, fetch_size): @mark_async_test async def test_result_single_exhausts_records(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) try: with warnings.catch_warnings(): @@ -512,7 +515,7 @@ async def test_result_single_exhausts_records(records, fetch_size, strict): @mark_async_test async def test_result_fetch(records, fetch_size, strict): connection = AsyncConnectionStub(records=Records(["x"], records)) - result = AsyncResult(connection, fetch_size, None, noop, noop) + result = AsyncResult(connection, fetch_size, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) assert await result.fetch(0) == [] assert await result.fetch(-1) == [] @@ -524,7 +527,7 @@ async def test_result_fetch(records, fetch_size, strict): @mark_async_test async def test_keys_are_available_before_and_after_stream(): connection = AsyncConnectionStub(records=Records(["x"], [[1], [2]])) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) assert list(result.keys()) == ["x"] await AsyncUtil.list(result) @@ -540,7 +543,7 @@ async def test_consume(records, consume_one, summary_meta, consume_times): connection = AsyncConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if consume_one: with suppress(StopAsyncIteration): @@ -574,7 +577,7 @@ async def test_time_in_summary(t_first, t_last): summary_meta=summary_meta, ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -596,7 +599,7 @@ async def test_time_in_summary(t_first, t_last): async def test_counts_in_summary(): connection = AsyncConnectionStub(records=Records(["n"], [[1], [2]])) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -610,7 +613,7 @@ async def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) summary = await result.consume() @@ -625,7 +628,7 @@ async def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) await result._buffer_all() records = result._record_buffer.copy() @@ -667,7 +670,7 @@ async def test_data(num_records): @mark_async_test async def test_result_graph(records): connection = AsyncConnectionStub(records=records) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) graph = await result.graph() assert isinstance(graph, Graph) @@ -760,7 +763,7 @@ async def test_result_graph(records): async def test_to_eager_result(records): summary = {"test_to_eager_result": uuid.uuid4()} connection = AsyncConnectionStub(records=records, summary_meta=summary) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) eager_result = await result.to_eager_result() @@ -850,7 +853,7 @@ async def test_to_eager_result(records): @mark_async_test async def test_to_df(keys, values, types, instances, test_default_expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) if test_default_expand: df = await result.to_df() @@ -1061,7 +1064,7 @@ async def test_to_df_expand( keys, values, expected_columns, expected_rows, expected_types ): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) df = await result.to_df(expand=True) @@ -1299,7 +1302,7 @@ async def test_to_df_expand( @mark_async_test async def test_to_df_parse_dates(keys, values, expected_df, expand): connection = AsyncConnectionStub(records=Records(keys, values)) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) df = await result.to_df(expand=expand, parse_dates=True) @@ -1314,7 +1317,7 @@ async def test_broken_hydration(nested): value_in = [value_in] records_in = Records(["foo", "bar"], [["foobar", value_in]]) connection = AsyncConnectionStub(records=records_in) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) await result._run("CYPHER", {}, None, None, "r", None, None, None) records_out = await AsyncUtil.list(result) assert len(records_out) == 1 @@ -1366,7 +1369,9 @@ async def test_notification_warning( ] }, ) - result = AsyncResult(connection, 1, warn_notification_severity, noop, noop) + result = AsyncResult( + connection, 1, warn_notification_severity, noop, noop, None + ) if expected_warning is None: with warnings.catch_warnings(): warnings.simplefilter("error") # assert not warnings are emitted @@ -1408,7 +1413,7 @@ async def test_notification_logging( records=Records(["foo"], ()), summary_meta={"notifications": [notification_data]}, ) - result = AsyncResult(connection, 1, None, noop, noop) + result = AsyncResult(connection, 1, None, noop, noop, None) with caplog.at_level(logging.INFO, logger="neo4j.notifications"): await result._run("CYPHER", {}, None, None, "r", None, None, None) await result.consume() @@ -1420,3 +1425,35 @@ async def test_notification_logging( f"Received notification from DBMS server: {formatted_notification}" ) assert caplog.messages[0] == expected_message + + +@pytest.mark.parametrize( + "async_cb", + (True, False) if AsyncUtil.is_async_code else (False,), +) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_async_test +async def test_on_database_callback(async_cb, resolved_db): + cb_calls = [] + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + if async_cb: + db_callback = wrap_async(db_callback) + + run_meta = {} + if resolved_db is not ...: + run_meta["db"] = resolved_db + connection = AsyncConnectionStub( + records=Records(["foo"], ()), run_meta=run_meta + ) + + result = AsyncResult(connection, 1, None, noop, noop, db_callback) + await result._run("CYPHER", {}, None, None, "r", None, None, None) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/async_/work/test_session.py b/tests/unit/async_/work/test_session.py index 6ec6fac2..06ba7046 100644 --- a/tests/unit/async_/work/test_session.py +++ b/tests/unit/async_/work/test_session.py @@ -22,14 +22,19 @@ AsyncManagedTransaction, AsyncSession, AsyncTransaction, + Auth, Bookmarks, unit_of_work, ) from neo4j._api import TelemetryAPI +from neo4j._async.home_db_cache import AsyncHomeDbCache from neo4j._async.io import ( + AcquisitionDatabase, AsyncBoltPool, AsyncNeo4jPool, ) +from neo4j._async_compat.util import AsyncUtil +from neo4j._auth_management import to_auth_dict from neo4j._conf import SessionConfig from neo4j.api import ( AsyncBookmarkManager, @@ -490,8 +495,10 @@ async def bmm_get_bookmarks(): async_fake_pool.update_routing_table.side_effect = ( update_routing_table_side_effect ) + async_fake_pool.is_direct_pool = False else: async_fake_pool.mock_add_spec(AsyncBoltPool) + async_fake_pool.is_direct_pool = True config = SessionConfig() config.bookmark_manager = bmm @@ -699,3 +706,174 @@ async def work(_): connection_mock.telemetry.assert_called_once() call_args = connection_mock.telemetry.call_args.args assert call_args[0] == TelemetryAPI.DRIVER + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("imp_user", (None, "imp_user")) +@pytest.mark.parametrize( + "auth", + ( + None, + Auth(scheme="magic-auth", principal=None, credentials="tada"), + ), +) +@mark_async_test +async def test_uses_home_db_cache_when_expected( + async_fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + imp_user, + auth, +): + async_fake_pool.ssr_enabled = pool_ssr + if pool_routing: + async_fake_pool.is_direct_pool = False + async_fake_pool.mock_add_spec(AsyncNeo4jPool) + cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache()) + cached_db = "nice_cached_home_db" + key = object() + cache_spy.compute_key.return_value = key + cache_spy.get.return_value = cached_db + async_fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.impersonated_user = imp_user + config.auth = auth + config.database = db + + async with AsyncSession(async_fake_pool, config) as session: + await session.run("RETURN 1") + + if expect_cache_usage: + # assert using cache + assert cache_spy.mock_calls == [ + mocker.call.compute_key( + imp_user, to_auth_dict(auth) if auth else None + ), + mocker.call.get(key), + ] + # assert passing cache result as a guess to the pool + async_fake_pool.acquire.assert_awaited_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(cached_db, guessed=True), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + else: + # assert not using cache + cache_spy.get.assert_not_called() + # assert passing a non-guess to the pool + async_fake_pool.acquire.assert_awaited_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(db, guessed=False), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("resolution_at", ("route", "run", "begin")) +@mark_async_test +async def test_pinns_session_db_with_cache( + async_fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + resolution_at, +): + async def resolve_db(): + if resolution_at == "route": + database_callback = async_fake_pool.acquire.call_args.kwargs[ + "database_callback" + ] + await AsyncUtil.callback(database_callback, resolved_db) + elif resolution_at == "run": + database_callback = res_mock.call_args.args[-1] + await AsyncUtil.callback(database_callback, resolved_db) + elif resolution_at == "begin": + database_callback = tx_mock.call_args.args[-1] + await AsyncUtil.callback(database_callback, resolved_db) + else: + raise ValueError(f"Unknown resolution_at: {resolution_at}") + + if resolution_at == "run": + res_mock = mocker.patch( + "neo4j._async.work.session.AsyncResult", autospec=True + ) + elif resolution_at == "begin": + tx_mock = mocker.patch( + "neo4j._async.work.session.AsyncTransaction", autospec=True + ) + + resolved_db = "resolved_db" + async_fake_pool.ssr_enabled = pool_ssr + if pool_routing: + async_fake_pool.is_direct_pool = False + async_fake_pool.mock_add_spec(AsyncNeo4jPool) + cache_spy = mocker.Mock(spec=AsyncHomeDbCache, wraps=AsyncHomeDbCache()) + key = object() + cache_spy.compute_key.return_value = key + async_fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.database = db + + async with AsyncSession(async_fake_pool, config) as session: + if resolution_at == "begin": + async with await session.begin_transaction() as tx: + await tx.run("RETURN 1") + else: + await session.run("RETURN 1") + + if expect_cache_usage: + # assert never using cache to pin a database + assert not session._pinned_database + assert config.database == db + + await resolve_db() + + assert session._pinned_database + assert config.database == resolved_db + cache_spy.set.assert_called_once_with(key, resolved_db) + else: + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + + await resolve_db() + + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + cache_spy.set.assert_not_called() + else: + cache_spy.set.assert_called_once_with(key, resolved_db) + assert session._pinned_database + assert config.database == resolved_db diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 81fba42b..787fae82 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from unittest.mock import MagicMock import pytest @@ -52,7 +50,7 @@ async def test_transaction_context_when_committing( on_error = mocker.AsyncMock() on_cancel = mocker.Mock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -88,7 +86,7 @@ async def test_transaction_context_with_explicit_rollback( on_error = mocker.AsyncMock() on_cancel = mocker.Mock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -120,7 +118,7 @@ class OopsError(RuntimeError): on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -141,7 +139,7 @@ async def test_transaction_run_takes_no_query_object(async_fake_connection): on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) with pytest.raises(ValueError): await tx.run(Query("RETURN 1")) @@ -165,7 +163,7 @@ async def test_transaction_run_parameters( on_error = MagicMock() on_cancel = MagicMock() tx = AsyncTransaction( - async_fake_connection, 2, None, on_closed, on_error, on_cancel + async_fake_connection, 2, None, on_closed, on_error, on_cancel, None ) if not as_kwargs: params = {"parameters": params} @@ -187,7 +185,9 @@ async def test_transaction_run_parameters( async def test_transaction_rollbacks_on_open_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = False async_fake_connection.is_reset_mock.reset_mock() @@ -201,7 +201,9 @@ async def test_transaction_rollbacks_on_open_connections( async def test_transaction_no_rollback_on_reset_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.is_reset_mock.return_value = True async_fake_connection.is_reset_mock.reset_mock() @@ -215,7 +217,9 @@ async def test_transaction_no_rollback_on_reset_connections( async def test_transaction_no_rollback_on_closed_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.closed.return_value = True async_fake_connection.closed.reset_mock() @@ -231,7 +235,9 @@ async def test_transaction_no_rollback_on_closed_connections( async def test_transaction_no_rollback_on_defunct_connections( async_fake_connection, ): - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) async with tx as tx_: async_fake_connection.defunct.return_value = True async_fake_connection.defunct.reset_mock() @@ -246,9 +252,13 @@ async def test_transaction_no_rollback_on_defunct_connections( @pytest.mark.parametrize("pipeline", (True, False)) @mark_async_test async def test_transaction_begin_pipelining( - async_fake_connection, pipeline + async_fake_connection, + pipeline, + mocker, ) -> None: - tx = AsyncTransaction(async_fake_connection, 2, None, noop, noop, noop) + tx = AsyncTransaction( + async_fake_connection, 2, None, noop, noop, noop, None + ) database = "db" imp_user = None bookmarks = ["bookmark1", "bookmark2"] @@ -283,6 +293,7 @@ async def test_transaction_begin_pipelining( "notifications_disabled_classifications": ( notifications_disabled_classifications ), + "on_success": mocker.ANY, }, ), ] @@ -333,7 +344,7 @@ async def test_server_error_propagates(async_scripted_connection, error): raise ValueError(f"Unknown error type {error}") connection.set_script(script) - tx = AsyncTransaction(connection, 2, None, noop, noop, noop) + tx = AsyncTransaction(connection, 2, None, noop, noop, noop, None) res1 = await tx.run("UNWIND range(1, 1000) AS n RETURN n") assert await res1.__anext__() == {"n": 1} @@ -349,3 +360,45 @@ async def test_server_error_propagates(async_scripted_connection, error): await res1.__anext__() assert exc1.value is exc2.value.__cause__ + + +@pytest.mark.parametrize("async_cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_async_test +async def test_on_database_callback( + async_scripted_connection, async_cb, resolved_db +): + cb_calls = [] + + if async_cb: + + async def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + begin_meta = {} + if resolved_db is not ...: + begin_meta["db"] = resolved_db + connection = async_scripted_connection + connection.set_script( + [ + ("begin", {"on_success": (begin_meta,), "on_summary": None}), + ] + ) + + result = AsyncTransaction( + connection, 1, None, noop, noop, noop, db_callback + ) + await result._begin( + None, None, None, None, None, None, None, None, pipelined=False + ) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/common/work/test_summary.py b/tests/unit/common/work/test_summary.py index 74f2059a..46c85c53 100644 --- a/tests/unit/common/work/test_summary.py +++ b/tests/unit/common/work/test_summary.py @@ -890,6 +890,7 @@ def test_summary_result_counters(summary_args_kwargs, counters_set) -> None: ((5, 5), "t_first"), ((5, 6), "t_first"), ((5, 7), "t_first"), + ((5, 8), "t_first"), ), ) def test_summary_result_available_after( @@ -927,6 +928,7 @@ def test_summary_result_available_after( ((5, 5), "t_last"), ((5, 6), "t_last"), ((5, 7), "t_last"), + ((5, 8), "t_last"), ), ) def test_summary_result_consumed_after( diff --git a/tests/unit/mixed/async_compat/test_concurrency.py b/tests/unit/mixed/async_compat/test_concurrency.py index 03356352..f4c17374 100644 --- a/tests/unit/mixed/async_compat/test_concurrency.py +++ b/tests/unit/mixed/async_compat/test_concurrency.py @@ -20,6 +20,8 @@ from neo4j._async_compat.concurrency import AsyncRLock +from ...._async_util import gather_cancel + @pytest.mark.asyncio async def test_async_r_lock(): @@ -36,7 +38,7 @@ async def worker(): assert counter == counter_ + 1 assert not lock.locked() - await asyncio.gather(worker(), worker(), worker()) + await gather_cancel(worker(), worker(), worker()) assert not lock.locked() @@ -52,7 +54,7 @@ async def worker(): assert lock.locked() assert not lock.locked() - await asyncio.gather(worker(), worker(), worker()) + await gather_cancel(worker(), worker(), worker()) assert not lock.locked() @@ -69,7 +71,7 @@ async def waiter(): assert not await lock.acquire(timeout=0.1) assert not lock.locked() - await asyncio.gather(blocker(), waiter()) + await gather_cancel(blocker(), waiter()) assert lock.locked() # blocker still owns it! @@ -90,7 +92,7 @@ async def waiter(): # blocker: lock.release() assert not lock.locked() - await asyncio.gather(blocker(), waiter()) + await gather_cancel(blocker(), waiter()) assert lock.locked() # waiter still owns it! @@ -162,7 +164,7 @@ async def waiter(): lock.release() assert not lock.locked() - await asyncio.gather(blocker(), waiter_non_blocking(), waiter()) + await gather_cancel(blocker(), waiter_non_blocking(), waiter()) assert lock.locked() # waiter_non_blocking still owns it! @@ -225,7 +227,7 @@ async def waiter_non_blocking(): awaits += 1 assert not lock.locked() - await asyncio.gather(blocker(), waiter_non_blocking()) + await gather_cancel(blocker(), waiter_non_blocking()) assert not lock.locked() diff --git a/tests/unit/mixed/io/test_direct.py b/tests/unit/mixed/io/test_direct.py index f330fdfb..ed4a9428 100644 --- a/tests/unit/mixed/io/test_direct.py +++ b/tests/unit/mixed/io/test_direct.py @@ -29,10 +29,11 @@ import pytest -from neo4j._async.io._pool import AcquireAuth as AsyncAcquireAuth +from neo4j._async.io._pool import AcquisitionAuth as AsyncAcquisitionAuth from neo4j._deadline import Deadline -from neo4j._sync.io._pool import AcquireAuth +from neo4j._sync.io._pool import AcquisitionAuth +from ...._async_util import gather_cancel from ...async_.io.test_direct import AsyncFakeBoltPool from ...async_.test_auth_management import ( static_auth_manager as static_async_auth_manager, @@ -128,11 +129,11 @@ def acquire_release_conn( def test_full_pool_re_auth(self, fake_connection_generator, mocker): address = ("127.0.0.1", 7687) - acquire_auth1 = AcquireAuth( + acquire_auth1 = AcquisitionAuth( auth=static_auth_manager(("user1", "pass1")) ) auth2 = ("user2", "pass2") - acquire_auth2 = AcquireAuth(auth=static_auth_manager(auth2)) + acquire_auth2 = AcquisitionAuth(auth=static_auth_manager(auth2)) acquire1_event = threading.Event() cx1 = None @@ -193,7 +194,7 @@ async def acquire_release_conn( async def waiter(pool_, acquired_counter_, release_event_): nonlocal pre_populated_connections, connections - if not await acquired_counter_.wait(5, timeout=1): + if not await acquired_counter_.wait(5, timeout=5): raise RuntimeError("Acquire coroutines not fast enough") # The pool size should be 5, all are in-use self.assert_pool_size(address, 5, 0, pool_) @@ -205,7 +206,7 @@ async def waiter(pool_, acquired_counter_, release_event_): release_event_.set() # wait for all coroutines to release connections back to pool - if not await acquired_counter_.wait(10, timeout=5): + if not await acquired_counter_.wait(10, timeout=10): raise RuntimeError("Acquire coroutines not fast enough") # The pool size is still 5, but all are free self.assert_pool_size(address, 0, 5, pool_) @@ -234,7 +235,7 @@ async def waiter(pool_, acquired_counter_, release_event_): ) for _ in range(10) ] - await asyncio.gather( + await gather_cancel( waiter(pool, acquired_counter, release_event), *coroutines ) @@ -243,11 +244,13 @@ async def test_full_pool_re_auth_async( self, async_fake_connection_generator, mocker ): address = ("127.0.0.1", 7687) - acquire_auth1 = AsyncAcquireAuth( + acquire_auth1 = AsyncAcquisitionAuth( auth=static_async_auth_manager(("user1", "pass1")) ) auth2 = ("user2", "pass2") - acquire_auth2 = AsyncAcquireAuth(auth=static_async_auth_manager(auth2)) + acquire_auth2 = AsyncAcquisitionAuth( + auth=static_async_auth_manager(auth2) + ) cx1 = None async def acquire1(pool_): @@ -274,4 +277,4 @@ async def acquire2(pool_): async with AsyncFakeBoltPool( async_fake_connection_generator, (), max_connection_pool_size=1 ) as pool: - await asyncio.gather(acquire1(pool), acquire2(pool)) + await gather_cancel(acquire1(pool), acquire2(pool)) diff --git a/tests/unit/mixed/test_home_db_cache.py b/tests/unit/mixed/test_home_db_cache.py new file mode 100644 index 00000000..466f1481 --- /dev/null +++ b/tests/unit/mixed/test_home_db_cache.py @@ -0,0 +1,88 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import random +from collections import defaultdict +from concurrent.futures import ( + as_completed, + ThreadPoolExecutor, +) +from time import monotonic + +from neo4j._sync.home_db_cache import HomeDbCache + + +# No async equivalent exists, because the async home db cache is not really +# async. As there's no IO involved, there's no need for locking in async world. +def test_concurrent_home_db_cache_access() -> None: + workers = 25 + duration = 5 + value_pool_size = 50 + + cache = HomeDbCache(ttl=0.001, max_size=value_pool_size - 2) + keys = tuple( + cache.compute_key(user, None) + for user in map(str, range(1, value_pool_size + 1)) + ) + + def worker(worked_id, end): + non_checks = checks = 0 + + value_counter = defaultdict(int) + while monotonic() < end: + for _ in range(20): # to not check time too often + i = random.randint(0, len(keys) - 1) + value_count = value_counter[i] + key = keys[i] + rand = random.random() + if rand < 0.1: + cache.set(key, None) + res = cache.get(key) + # Never want to read back this worker's own value + assert res is None or not res.startswith(f"{worked_id}-") + elif rand < 0.55: + value_counter[i] += 1 + value = f"{worked_id}-{value_count + 1}" + cache.set(key, value) + res = cache.get(key) + if res is not None and res.startswith(f"{worked_id}-"): + # never want to read back an old value of this worker + checks += 1 + assert res == value + else: + non_checks += 1 + else: + res = cache.get(key) + if res is not None and res.startswith(f"{worked_id}-"): + # never want to read back an old value of this worker + checks += 1 + assert res == f"{worked_id}-{value_count}" + else: + non_checks += 1 + + # import json + # print( + # f"{worked_id}:\n" + # f"{json.dumps(value_counter, indent=2)}\n" + # f"checks: {checks}, non_checks: {non_checks}\n", + # flush=True, + # ) + + with ThreadPoolExecutor(max_workers=workers) as executor: + end = monotonic() + duration + futures = (executor.submit(worker, i, end) for i in range(workers)) + for future in as_completed(futures): + future.result() diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 8785badb..ca9a5c80 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -127,7 +127,14 @@ def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock @@ -218,7 +225,14 @@ def callback(): return func method_mock = parent.__getattr__(name) - if name in {"run", "commit", "pull", "rollback", "discard"}: + if name in { + "begin", + "run", + "commit", + "pull", + "rollback", + "discard", + }: method_mock.side_effect = build_message_handler(name) return method_mock diff --git a/tests/unit/sync/fixtures/fake_pool.py b/tests/unit/sync/fixtures/fake_pool.py index 38d2ac4d..855d935a 100644 --- a/tests/unit/sync/fixtures/fake_pool.py +++ b/tests/unit/sync/fixtures/fake_pool.py @@ -17,6 +17,7 @@ import pytest from neo4j._sync.config import PoolConfig +from neo4j._sync.home_db_cache import HomeDbCache from neo4j._sync.io._pool import IOPool @@ -32,6 +33,9 @@ def fake_pool(fake_connection_generator, mocker): pool.buffered_connection_mocks = [] pool.acquired_connection_mocks = [] pool.pool_config = PoolConfig() + pool.ssr_enabled = False + pool.is_direct_pool = True + pool.home_db_cache = HomeDbCache(enabled=False) def acquire_side_effect(*_, **__): if pool.buffered_connection_mocks: diff --git a/tests/unit/sync/io/test_class_bolt.py b/tests/unit/sync/io/test_class_bolt.py index f3b06303..c3d6bace 100644 --- a/tests/unit/sync/io/test_class_bolt.py +++ b/tests/unit/sync/io/test_class_bolt.py @@ -38,7 +38,7 @@ def test_class_method_protocol_handlers(): expected_handlers = { (3, 0), (4, 1), (4, 2), (4, 3), (4, 4), - (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), + (5, 0), (5, 1), (5, 2), (5, 3), (5, 4), (5, 5), (5, 6), (5, 7), (5, 8), } # fmt: on @@ -69,7 +69,8 @@ def test_class_method_protocol_handlers(): ((5, 5), 1), ((5, 6), 1), ((5, 7), 1), - ((5, 8), 0), + ((5, 8), 1), + ((5, 9), 0), ((6, 0), 0), ], ) @@ -92,7 +93,7 @@ def test_class_method_get_handshake(): handshake = Bolt.get_handshake() assert ( handshake - == b"\x00\x07\x07\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" + == b"\x00\x08\x08\x05\x00\x02\x04\x04\x00\x00\x01\x04\x00\x00\x00\x03" ) @@ -143,6 +144,7 @@ def test_cancel_hello_in_open(mocker, none_auth): ((5, 5), "neo4j._sync.io._bolt5.Bolt5x5"), ((5, 6), "neo4j._sync.io._bolt5.Bolt5x6"), ((5, 7), "neo4j._sync.io._bolt5.Bolt5x7"), + ((5, 8), "neo4j._sync.io._bolt5.Bolt5x8"), ), ) @mark_sync_test @@ -181,7 +183,7 @@ def test_version_negotiation( (2, 0), (4, 0), (3, 1), - (5, 8), + (5, 9), (6, 0), ), ) diff --git a/tests/unit/sync/io/test_class_bolt3.py b/tests/unit/sync/io/test_class_bolt3.py index ba80ce81..af3d3c6b 100644 --- a/tests/unit/sync/io/test_class_bolt3.py +++ b/tests/unit/sync/io/test_class_bolt3.py @@ -129,7 +129,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -569,3 +569,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt3.PACKER_CLS, + unpacker_cls=Bolt3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x0.py b/tests/unit/sync/io/test_class_bolt4x0.py index a0ad36e8..f9dfef4b 100644 --- a/tests/unit/sync/io/test_class_bolt4x0.py +++ b/tests/unit/sync/io/test_class_bolt4x0.py @@ -232,7 +232,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -660,3 +660,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x0.PACKER_CLS, + unpacker_cls=Bolt4x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x1.py b/tests/unit/sync/io/test_class_bolt4x1.py index c4b0208a..219a9fda 100644 --- a/tests/unit/sync/io/test_class_bolt4x1.py +++ b/tests/unit/sync/io/test_class_bolt4x1.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -682,3 +682,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x1.PACKER_CLS, + unpacker_cls=Bolt4x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x2.py b/tests/unit/sync/io/test_class_bolt4x2.py index b6ac961a..944f7c28 100644 --- a/tests/unit/sync/io/test_class_bolt4x2.py +++ b/tests/unit/sync/io/test_class_bolt4x2.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -682,3 +682,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x2.PACKER_CLS, + unpacker_cls=Bolt4x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x3.py b/tests/unit/sync/io/test_class_bolt4x3.py index c5da8700..2e53fc42 100644 --- a/tests/unit/sync/io/test_class_bolt4x3.py +++ b/tests/unit/sync/io/test_class_bolt4x3.py @@ -254,7 +254,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -711,3 +711,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x3.PACKER_CLS, + unpacker_cls=Bolt4x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt4x4.py b/tests/unit/sync/io/test_class_bolt4x4.py index 164372b0..9378e8bf 100644 --- a/tests/unit/sync/io/test_class_bolt4x4.py +++ b/tests/unit/sync/io/test_class_bolt4x4.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -671,3 +671,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt4x4.PACKER_CLS, + unpacker_cls=Bolt4x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt4x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x0.py b/tests/unit/sync/io/test_class_bolt5x0.py index 6f26b97a..2390112d 100644 --- a/tests/unit/sync/io/test_class_bolt5x0.py +++ b/tests/unit/sync/io/test_class_bolt5x0.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -735,3 +735,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x0.PACKER_CLS, + unpacker_cls=Bolt5x0.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x0( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x1.py b/tests/unit/sync/io/test_class_bolt5x1.py index dfe638a9..7b2804c4 100644 --- a/tests/unit/sync/io/test_class_bolt5x1.py +++ b/tests/unit/sync/io/test_class_bolt5x1.py @@ -280,7 +280,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -789,3 +789,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x1.PACKER_CLS, + unpacker_cls=Bolt5x1.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x1( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x2.py b/tests/unit/sync/io/test_class_bolt5x2.py index 5dc09be8..165d1776 100644 --- a/tests/unit/sync/io/test_class_bolt5x2.py +++ b/tests/unit/sync/io/test_class_bolt5x2.py @@ -278,7 +278,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -826,3 +826,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x2.PACKER_CLS, + unpacker_cls=Bolt5x2.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x2( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x3.py b/tests/unit/sync/io/test_class_bolt5x3.py index af852710..d0d17131 100644 --- a/tests/unit/sync/io/test_class_bolt5x3.py +++ b/tests/unit/sync/io/test_class_bolt5x3.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -713,3 +713,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x3.PACKER_CLS, + unpacker_cls=Bolt5x3.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x3( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x4.py b/tests/unit/sync/io/test_class_bolt5x4.py index 5773d1f6..ea938cc2 100644 --- a/tests/unit/sync/io/test_class_bolt5x4.py +++ b/tests/unit/sync/io/test_class_bolt5x4.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -718,3 +718,25 @@ def on_success(metadata): connection.fetch_all() assert received_metadata == sent_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x4.PACKER_CLS, + unpacker_cls=Bolt5x4.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x4( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x5.py b/tests/unit/sync/io/test_class_bolt5x5.py index 361a9c14..e5cc6e74 100644 --- a/tests/unit/sync/io/test_class_bolt5x5.py +++ b/tests/unit/sync/io/test_class_bolt5x5.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -756,3 +756,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x5.PACKER_CLS, + unpacker_cls=Bolt5x5.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x5( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x6.py b/tests/unit/sync/io/test_class_bolt5x6.py index 15f37872..a472ef5f 100644 --- a/tests/unit/sync/io/test_class_bolt5x6.py +++ b/tests/unit/sync/io/test_class_bolt5x6.py @@ -281,7 +281,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -760,3 +760,25 @@ def extend_diag_record(r): } assert received_metadata == expected_metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x6.PACKER_CLS, + unpacker_cls=Bolt5x6.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x6( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x7.py b/tests/unit/sync/io/test_class_bolt5x7.py index cf999cc6..95890c79 100644 --- a/tests/unit/sync/io/test_class_bolt5x7.py +++ b/tests/unit/sync/io/test_class_bolt5x7.py @@ -282,7 +282,7 @@ def test_telemetry_message( telemetry_disabled=driver_disabled, ) if serv_enabled: - connection.configuration_hints["telemetry.enabled"] = True + connection.connection_hints["telemetry.enabled"] = True connection.telemetry(api) connection.send_all() @@ -848,3 +848,25 @@ def _build_error_hierarchy_metadata(diag_records_metadata): if r is not ...: current_root["diagnostic_record"] = r return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x7.PACKER_CLS, + unpacker_cls=Bolt5x7.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x7( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is False diff --git a/tests/unit/sync/io/test_class_bolt5x8.py b/tests/unit/sync/io/test_class_bolt5x8.py new file mode 100644 index 00000000..25de8c2b --- /dev/null +++ b/tests/unit/sync/io/test_class_bolt5x8.py @@ -0,0 +1,872 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import itertools +import logging + +import pytest + +import neo4j +from neo4j._api import TelemetryAPI +from neo4j._meta import ( + BOLT_AGENT_DICT, + USER_AGENT, +) +from neo4j._sync.config import PoolConfig +from neo4j._sync.io._bolt5 import Bolt5x8 +from neo4j.exceptions import Neo4jError + +from ...._async_compat import mark_sync_test +from ....iter_util import powerset + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is True + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale_if_not_enabled(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = -1 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize("set_stale", (True, False)) +def test_conn_is_not_stale(fake_socket, set_stale): + address = neo4j.Address(("127.0.0.1", 7687)) + max_connection_lifetime = 999999999 + connection = Bolt5x8( + address, fake_socket(address), max_connection_lifetime + ) + if set_stale: + connection.set_stale() + assert connection.stale() is set_stale + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ({"db": "something"},)), + (("", {}), {"imp_user": "imposter"}, ({"imp_user": "imposter"},)), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ({"db": "something", "imp_user": "imposter"},), + ), + ), +) +@mark_sync_test +def test_extra_in_begin(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.begin(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x11" + assert tuple(is_fields) == expected_fields + + +@pytest.mark.parametrize( + ("args", "kwargs", "expected_fields"), + ( + (("", {}), {"db": "something"}, ("", {}, {"db": "something"})), + ( + ("", {}), + {"imp_user": "imposter"}, + ("", {}, {"imp_user": "imposter"}), + ), + ( + ("", {}), + {"db": "something", "imp_user": "imposter"}, + ("", {}, {"db": "something", "imp_user": "imposter"}), + ), + ), +) +@mark_sync_test +def test_extra_in_run(fake_socket, args, kwargs, expected_fields): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.run(*args, **kwargs) + connection.send_all() + tag, is_fields = socket.pop_message() + assert tag == b"\x10" + assert tuple(is_fields) == expected_fields + + +@mark_sync_test +def test_n_extra_in_discard(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == {"n": 666} + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": -1, "qid": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": 666, "qid": 777}), + (-1, {"n": 666}), + ], +) +@mark_sync_test +def test_n_and_qid_extras_in_discard(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.discard(n=666, qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x2f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (666, {"n": 666}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_n_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (777, {"n": -1, "qid": 777}), + (-1, {"n": -1}), + ], +) +@mark_sync_test +def test_qid_extra_in_pull(fake_socket, test_input, expected): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(qid=test_input) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == expected + + +@mark_sync_test +def test_n_and_qid_extras_in_pull(fake_socket): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, socket, PoolConfig.max_connection_lifetime + ) + connection.pull(n=666, qid=777) + connection.send_all() + tag, fields = socket.pop_message() + assert tag == b"\x3f" + assert len(fields) == 1 + assert fields[0] == {"n": 666, "qid": 777} + + +@mark_sync_test +def test_hello_passes_routing_metadata(fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.4.0"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + routing_context={"foo": "bar"}, + ) + connection.hello() + tag, fields = sockets.server.pop_message() + assert tag == b"\x01" + assert len(fields) == 1 + assert fields[0]["routing"] == {"foo": "bar"} + + +@pytest.mark.parametrize("api", TelemetryAPI) +@pytest.mark.parametrize("serv_enabled", (True, False)) +@pytest.mark.parametrize("driver_disabled", (True, False)) +@mark_sync_test +def test_telemetry_message( + fake_socket, api, serv_enabled, driver_disabled +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + telemetry_disabled=driver_disabled, + ) + if serv_enabled: + connection.connection_hints["telemetry.enabled"] = True + connection.telemetry(api) + connection.send_all() + + if serv_enabled and not driver_disabled: + tag, fields = socket.pop_message() + assert tag == b"\x54" + assert fields == [int(api)] + else: + with pytest.raises(OSError): + socket.pop_message() + + +@pytest.mark.parametrize( + ("hints", "valid"), + ( + ({"connection.recv_timeout_seconds": 1}, True), + ({"connection.recv_timeout_seconds": 42}, True), + ({}, True), + ({"whatever_this_is": "ignore me!"}, True), + ({"connection.recv_timeout_seconds": -1}, False), + ({"connection.recv_timeout_seconds": 0}, False), + ({"connection.recv_timeout_seconds": 2.5}, False), + ({"connection.recv_timeout_seconds": None}, False), + ({"connection.recv_timeout_seconds": False}, False), + ({"connection.recv_timeout_seconds": "1"}, False), + ), +) +@mark_sync_test +def test_hint_recv_timeout_seconds( + fake_socket_pair, hints, valid, caplog, mocker +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.client.settimeout = mocker.Mock() + sockets.server.send_message( + b"\x70", {"server": "Neo4j/4.3.4", "hints": hints} + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + with caplog.at_level(logging.INFO): + connection.hello() + if valid: + if "connection.recv_timeout_seconds" in hints: + sockets.client.settimeout.assert_called_once_with( + hints["connection.recv_timeout_seconds"] + ) + else: + sockets.client.settimeout.assert_not_called() + assert not any( + "recv_timeout_seconds" in msg and "invalid" in msg + for msg in caplog.messages + ) + else: + sockets.client.settimeout.assert_not_called() + assert any( + repr(hints["connection.recv_timeout_seconds"]) in msg + and "recv_timeout_seconds" in msg + and "invalid" in msg + for msg in caplog.messages + ) + + +CREDENTIALS = "+++super-secret-sauce+++" + + +@pytest.mark.parametrize( + "auth", + ( + ("user", CREDENTIALS), + neo4j.basic_auth("user", CREDENTIALS), + neo4j.kerberos_auth(CREDENTIALS), + neo4j.bearer_auth(CREDENTIALS), + neo4j.custom_auth("user", CREDENTIALS, "realm", "scheme"), + neo4j.Auth("scheme", "principal", CREDENTIALS, "realm", foo="bar"), + ), +) +@mark_sync_test +def test_credentials_are_not_logged(auth, fake_socket_pair, caplog): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/4.3.4"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + auth=auth, + ) + with caplog.at_level(logging.DEBUG): + connection.hello() + + if isinstance(auth, tuple): + auth = neo4j.basic_auth(*auth) + for field in ("scheme", "principal", "realm", "parameters"): + value = getattr(auth, field, None) + if value: + assert repr(value) in caplog.text + assert CREDENTIALS not in caplog.text + + +def _assert_notifications_in_extra(extra, expected): + for key in expected: + assert key in extra + assert extra[key] == expected[key] + + +@pytest.mark.parametrize( + ("method", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("cls_min_sev", "method_min_sev"), + itertools.product((None, "WARNING", "OFF"), repeat=2), +) +@pytest.mark.parametrize( + ("cls_dis_clss", "method_dis_clss"), + itertools.product((None, [], ["HINT"], ["HINT", "DEPRECATION"]), repeat=2), +) +@mark_sync_test +def test_supports_notification_filters( + fake_socket, + method, + args, + extra_idx, + cls_min_sev, + method_min_sev, + cls_dis_clss, + method_dis_clss, +): + address = neo4j.Address(("127.0.0.1", 7687)) + socket = fake_socket(address, Bolt5x8.UNPACKER_CLS) + connection = Bolt5x8( + address, + socket, + PoolConfig.max_connection_lifetime, + notifications_min_severity=cls_min_sev, + notifications_disabled_classifications=cls_dis_clss, + ) + method = getattr(connection, method) + + method( + *args, + notifications_min_severity=method_min_sev, + notifications_disabled_classifications=method_dis_clss, + ) + connection.send_all() + + _, fields = socket.pop_message() + extra = fields[extra_idx] + expected = {} + if method_min_sev is not None: + expected["notifications_minimum_severity"] = method_min_sev + if method_dis_clss is not None: + expected["notifications_disabled_classifications"] = method_dis_clss + _assert_notifications_in_extra(extra, expected) + + +@pytest.mark.parametrize("min_sev", (None, "WARNING", "OFF")) +@pytest.mark.parametrize( + "dis_clss", (None, [], ["HINT"], ["HINT", "DEPRECATION"]) +) +@mark_sync_test +def test_hello_supports_notification_filters( + fake_socket_pair, min_sev, dis_clss +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, + sockets.client, + PoolConfig.max_connection_lifetime, + notifications_min_severity=min_sev, + notifications_disabled_classifications=dis_clss, + ) + + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + expected = {} + if min_sev is not None: + expected["notifications_minimum_severity"] = min_sev + if dis_clss is not None: + expected["notifications_disabled_classifications"] = dis_clss + _assert_notifications_in_extra(extra, expected) + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_user_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + if not user_agent: + assert extra["user_agent"] == USER_AGENT + else: + assert extra["user_agent"] == user_agent + + +@mark_sync_test +@pytest.mark.parametrize( + "user_agent", (None, "test user agent", "", USER_AGENT) +) +def test_sends_bolt_agent(fake_socket_pair, user_agent): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + max_connection_lifetime = 0 + connection = Bolt5x8( + address, sockets.client, max_connection_lifetime, user_agent=user_agent + ) + connection.hello() + + _tag, fields = sockets.server.pop_message() + extra = fields[0] + assert extra["bolt_agent"] == BOLT_AGENT_DICT + + +@mark_sync_test +@pytest.mark.parametrize( + ("func", "args", "extra_idx"), + ( + ("run", ("RETURN 1",), 2), + ("begin", (), 0), + ), +) +@pytest.mark.parametrize( + ("timeout", "res"), + ( + (None, None), + (0, 0), + (0.1, 100), + (0.001, 1), + (1e-15, 1), + (0.0005, 1), + (0.0001, 1), + (1.0015, 1002), + (1.000499, 1000), + (1.0025, 1002), + (3.0005, 3000), + (3.456, 3456), + (1, 1000), + ( + -1e-15, + ValueError("Timeout must be a positive number or 0"), + ), + ( + "foo", + ValueError("Timeout must be specified as a number of seconds"), + ), + ( + [1, 2], + TypeError("Timeout must be specified as a number of seconds"), + ), + ), +) +def test_tx_timeout( + fake_socket_pair, func, args, extra_idx, timeout, res +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8(address, sockets.client, 0) + func = getattr(connection, func) + if isinstance(res, Exception): + with pytest.raises(type(res), match=str(res)): + func(*args, timeout=timeout) + else: + func(*args, timeout=timeout) + connection.send_all() + _tag, fields = sockets.server.pop_message() + extra = fields[extra_idx] + if timeout is None: + assert "tx_timeout" not in extra + else: + assert extra["tx_timeout"] == res + + +@pytest.mark.parametrize( + "actions", + itertools.combinations_with_replacement( + itertools.product( + ("run", "begin", "begin_run"), + ("reset", "commit", "rollback"), + (None, "some_db", "another_db"), + ), + 2, + ), +) +@mark_sync_test +def test_tracks_last_database(fake_socket_pair, actions): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sockets.server.send_message(b"\x70", {"server": "Neo4j/1.2.3"}) + sockets.server.send_message(b"\x70", {}) + connection.hello() + assert connection.last_database is None + for action, finish, db in actions: + sockets.server.send_message(b"\x70", {}) + if action == "run": + connection.run("RETURN 1", db=db) + elif action == "begin": + connection.begin(db=db) + elif action == "begin_run": + connection.begin(db=db) + assert connection.last_database == db + sockets.server.send_message(b"\x70", {}) + connection.run("RETURN 1") + else: + raise ValueError(action) + + assert connection.last_database == db + connection.send_all() + connection.fetch_all() + assert connection.last_database == db + + sockets.server.send_message(b"\x70", {}) + if finish == "reset": + connection.reset() + elif finish == "commit": + if action == "run": + connection.pull() + else: + connection.commit() + elif finish == "rollback": + if action == "run": + connection.pull() + else: + connection.rollback() + else: + raise ValueError(finish) + + connection.send_all() + connection.fetch_all() + + assert connection.last_database == db + + +DEFAULT_DIAG_REC_PAIRS = ( + ("OPERATION", ""), + ("OPERATION_CODE", "0"), + ("CURRENT_SCHEMA", "/"), +) + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + upper_limit=3, + ), +) +@pytest.mark.parametrize("method", ("pull", "discard")) +@mark_sync_test +def test_enriches_statuses( + sent_diag_records, + method, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + + sent_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in sent_diag_records + ] + } + sockets.server.send_message(b"\x70", sent_metadata) + + received_metadata = None + + def on_success(metadata): + nonlocal received_metadata + received_metadata = metadata + + getattr(connection, method)(on_success=on_success) + connection.send_all() + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = { + "statuses": [ + { + "status_description": "the status description", + "description": "description", + "diagnostic_record": r, + } + if r is not ... + else { + "status_description": "the status description", + "description": "description", + } + for r in expected_diag_records + ] + } + + assert received_metadata == expected_metadata + + +@pytest.mark.parametrize( + "sent_diag_records", + powerset( + ( + ..., + None, + {}, + [], + "1", + 1, + {"OPERATION_CODE": "0"}, + {"OPERATION": "", "OPERATION_CODE": "0", "CURRENT_SCHEMA": "/"}, + {"OPERATION": "Foo", "OPERATION_CODE": 1, "CURRENT_SCHEMA": False}, + {"OPERATION": "", "OPERATION_CODE": "0", "bar": "baz"}, + ), + lower_limit=1, + upper_limit=3, + ), +) +@mark_sync_test +def test_enriches_error_statuses( + sent_diag_records, + fake_socket_pair, +): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + connection = Bolt5x8(address, sockets.client, 0) + sent_diag_records = [ + {**r, "_classification": "CLIENT_ERROR", "_status_parameters": {}} + if isinstance(r, dict) + else r + for r in sent_diag_records + ] + + sent_metadata = _build_error_hierarchy_metadata(sent_diag_records) + + sockets.server.send_message(b"\x7f", sent_metadata) + + received_metadata = None + + def on_failure(metadata): + nonlocal received_metadata + received_metadata = metadata + + connection.run("RETURN 1", on_failure=on_failure) + connection.send_all() + with pytest.raises(Neo4jError): + connection.fetch_all() + + def extend_diag_record(r): + if r is ...: + return dict(DEFAULT_DIAG_REC_PAIRS) + if isinstance(r, dict): + return dict((*DEFAULT_DIAG_REC_PAIRS, *r.items())) + return r + + expected_diag_records = [extend_diag_record(r) for r in sent_diag_records] + expected_metadata = _build_error_hierarchy_metadata(expected_diag_records) + + assert received_metadata == expected_metadata + + +def _build_error_hierarchy_metadata(diag_records_metadata): + metadata = { + "gql_status": "FOO12", + "description": "but have you tried not doing that?!", + "message": "some people just can't be helped", + "neo4j_code": "Neo.ClientError.Generic.YouSuck", + } + if diag_records_metadata[0] is not ...: + metadata["diagnostic_record"] = diag_records_metadata[0] + current_root = metadata + for i, r in enumerate(diag_records_metadata[1:]): + current_root["cause"] = { + "description": f"error cause nr. {i + 1}", + "message": f"cause message {i + 1}", + } + current_root = current_root["cause"] + if r is not ...: + current_root["diagnostic_record"] = r + return metadata + + +@pytest.mark.parametrize("ssr_hint", (True, False, None)) +@mark_sync_test +def test_ssr_enabled(ssr_hint, fake_socket_pair): + address = neo4j.Address(("127.0.0.1", 7687)) + sockets = fake_socket_pair( + address, + packer_cls=Bolt5x8.PACKER_CLS, + unpacker_cls=Bolt5x8.UNPACKER_CLS, + ) + meta = {"server": "Neo4j/4.3.4"} + if ssr_hint is not None: + meta["hints"] = {"ssr.enabled": ssr_hint} + sockets.server.send_message(b"\x70", meta) + sockets.server.send_message(b"\x70", {}) + connection = Bolt5x8( + address, sockets.client, PoolConfig.max_connection_lifetime + ) + assert connection.ssr_enabled is False + connection.hello() + assert connection.ssr_enabled is bool(ssr_hint) diff --git a/tests/unit/sync/io/test_direct.py b/tests/unit/sync/io/test_direct.py index a899ae49..64b7d9b5 100644 --- a/tests/unit/sync/io/test_direct.py +++ b/tests/unit/sync/io/test_direct.py @@ -67,6 +67,7 @@ def acquire( bookmarks, auth, liveness_check_timeout, + database_callback=None, ): return self._acquire( self.address, auth, timeout, liveness_check_timeout diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 89b4d16b..13b9be4e 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -30,6 +30,7 @@ from neo4j._deadline import Deadline from neo4j._sync.config import PoolConfig from neo4j._sync.io import ( + AcquisitionDatabase, Bolt, Neo4jPool, ) @@ -53,11 +54,27 @@ WRITER1_ADDRESS = ResolvedAddress(("1.2.3.1", 9020), host_name="host") +def make_home_db_resolve(home_db): + def _home_db_resolve(db): + return db or home_db + + return _home_db_resolve + + +_default_db_resolve = make_home_db_resolve("neo4j") + + @pytest.fixture def custom_routing_opener(fake_connection_generator, mocker): - def make_opener(failures=None, get_readers=None): + def make_opener( + failures=None, + get_readers=None, + db_resolve=_default_db_resolve, + on_open=None, + ): def routing_side_effect(*args, **kwargs): nonlocal failures + opener_.route_requests.append(kwargs.get("database")) res = next(failures, None) if res is None: routers = [ @@ -70,16 +87,18 @@ def routing_side_effect(*args, **kwargs): else: readers = [str(READER1_ADDRESS)] writers = [str(WRITER1_ADDRESS)] - return [ - { - "ttl": 1000, - "servers": [ - {"addresses": routers, "role": "ROUTE"}, - {"addresses": readers, "role": "READ"}, - {"addresses": writers, "role": "WRITE"}, - ], - } - ] + rt = { + "ttl": 1000, + "servers": [ + {"addresses": routers, "role": "ROUTE"}, + {"addresses": readers, "role": "READ"}, + {"addresses": writers, "role": "WRITE"}, + ], + } + db = db_resolve(kwargs.get("database")) + if db is not ...: + rt["db"] = db + return [rt] raise res def open_(addr, auth, timeout): @@ -92,11 +111,16 @@ def open_(addr, auth, timeout): route_mock.side_effect = routing_side_effect connection.attach_mock(route_mock, "route") opener_.connections.append(connection) + + if callable(on_open): + on_open(connection) + return connection failures = iter(failures or []) opener_ = mocker.MagicMock() opener_.connections = [] + opener_.route_requests = [] opener_.side_effect = open_ return opener_ @@ -124,54 +148,101 @@ def _simple_pool(opener) -> Neo4jPool: ) +TEST_DB1 = AcquisitionDatabase("test_db1") +TEST_DB2 = AcquisitionDatabase("test_db2") + + +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_sync_test -def test_acquires_new_routing_table_if_deleted(opener): +def test_acquires_new_routing_table_if_deleted( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - del pool.routing_tables["test_db"] + del pool.routing_tables[db.name] - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] +@pytest.mark.parametrize("guessed_db", (True, False)) @mark_sync_test -def test_acquires_new_routing_table_if_stale(opener): +def test_acquires_new_routing_table_if_stale( + custom_routing_opener, + guessed_db, +) -> None: + db = AcquisitionDatabase("test_db", guessed=guessed_db) + opener = custom_routing_opener(db_resolve=make_home_db_resolve(db.name)) pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(db.name) + assert opener.route_requests == [None if guessed_db else db.name] + opener.route_requests = [] - old_value = pool.routing_tables["test_db"].last_updated_time - pool.routing_tables["test_db"].ttl = 0 + old_value = pool.routing_tables[db.name].last_updated_time + pool.routing_tables[db.name].ttl = 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, db, None, None, None) pool.release(cx) - assert pool.routing_tables["test_db"].last_updated_time > old_value + assert pool.routing_tables[db.name].last_updated_time > old_value + assert opener.route_requests == [None if guessed_db else db.name] @mark_sync_test def test_removes_old_routing_table(opener): pool = _simple_pool(opener) - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db1") - cx = pool.acquire(READ_ACCESS, 30, "test_db2", None, None, None) + assert pool.routing_tables.get(TEST_DB1.name) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx) - assert pool.routing_tables.get("test_db2") + assert pool.routing_tables.get(TEST_DB2.name) - old_value = pool.routing_tables["test_db1"].last_updated_time - pool.routing_tables["test_db1"].ttl = 0 - db2_rt = pool.routing_tables["test_db2"] + old_value = pool.routing_tables[TEST_DB1.name].last_updated_time + pool.routing_tables[TEST_DB1.name].ttl = 0 + db2_rt = pool.routing_tables[TEST_DB2.name] db2_rt.ttl = -RoutingConfig.routing_table_purge_delay - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None, None, None) + cx = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx) - assert pool.routing_tables["test_db1"].last_updated_time > old_value - assert "test_db2" not in pool.routing_tables + assert pool.routing_tables[TEST_DB1.name].last_updated_time > old_value + assert TEST_DB2.name not in pool.routing_tables + + +@pytest.mark.parametrize("guessed_db", (True, False)) +@mark_sync_test +def test_db_resolution_callback(custom_routing_opener, guessed_db): + cb_calls = [] + + def cb(db_): + nonlocal cb_calls + cb_calls.append(db_) + + db = AcquisitionDatabase("test_db", guessed=guessed_db) + home_db = "home_db" + expected_target_db = home_db if db.guessed else db.name + + opener = custom_routing_opener(db_resolve=make_home_db_resolve(home_db)) + pool = _simple_pool(opener) + cx = pool.acquire( + READ_ACCESS, 30, db, None, None, None, database_callback=cb + ) + pool.release(cx) + + assert pool.routing_tables.get(expected_target_db) + assert opener.route_requests == [None if guessed_db else db.name] + assert cb_calls == [expected_target_db] @pytest.mark.parametrize("type_", ("r", "w")) @@ -181,7 +252,7 @@ def test_chooses_right_connection_type(opener, type_): cx1 = pool.acquire( READ_ACCESS if type_ == "r" else WRITE_ACCESS, 30, - "test_db", + TEST_DB1, None, None, None, @@ -196,9 +267,9 @@ def test_chooses_right_connection_type(opener, type_): @mark_sync_test def test_reuses_connection(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 is cx2 @@ -216,7 +287,7 @@ def break_connection(): return None pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) and then @@ -226,7 +297,7 @@ def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -241,12 +312,12 @@ def break_connection(): @mark_sync_test def test_does_not_close_stale_connections_in_use(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1 in pool.connections[cx1.unresolved_address] # simulate connection going stale (e.g. exceeding idle timeout) while being # in use cx1.stale.return_value = True - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -259,7 +330,7 @@ def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -271,7 +342,7 @@ def test_does_not_close_stale_connections_in_use(opener): @mark_sync_test def test_release_resets_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() pool.release(cx1) @@ -282,7 +353,7 @@ def test_release_resets_connections(opener): @mark_sync_test def test_release_does_not_resets_closed_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -295,7 +366,7 @@ def test_release_does_not_resets_closed_connections(opener): @mark_sync_test def test_release_does_not_resets_defunct_connections(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -457,8 +528,8 @@ def close_side_effect(): # create pool with 2 idle connections pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) pool.release(cx2) @@ -470,7 +541,7 @@ def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -479,11 +550,11 @@ def close_side_effect(): @mark_sync_test def test_failing_opener_leaves_connections_in_use_alone(opener): pool = _simple_pool(opener) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert not cx1.closed() @@ -505,7 +576,7 @@ def test__acquire_new_later_without_room(opener): config = _pool_config() config.max_connection_pool_size = 1 pool = Neo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) - _ = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + _ = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) # pool is full now assert pool.connections_reservations[READER1_ADDRESS] == 0 creator = pool._acquire_new_later(READER1_ADDRESS, None, Deadline(1)) @@ -559,13 +630,13 @@ def test_discovery_is_retried(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx2) - assert pool.routing_tables.get("test_db") + assert pool.routing_tables.get(TEST_DB1.name) assert cx1 is cx2 @@ -611,12 +682,12 @@ def test_fast_failing_discovery(custom_routing_opener, error): WorkspaceConfig(), ResolvedAddress(("1.2.3.1", 9999), host_name="host"), ) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) - pool.routing_tables.get("test_db").ttl = 0 + pool.routing_tables.get(TEST_DB1.name).ttl = 0 with pytest.raises(error.__class__) as exc: - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert exc.value is error @@ -657,11 +728,11 @@ def test_connection_error_callback( config.auth = auth_manager pool = Neo4jPool(opener, config, WorkspaceConfig(), ROUTER1_ADDRESS) cxs_read = [ - pool.acquire(READ_ACCESS, 30, "test_db", None, None, None) + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] cxs_write = [ - pool.acquire(WRITE_ACCESS, 30, "test_db", None, None, None) + pool.acquire(WRITE_ACCESS, 30, TEST_DB1, None, None, None) for _ in range(5) ] @@ -690,7 +761,7 @@ def test_connection_error_callback( @mark_sync_test def test_pool_closes_connections_dropped_from_rt(custom_routing_opener): - readers = {"db1": [str(READER1_ADDRESS)]} + readers = {TEST_DB1.name: [str(READER1_ADDRESS)]} def get_readers(database): return readers[database] @@ -700,7 +771,7 @@ def get_readers(database): pool = Neo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx1.unresolved_address == READER1_ADDRESS pool.release(cx1) @@ -708,10 +779,10 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 # force RT refresh, returning a different reader - del pool.routing_tables["db1"] - readers["db1"] = [str(READER2_ADDRESS)] + del pool.routing_tables[TEST_DB1.name] + readers[TEST_DB1.name] = [str(READER2_ADDRESS)] - cx2 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) assert cx2.unresolved_address == READER2_ADDRESS cx1.close.assert_called_once() @@ -726,8 +797,8 @@ def test_pool_does_not_close_connections_dropped_from_rt_for_other_server( # no custom_routing_opener, ): readers = { - "db1": [str(READER1_ADDRESS), str(READER2_ADDRESS)], - "db2": [str(READER1_ADDRESS)], + TEST_DB1.name: [str(READER1_ADDRESS), str(READER2_ADDRESS)], + TEST_DB2.name: [str(READER1_ADDRESS)], } def get_readers(database): @@ -738,14 +809,14 @@ def get_readers(database): pool = Neo4jPool( opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS ) - cx1 = pool.acquire(READ_ACCESS, 30, "db1", None, None, None) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) pool.release(cx1) assert cx1.unresolved_address in {READER1_ADDRESS, READER2_ADDRESS} reader1_connection_count = len(pool.connections[READER1_ADDRESS]) reader2_connection_count = len(pool.connections[READER2_ADDRESS]) assert reader1_connection_count + reader2_connection_count == 1 - cx2 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx2) assert cx2.unresolved_address == READER1_ADDRESS cx1.close.assert_not_called() @@ -754,10 +825,10 @@ def get_readers(database): assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count # force RT refresh, returning a different reader - del pool.routing_tables["db2"] - readers["db2"] = [str(READER3_ADDRESS)] + del pool.routing_tables[TEST_DB2.name] + readers[TEST_DB2.name] = [str(READER3_ADDRESS)] - cx3 = pool.acquire(READ_ACCESS, 30, "db2", None, None, None) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB2, None, None, None) pool.release(cx3) assert cx3.unresolved_address == READER3_ADDRESS @@ -767,3 +838,79 @@ def get_readers(database): assert len(pool.connections[READER1_ADDRESS]) == 1 assert len(pool.connections[READER2_ADDRESS]) == reader2_connection_count assert len(pool.connections[READER3_ADDRESS]) == 1 + + +@mark_sync_test +def test_tracks_ssr_connection_hints(custom_routing_opener): + connection_count = 0 + + def on_open(connection): + if connection.unresolved_address in { + ROUTER1_ADDRESS, + ROUTER2_ADDRESS, + ROUTER3_ADDRESS, + }: + connection.ssr_enabled = True + return + nonlocal connection_count + connection_count += 1 + connection.ssr_enabled = connection_count != 2 + + opener = custom_routing_opener(on_open=on_open) + pool = Neo4jPool( + opener, _pool_config(), WorkspaceConfig(), ROUTER1_ADDRESS + ) + + # no connection in pool => cannot know => defensive assumption: off + assert not pool.ssr_enabled + + # open 1st reader connection (supports SSR) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx1.ssr_enabled # double check we got the mocking right + + assert pool.ssr_enabled + + # open 2nd reader connection (does not support SSR) + cx2 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert not cx2.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + # open 3rd reader connection (supports SSR) + cx3 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + assert cx3.ssr_enabled # double check we got the mocking right + + assert not pool.ssr_enabled + + pool.release(cx1) + pool.release(cx2) + pool.release(cx3) + + assert not pool.ssr_enabled + + cxs = [ + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert sum(not c.ssr_enabled for c in cxs) == 1 # double check + + for cx in (cx for cx in cxs if not cx.ssr_enabled): + cx.close() + + # after the single connection without SSR support is closed + for cx in cxs: + pool.release(cx) + + # force pool cleaning up all stale connections: + cxs = [ + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + for _ in range(3) + ] + assert all(cx.ssr_enabled for cx in cxs) # double check + + assert pool.ssr_enabled + + for cx in cxs: + pool.release(cx) + + assert pool.ssr_enabled diff --git a/tests/unit/sync/test_driver.py b/tests/unit/sync/test_driver.py index b90ebd52..e1173f6b 100644 --- a/tests/unit/sync/test_driver.py +++ b/tests/unit/sync/test_driver.py @@ -276,6 +276,7 @@ def test_driver_opens_write_session_by_default( bookmarks=mocker.ANY, auth=mocker.ANY, liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, ) tx._begin.assert_called_once_with( mocker.ANY, diff --git a/tests/unit/sync/test_home_db_cache.py b/tests/unit/sync/test_home_db_cache.py new file mode 100644 index 00000000..cd656676 --- /dev/null +++ b/tests/unit/sync/test_home_db_cache.py @@ -0,0 +1,293 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import math +import time +import typing as t +from datetime import ( + datetime, + timedelta, +) + +import freezegun +import pytest +import pytz + +from neo4j._conf import WorkspaceConfig +from neo4j._sync.config import PoolConfig +from neo4j._sync.home_db_cache import HomeDbCache +from neo4j._sync.io._pool import Neo4jPool +from neo4j.time import DateTime + + +if t.TYPE_CHECKING: + from neo4j._sync.home_db_cache import TKey + + +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_none_is_none(enabled: bool) -> None: + assert HomeDbCache(enabled=enabled).compute_key(None, None) == (None,) + + +@pytest.mark.parametrize( + "auth", + ( + None, + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "nice token"}, + {"foo": "bar"}, + ), +) +@pytest.mark.parametrize("enabled", (True, False)) +def test_key_imp_precedence_over_auth( + auth: dict | None, + enabled: bool, +) -> None: + cache = HomeDbCache(enabled=enabled) + assert cache.compute_key("bob", auth) == ("bob" if enabled else (None,)) + + +@pytest.mark.parametrize( + "auth", + ( + {}, + {"scheme": "basic", "principal": "neo4j", "credentials": "password"}, + {"scheme": "basic", "principal": "this is wrong, no password?"}, + {"scheme": "basic", "credentials": "this is wrong, no user?"}, + {"scheme": "none"}, + {"scheme": "none", "principal": "even though the scheme is none"}, + {"scheme": "kerberos", "principal": "", "credentials": "ticket"}, + {"scheme": "bearer", "credentials": "nice SSO token"}, + {"scheme": "custom", "principal": "neo4j", "credentials": "password"}, + {"scheme": "custom", "credentials": "bar", "parameters": {"oh": "hi"}}, + {"foo": "bar"}, + ), +) +def test_key_reduces_basic_auth_to_principal(auth: dict) -> None: + key = HomeDbCache().compute_key(None, auth) + if auth.get("scheme") == "basic" and "principal" in auth: + assert isinstance(key, str) + assert key == auth["principal"] + else: + assert isinstance(key, tuple) + for e in key: + assert isinstance(e, tuple) and len(e) == 2 + assert isinstance(e[0], str) + + +_NAN = float("nan") +_NOW = pytz.timezone("Europe/Stockholm").localize( + DateTime(2021, 8, 12, 12, 34, 57, 123456789) +) + + +@pytest.mark.parametrize( + ("auth1", "auth2"), + ( + ({}, {}), + ({"foo": "bar"}, {"foo": "bar"}), + ({"a": 1, "b": 2}, {"b": 2, "a": 1}), + ( + { + "scheme": "funky", + "credentials": "t0pS3cr3t!!11", + "parameters": { + "how much": 1.5, + # Note: for special values (NaN, temporal types, etc.), + # equality may rely on object identity. + "why": "because", + "difficult": _NAN, + "also difficult 🔥": _NOW, + }, + }, + { + "parameters": { + "also difficult 🔥": _NOW, + "difficult": _NAN, + "why": "because", + "how much": 1.5, + }, + "credentials": "t0pS3cr3t!!11", + "scheme": "funky", + }, + ), + ), +) +def test_key_auth_equality(auth1: dict, auth2: dict) -> None: + cache = HomeDbCache() + key1 = cache.compute_key(None, auth1) + key2 = cache.compute_key(None, auth2) + + assert len(cache) == 0 + + cache.set(key1, "value") + assert len(cache) == 1 + assert cache.get(key1) == "value" + + cache.set(key2, "value2") + assert len(cache) == 1 + assert cache.get(key1) == "value2" + assert cache.get(key2) == "value2" + + assert key1 == key2 + + +def _assert_entries( + cache: HomeDbCache, + expected_entries: t.Collection[tuple[TKey, str]], + allow_subset: bool = False, +) -> None: + __tracebackhide__ = True + if not allow_subset: + assert len(cache) == len(expected_entries) + for key, value in expected_entries: + assert cache.get(key) == value + else: + hits = sum(cache.get(key) == value for key, value in expected_entries) + assert hits == len(cache) + + +def _force_cache_clean( + cache: HomeDbCache, + now: float | None = None, +) -> None: + cache._clean(now) + + +def test_cache_ttl() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + + entries = [] + for i in range(1, 11): + time.move_to(t0 + timedelta(seconds=0.25) * (i - 1)) + + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i - timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + _assert_entries(cache, entries) + + time.move_to( + t0 + timedelta(seconds=0.25) * i + timedelta(milliseconds=1) + ) + _force_cache_clean(cache) + entries = entries[-3:] + _assert_entries(cache, entries) + + +def test_cache_ttl_empty_cache() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_does_not_return_expired_entries() -> None: + t0 = datetime(1970, 1, 1) + with freezegun.freeze_time(t0) as time: + cache = HomeDbCache(ttl=1) + key = cache.compute_key("key", None) + value = "value" + + cache.set(cache.compute_key("key", None), value) + assert cache.get(key) == value + + time.move_to(t0 + timedelta(seconds=1, milliseconds=1)) + assert cache.get(key) is None + + +def test_cache_max_size() -> None: + cache = HomeDbCache(max_size=4) + + entries = [] + for i in range(1, 11): + entries.append((cache.compute_key(f"{i}", None), f"value{i}")) + entries = entries[-4:] + key, value = entries[-1] + cache.set(key, value) + + _force_cache_clean(cache) + _assert_entries(cache, entries, allow_subset=True) + + +def test_cache_max_size_empty_cache() -> None: + cache = HomeDbCache(max_size=1) + assert len(cache) == 0 + _force_cache_clean(cache) + assert len(cache) == 0 + + +def test_clean_up_time() -> None: + def get_default_cache(): + pool = Neo4jPool( + lambda: None, PoolConfig(), WorkspaceConfig(), None + ) + return pool.home_db_cache + + repetitions = 5 + scenario_timings = [] + + # Test assumes that by default the driver uses a home db cache only limited + # by its size. + default_cache = get_default_cache() + default_max_size = default_cache._max_size + assert isinstance(default_max_size, int) + # If ttl ever get used, this test needs to be updated to also test pruning + # by TTL. + assert math.isinf(default_cache._ttl) and default_cache._ttl > 0 + + for max_size, count in ( + # no pruning needed + (default_max_size * 10, default_max_size * 10), + # pruning needed + (default_max_size, default_max_size * 10), + ): + cache = HomeDbCache(max_size=max_size) + keys = [cache.compute_key(f"key{i}", None) for i in range(count)] + rep_timings = [] + for _ in range(repetitions): + t0 = time.perf_counter() + for key in keys: + cache.set(key, "value") + t1 = time.perf_counter() + rep_timings.append(t1 - t0) + scenario_timings.append(sum(rep_timings) / len(rep_timings)) + + # pruning shouldn't take more than 20 times the time of no pruning + # N.B., the pruning takes O(n * log(n)) where n is max_size. By only + # pruning O(n * log(n)) elements, we get an amortized pruning overhead of + # O(1) (as long as max_size is small enough to be able to choose a positive + # pruning size). + assert scenario_timings[1] <= 20 * scenario_timings[0] diff --git a/tests/unit/sync/work/test_result.py b/tests/unit/sync/work/test_result.py index 623d5014..65249303 100644 --- a/tests/unit/sync/work/test_result.py +++ b/tests/unit/sync/work/test_result.py @@ -60,7 +60,10 @@ Neo4jWarning, ) -from ...._async_compat import mark_sync_test +from ...._async_compat import ( + mark_sync_test, + wrap_async, +) if t.TYPE_CHECKING: @@ -315,7 +318,7 @@ def fetch_and_compare_all_records( @mark_sync_test def test_result_iteration(method, records): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, 2, None, noop, noop) + result = Result(connection, 2, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) fetch_and_compare_all_records(result, "x", records, method) @@ -324,7 +327,7 @@ def test_result_iteration(method, records): def test_result_iteration_mixed_methods(): records = [[i] for i in range(10)] connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, 4, None, noop, noop) + result = Result(connection, 4, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) iter1 = Util.iter(result) iter2 = Util.iter(result) @@ -372,9 +375,9 @@ def test_parallel_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["x"], records2)) ) - result1 = Result(connection, 2, None, noop, noop) + result1 = Result(connection, 2, None, noop, noop, None) result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = Result(connection, 2, None, noop, noop) + result2 = Result(connection, 2, None, noop, noop, None) result2._run("CYPHER2", {}, None, None, "r", None, None, None) if invert_fetch: fetch_and_compare_all_records(result2, "x", records2, method) @@ -395,9 +398,9 @@ def test_interwoven_result_iteration(method, invert_fetch): connection = ConnectionStub( records=(Records(["x"], records1), Records(["y"], records2)) ) - result1 = Result(connection, 2, None, noop, noop) + result1 = Result(connection, 2, None, noop, noop, None) result1._run("CYPHER1", {}, None, None, "r", None, None, None) - result2 = Result(connection, 2, None, noop, noop) + result2 = Result(connection, 2, None, noop, noop, None) result2._run("CYPHER2", {}, None, None, "r", None, None, None) start = 0 for n in (1, 2, 3, 1, None): @@ -424,7 +427,7 @@ def test_interwoven_result_iteration(method, invert_fetch): @mark_sync_test def test_result_peek(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) for i in range(len(records) + 1): record = result.peek() @@ -447,7 +450,7 @@ def test_result_single_non_strict(records, fetch_size, default): kwargs["strict"] = False connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) == 0: assert result.single(**kwargs) is None @@ -466,7 +469,7 @@ def test_result_single_non_strict(records, fetch_size, default): @mark_sync_test def test_result_single_strict(records, fetch_size): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if len(records) != 1: with pytest.raises(ResultNotSingleError) as exc: @@ -490,7 +493,7 @@ def test_result_single_strict(records, fetch_size): @mark_sync_test def test_result_single_exhausts_records(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) try: with warnings.catch_warnings(): @@ -512,7 +515,7 @@ def test_result_single_exhausts_records(records, fetch_size, strict): @mark_sync_test def test_result_fetch(records, fetch_size, strict): connection = ConnectionStub(records=Records(["x"], records)) - result = Result(connection, fetch_size, None, noop, noop) + result = Result(connection, fetch_size, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) assert result.fetch(0) == [] assert result.fetch(-1) == [] @@ -524,7 +527,7 @@ def test_result_fetch(records, fetch_size, strict): @mark_sync_test def test_keys_are_available_before_and_after_stream(): connection = ConnectionStub(records=Records(["x"], [[1], [2]])) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) assert list(result.keys()) == ["x"] Util.list(result) @@ -540,7 +543,7 @@ def test_consume(records, consume_one, summary_meta, consume_times): connection = ConnectionStub( records=Records(["x"], records), summary_meta=summary_meta ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if consume_one: with suppress(StopIteration): @@ -574,7 +577,7 @@ def test_time_in_summary(t_first, t_last): summary_meta=summary_meta, ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -596,7 +599,7 @@ def test_time_in_summary(t_first, t_last): def test_counts_in_summary(): connection = ConnectionStub(records=Records(["n"], [[1], [2]])) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -610,7 +613,7 @@ def test_query_type(query_type): records=Records(["n"], [[1], [2]]), summary_meta={"type": query_type} ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) summary = result.consume() @@ -625,7 +628,7 @@ def test_data(num_records): records=Records(["n"], [[i + 1] for i in range(num_records)]) ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) result._buffer_all() records = result._record_buffer.copy() @@ -667,7 +670,7 @@ def test_data(num_records): @mark_sync_test def test_result_graph(records): connection = ConnectionStub(records=records) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) graph = result.graph() assert isinstance(graph, Graph) @@ -760,7 +763,7 @@ def test_result_graph(records): def test_to_eager_result(records): summary = {"test_to_eager_result": uuid.uuid4()} connection = ConnectionStub(records=records, summary_meta=summary) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) eager_result = result.to_eager_result() @@ -850,7 +853,7 @@ def test_to_eager_result(records): @mark_sync_test def test_to_df(keys, values, types, instances, test_default_expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) if test_default_expand: df = result.to_df() @@ -1061,7 +1064,7 @@ def test_to_df_expand( keys, values, expected_columns, expected_rows, expected_types ): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) df = result.to_df(expand=True) @@ -1299,7 +1302,7 @@ def test_to_df_expand( @mark_sync_test def test_to_df_parse_dates(keys, values, expected_df, expand): connection = ConnectionStub(records=Records(keys, values)) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) df = result.to_df(expand=expand, parse_dates=True) @@ -1314,7 +1317,7 @@ def test_broken_hydration(nested): value_in = [value_in] records_in = Records(["foo", "bar"], [["foobar", value_in]]) connection = ConnectionStub(records=records_in) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) result._run("CYPHER", {}, None, None, "r", None, None, None) records_out = Util.list(result) assert len(records_out) == 1 @@ -1366,7 +1369,9 @@ def test_notification_warning( ] }, ) - result = Result(connection, 1, warn_notification_severity, noop, noop) + result = Result( + connection, 1, warn_notification_severity, noop, noop, None + ) if expected_warning is None: with warnings.catch_warnings(): warnings.simplefilter("error") # assert not warnings are emitted @@ -1408,7 +1413,7 @@ def test_notification_logging( records=Records(["foo"], ()), summary_meta={"notifications": [notification_data]}, ) - result = Result(connection, 1, None, noop, noop) + result = Result(connection, 1, None, noop, noop, None) with caplog.at_level(logging.INFO, logger="neo4j.notifications"): result._run("CYPHER", {}, None, None, "r", None, None, None) result.consume() @@ -1420,3 +1425,35 @@ def test_notification_logging( f"Received notification from DBMS server: {formatted_notification}" ) assert caplog.messages[0] == expected_message + + +@pytest.mark.parametrize( + "cb", + (True, False) if Util.is_async_code else (False,), +) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_sync_test +def test_on_database_callback(cb, resolved_db): + cb_calls = [] + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + if cb: + db_callback = wrap_async(db_callback) + + run_meta = {} + if resolved_db is not ...: + run_meta["db"] = resolved_db + connection = ConnectionStub( + records=Records(["foo"], ()), run_meta=run_meta + ) + + result = Result(connection, 1, None, noop, noop, db_callback) + result._run("CYPHER", {}, None, None, "r", None, None, None) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db] diff --git a/tests/unit/sync/work/test_session.py b/tests/unit/sync/work/test_session.py index 52843be1..94543cde 100644 --- a/tests/unit/sync/work/test_session.py +++ b/tests/unit/sync/work/test_session.py @@ -19,6 +19,7 @@ import pytest from neo4j import ( + Auth, Bookmarks, ManagedTransaction, Session, @@ -26,8 +27,12 @@ unit_of_work, ) from neo4j._api import TelemetryAPI +from neo4j._async_compat.util import Util +from neo4j._auth_management import to_auth_dict from neo4j._conf import SessionConfig +from neo4j._sync.home_db_cache import HomeDbCache from neo4j._sync.io import ( + AcquisitionDatabase, BoltPool, Neo4jPool, ) @@ -490,8 +495,10 @@ def bmm_get_bookmarks(): fake_pool.update_routing_table.side_effect = ( update_routing_table_side_effect ) + fake_pool.is_direct_pool = False else: fake_pool.mock_add_spec(BoltPool) + fake_pool.is_direct_pool = True config = SessionConfig() config.bookmark_manager = bmm @@ -699,3 +706,174 @@ def work(_): connection_mock.telemetry.assert_called_once() call_args = connection_mock.telemetry.call_args.args assert call_args[0] == TelemetryAPI.DRIVER + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("imp_user", (None, "imp_user")) +@pytest.mark.parametrize( + "auth", + ( + None, + Auth(scheme="magic-auth", principal=None, credentials="tada"), + ), +) +@mark_sync_test +def test_uses_home_db_cache_when_expected( + fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + imp_user, + auth, +): + fake_pool.ssr_enabled = pool_ssr + if pool_routing: + fake_pool.is_direct_pool = False + fake_pool.mock_add_spec(Neo4jPool) + cache_spy = mocker.Mock(spec=HomeDbCache, wraps=HomeDbCache()) + cached_db = "nice_cached_home_db" + key = object() + cache_spy.compute_key.return_value = key + cache_spy.get.return_value = cached_db + fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.impersonated_user = imp_user + config.auth = auth + config.database = db + + with Session(fake_pool, config) as session: + session.run("RETURN 1") + + if expect_cache_usage: + # assert using cache + assert cache_spy.mock_calls == [ + mocker.call.compute_key( + imp_user, to_auth_dict(auth) if auth else None + ), + mocker.call.get(key), + ] + # assert passing cache result as a guess to the pool + fake_pool.acquire.assert_called_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(cached_db, guessed=True), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + else: + # assert not using cache + cache_spy.get.assert_not_called() + # assert passing a non-guess to the pool + fake_pool.acquire.assert_called_once_with( + access_mode=mocker.ANY, + timeout=mocker.ANY, + database=AcquisitionDatabase(db, guessed=False), + bookmarks=mocker.ANY, + auth=mocker.ANY, + liveness_check_timeout=mocker.ANY, + database_callback=mocker.ANY, + ) + + +@pytest.mark.parametrize( + ("db", "pool_ssr", "pool_routing", "expect_cache_usage"), + ( + (db, ssr, routing, ssr and routing and not db) + for ssr in (True, False) + for routing in (True, False) + for db in (None, "mydb") + ), +) +@pytest.mark.parametrize("resolution_at", ("route", "run", "begin")) +@mark_sync_test +def test_pinns_session_db_with_cache( + fake_pool, + mocker, + db, + pool_ssr, + pool_routing, + expect_cache_usage, + resolution_at, +): + def resolve_db(): + if resolution_at == "route": + database_callback = fake_pool.acquire.call_args.kwargs[ + "database_callback" + ] + Util.callback(database_callback, resolved_db) + elif resolution_at == "run": + database_callback = res_mock.call_args.args[-1] + Util.callback(database_callback, resolved_db) + elif resolution_at == "begin": + database_callback = tx_mock.call_args.args[-1] + Util.callback(database_callback, resolved_db) + else: + raise ValueError(f"Unknown resolution_at: {resolution_at}") + + if resolution_at == "run": + res_mock = mocker.patch( + "neo4j._sync.work.session.Result", autospec=True + ) + elif resolution_at == "begin": + tx_mock = mocker.patch( + "neo4j._sync.work.session.Transaction", autospec=True + ) + + resolved_db = "resolved_db" + fake_pool.ssr_enabled = pool_ssr + if pool_routing: + fake_pool.is_direct_pool = False + fake_pool.mock_add_spec(Neo4jPool) + cache_spy = mocker.Mock(spec=HomeDbCache, wraps=HomeDbCache()) + key = object() + cache_spy.compute_key.return_value = key + fake_pool.home_db_cache = cache_spy + + config = SessionConfig() + config.database = db + + with Session(fake_pool, config) as session: + if resolution_at == "begin": + with session.begin_transaction() as tx: + tx.run("RETURN 1") + else: + session.run("RETURN 1") + + if expect_cache_usage: + # assert never using cache to pin a database + assert not session._pinned_database + assert config.database == db + + resolve_db() + + assert session._pinned_database + assert config.database == resolved_db + cache_spy.set.assert_called_once_with(key, resolved_db) + else: + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + + resolve_db() + + if not pool_routing or db: + assert session._pinned_database + assert config.database == db + cache_spy.set.assert_not_called() + else: + cache_spy.set.assert_called_once_with(key, resolved_db) + assert session._pinned_database + assert config.database == resolved_db diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 53aeba1e..b78768a2 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - from unittest.mock import MagicMock import pytest @@ -52,7 +50,7 @@ def test_transaction_context_when_committing( on_error = mocker.MagicMock() on_cancel = mocker.Mock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -88,7 +86,7 @@ def test_transaction_context_with_explicit_rollback( on_error = mocker.MagicMock() on_cancel = mocker.Mock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -120,7 +118,7 @@ class OopsError(RuntimeError): on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) mock_commit = mocker.patch.object(tx, "_commit", wraps=tx._commit) mock_rollback = mocker.patch.object(tx, "_rollback", wraps=tx._rollback) @@ -141,7 +139,7 @@ def test_transaction_run_takes_no_query_object(fake_connection): on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) with pytest.raises(ValueError): tx.run(Query("RETURN 1")) @@ -165,7 +163,7 @@ def test_transaction_run_parameters( on_error = MagicMock() on_cancel = MagicMock() tx = Transaction( - fake_connection, 2, None, on_closed, on_error, on_cancel + fake_connection, 2, None, on_closed, on_error, on_cancel, None ) if not as_kwargs: params = {"parameters": params} @@ -187,7 +185,9 @@ def test_transaction_run_parameters( def test_transaction_rollbacks_on_open_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = False fake_connection.is_reset_mock.reset_mock() @@ -201,7 +201,9 @@ def test_transaction_rollbacks_on_open_connections( def test_transaction_no_rollback_on_reset_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.is_reset_mock.return_value = True fake_connection.is_reset_mock.reset_mock() @@ -215,7 +217,9 @@ def test_transaction_no_rollback_on_reset_connections( def test_transaction_no_rollback_on_closed_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.closed.return_value = True fake_connection.closed.reset_mock() @@ -231,7 +235,9 @@ def test_transaction_no_rollback_on_closed_connections( def test_transaction_no_rollback_on_defunct_connections( fake_connection, ): - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) with tx as tx_: fake_connection.defunct.return_value = True fake_connection.defunct.reset_mock() @@ -246,9 +252,13 @@ def test_transaction_no_rollback_on_defunct_connections( @pytest.mark.parametrize("pipeline", (True, False)) @mark_sync_test def test_transaction_begin_pipelining( - fake_connection, pipeline + fake_connection, + pipeline, + mocker, ) -> None: - tx = Transaction(fake_connection, 2, None, noop, noop, noop) + tx = Transaction( + fake_connection, 2, None, noop, noop, noop, None + ) database = "db" imp_user = None bookmarks = ["bookmark1", "bookmark2"] @@ -283,6 +293,7 @@ def test_transaction_begin_pipelining( "notifications_disabled_classifications": ( notifications_disabled_classifications ), + "on_success": mocker.ANY, }, ), ] @@ -333,7 +344,7 @@ def test_server_error_propagates(scripted_connection, error): raise ValueError(f"Unknown error type {error}") connection.set_script(script) - tx = Transaction(connection, 2, None, noop, noop, noop) + tx = Transaction(connection, 2, None, noop, noop, noop, None) res1 = tx.run("UNWIND range(1, 1000) AS n RETURN n") assert res1.__next__() == {"n": 1} @@ -349,3 +360,45 @@ def test_server_error_propagates(scripted_connection, error): res1.__next__() assert exc1.value is exc2.value.__cause__ + + +@pytest.mark.parametrize("cb", (True, False)) +@pytest.mark.parametrize("resolved_db", (..., None, "resolved_db")) +@mark_sync_test +def test_on_database_callback( + scripted_connection, cb, resolved_db +): + cb_calls = [] + + if cb: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + else: + + def db_callback(db): + nonlocal cb_calls + cb_calls.append(db) + + begin_meta = {} + if resolved_db is not ...: + begin_meta["db"] = resolved_db + connection = scripted_connection + connection.set_script( + [ + ("begin", {"on_success": (begin_meta,), "on_summary": None}), + ] + ) + + result = Transaction( + connection, 1, None, noop, noop, noop, db_callback + ) + result._begin( + None, None, None, None, None, None, None, None, pipelined=False + ) + + if resolved_db in {..., None}: + assert cb_calls == [] + else: + assert cb_calls == [resolved_db]