diff --git a/src/neo4j/_async/io/_pool.py b/src/neo4j/_async/io/_pool.py index 270bcae8..d61f6f04 100644 --- a/src/neo4j/_async/io/_pool.py +++ b/src/neo4j/_async/io/_pool.py @@ -20,6 +20,7 @@ import asyncio import logging import math +import sys from collections import ( defaultdict, deque, @@ -53,6 +54,7 @@ DriverError, Neo4jError, ReadServiceUnavailable, + RoutingServiceUnavailable, ServiceUnavailable, SessionExpired, WriteServiceUnavailable, @@ -810,6 +812,7 @@ async def fetch_routing_table( imp_user, bookmarks, auth, + ignored_errors=None, ): """ Fetch a routing table from a given router address. @@ -823,6 +826,7 @@ async def fetch_routing_table( :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table :param auth: auth + :param ignored_errors: optional list to accumulate ignored errors in :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -843,15 +847,16 @@ async def fetch_routing_table( # router. Hence, the driver should fail fast during discovery. if e._is_fatal_during_discovery(): raise - except (ServiceUnavailable, SessionExpired): - pass + if ignored_errors is not None: + ignored_errors.append(e) + except (ServiceUnavailable, SessionExpired) as e: + if ignored_errors is not None: + ignored_errors.append(e) if not new_routing_info: log.debug( "[#0000] _: failed to fetch routing info from %r", address, ) - # TODO: 7.0 - when Python 3.11+ is the minimum, - # use exception groups instead of swallowing discovery errors return None else: servers = new_routing_info[0]["servers"] @@ -876,6 +881,12 @@ async def fetch_routing_table( "server %s", address, ) + if ignored_errors is not None: + ignored_errors.append( + RoutingServiceUnavailable( + "Rejected routing table: no routers" + ) + ) return None # No readers @@ -884,6 +895,12 @@ async def fetch_routing_table( "[#0000] _: no read servers returned from server %s", address, ) + if ignored_errors is not None: + ignored_errors.append( + ReadServiceUnavailable( + "Rejected routing table: no readers" + ) + ) return None # At least one of each is fine, so return this table @@ -898,6 +915,7 @@ async def _update_routing_table_from( auth, acquisition_timeout, database_callback, + ignored_errors=None, ): """ Try to update routing tables with the given routers. @@ -924,6 +942,7 @@ async def _update_routing_table_from( imp_user=imp_user, bookmarks=bookmarks, auth=auth, + ignored_errors=ignored_errors, ) if new_routing_table is not None: new_database = new_routing_table.database @@ -973,6 +992,7 @@ async def update_routing_table( acquisition_timeout = acquisition_timeout_to_deadline( acquisition_timeout ) + errors = [] async with self.refresh_lock: routing_table = await self.get_routing_table(database) if routing_table is not None: @@ -997,6 +1017,7 @@ async def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ) ): # Why is only the first initial routing address used? @@ -1009,6 +1030,7 @@ async def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ): return @@ -1022,6 +1044,7 @@ async def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ) ): # Why is only the first initial routing address used? @@ -1029,7 +1052,33 @@ async def update_routing_table( # None of the routers have been successful, so just fail log.error("Unable to retrieve routing information") - raise ServiceUnavailable("Unable to retrieve routing information") + if sys.version_info >= (3, 11): + e = ExceptionGroup( # noqa: F821 # version guard in place + "All routing table requests failed", errors + ) + else: + e = None + for error in errors: + if e is None: + e = error + continue + cause = error + seen_causes = {id(cause)} + while True: + next_cause = getattr(cause, "__cause__", None) + if next_cause is None: + break + if id(next_cause) in seen_causes: + # Avoid infinite recursion in case of circular + # references. + break + cause = next_cause + seen_causes.add(id(cause)) + cause.__cause__ = e + e = error + raise ServiceUnavailable( + "Unable to retrieve routing information" + ) from e async def update_connection_pool(self): async with self.refresh_lock: diff --git a/src/neo4j/_sync/io/_pool.py b/src/neo4j/_sync/io/_pool.py index 49380742..a94af11f 100644 --- a/src/neo4j/_sync/io/_pool.py +++ b/src/neo4j/_sync/io/_pool.py @@ -20,6 +20,7 @@ import asyncio import logging import math +import sys from collections import ( defaultdict, deque, @@ -53,6 +54,7 @@ DriverError, Neo4jError, ReadServiceUnavailable, + RoutingServiceUnavailable, ServiceUnavailable, SessionExpired, WriteServiceUnavailable, @@ -807,6 +809,7 @@ def fetch_routing_table( imp_user, bookmarks, auth, + ignored_errors=None, ): """ Fetch a routing table from a given router address. @@ -820,6 +823,7 @@ def fetch_routing_table( :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table :param auth: auth + :param ignored_errors: optional list to accumulate ignored errors in :returns: a new RoutingTable instance or None if the given router is currently unable to provide routing information @@ -840,15 +844,16 @@ def fetch_routing_table( # router. Hence, the driver should fail fast during discovery. if e._is_fatal_during_discovery(): raise - except (ServiceUnavailable, SessionExpired): - pass + if ignored_errors is not None: + ignored_errors.append(e) + except (ServiceUnavailable, SessionExpired) as e: + if ignored_errors is not None: + ignored_errors.append(e) if not new_routing_info: log.debug( "[#0000] _: failed to fetch routing info from %r", address, ) - # TODO: 7.0 - when Python 3.11+ is the minimum, - # use exception groups instead of swallowing discovery errors return None else: servers = new_routing_info[0]["servers"] @@ -873,6 +878,12 @@ def fetch_routing_table( "server %s", address, ) + if ignored_errors is not None: + ignored_errors.append( + RoutingServiceUnavailable( + "Rejected routing table: no routers" + ) + ) return None # No readers @@ -881,6 +892,12 @@ def fetch_routing_table( "[#0000] _: no read servers returned from server %s", address, ) + if ignored_errors is not None: + ignored_errors.append( + ReadServiceUnavailable( + "Rejected routing table: no readers" + ) + ) return None # At least one of each is fine, so return this table @@ -895,6 +912,7 @@ def _update_routing_table_from( auth, acquisition_timeout, database_callback, + ignored_errors=None, ): """ Try to update routing tables with the given routers. @@ -921,6 +939,7 @@ def _update_routing_table_from( imp_user=imp_user, bookmarks=bookmarks, auth=auth, + ignored_errors=ignored_errors, ) if new_routing_table is not None: new_database = new_routing_table.database @@ -970,6 +989,7 @@ def update_routing_table( acquisition_timeout = acquisition_timeout_to_deadline( acquisition_timeout ) + errors = [] with self.refresh_lock: routing_table = self.get_routing_table(database) if routing_table is not None: @@ -994,6 +1014,7 @@ def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ) ): # Why is only the first initial routing address used? @@ -1006,6 +1027,7 @@ def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ): return @@ -1019,6 +1041,7 @@ def update_routing_table( auth=auth, acquisition_timeout=acquisition_timeout, database_callback=database_callback, + ignored_errors=errors, ) ): # Why is only the first initial routing address used? @@ -1026,7 +1049,33 @@ def update_routing_table( # None of the routers have been successful, so just fail log.error("Unable to retrieve routing information") - raise ServiceUnavailable("Unable to retrieve routing information") + if sys.version_info >= (3, 11): + e = ExceptionGroup( # noqa: F821 # version guard in place + "All routing table requests failed", errors + ) + else: + e = None + for error in errors: + if e is None: + e = error + continue + cause = error + seen_causes = {id(cause)} + while True: + next_cause = getattr(cause, "__cause__", None) + if next_cause is None: + break + if id(next_cause) in seen_causes: + # Avoid infinite recursion in case of circular + # references. + break + cause = next_cause + seen_causes.add(id(cause)) + cause.__cause__ = e + e = error + raise ServiceUnavailable( + "Unable to retrieve routing information" + ) from e def update_connection_pool(self): with self.refresh_lock: diff --git a/tests/unit/async_/io/test_neo4j_pool.py b/tests/unit/async_/io/test_neo4j_pool.py index a758ac3c..d10abcde 100644 --- a/tests/unit/async_/io/test_neo4j_pool.py +++ b/tests/unit/async_/io/test_neo4j_pool.py @@ -16,6 +16,7 @@ import contextlib import inspect +import sys import pytest @@ -650,6 +651,68 @@ async def test_discovery_is_retried(custom_routing_opener, error): assert len(opener.connections) == 4 +@mark_async_test +async def test_failed_discovery_chains_errors(custom_routing_opener) -> None: + error1 = Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", + ) + error2 = ServiceUnavailable("message") + error2_cause = Exception("just (be)cause") + error2.__cause__ = error2_cause + error3 = SessionExpired("message") + error4 = Neo4jError._hydrate_neo4j( + code="Neo.Made.Up.Code", + message="message", + ) + opener = custom_routing_opener( + [ + None, # first call to router for seeding the RT with more routers + error1, # router 1 fails + error2, # router 2 fails + error3, # router 3 fails + error4, # initial router fails + ] + ) + pool = AsyncNeo4jPool( + opener, + _pool_config(), + WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host"), + ) + cx1 = await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + await pool.release(cx1) + pool.routing_tables.get(TEST_DB1.name).ttl = 0 + + with pytest.raises(ServiceUnavailable) as exc_info: + await pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + + exc = exc_info.value + if sys.version_info >= (3, 11): + group = exc.__cause__ + assert isinstance(group, ExceptionGroup) # noqa: F821 + assert all( + a is b + for a, b in zip( + group.exceptions, + [error1, error2, error3, error4], + strict=True, + ) + ) + assert error4.__cause__ is None + assert error3.__cause__ is None + assert error2.__cause__ is error2_cause + assert error2_cause.__cause__ is None + assert error1.__cause__ is None + else: + assert exc.__cause__ is error4 + assert error4.__cause__ is error3 + assert error3.__cause__ is error2 + assert error2.__cause__ is error2_cause + assert error2_cause.__cause__ is error1 + assert error1.__cause__ is None + + @pytest.mark.parametrize( "error", map( diff --git a/tests/unit/sync/io/test_neo4j_pool.py b/tests/unit/sync/io/test_neo4j_pool.py index 5228f896..70c8401b 100644 --- a/tests/unit/sync/io/test_neo4j_pool.py +++ b/tests/unit/sync/io/test_neo4j_pool.py @@ -16,6 +16,7 @@ import contextlib import inspect +import sys import pytest @@ -650,6 +651,68 @@ def test_discovery_is_retried(custom_routing_opener, error): assert len(opener.connections) == 4 +@mark_sync_test +def test_failed_discovery_chains_errors(custom_routing_opener) -> None: + error1 = Neo4jError._hydrate_neo4j( + code="Neo.ClientError.Security.AuthorizationExpired", + message="message", + ) + error2 = ServiceUnavailable("message") + error2_cause = Exception("just (be)cause") + error2.__cause__ = error2_cause + error3 = SessionExpired("message") + error4 = Neo4jError._hydrate_neo4j( + code="Neo.Made.Up.Code", + message="message", + ) + opener = custom_routing_opener( + [ + None, # first call to router for seeding the RT with more routers + error1, # router 1 fails + error2, # router 2 fails + error3, # router 3 fails + error4, # initial router fails + ] + ) + pool = Neo4jPool( + opener, + _pool_config(), + WorkspaceConfig(), + ResolvedAddress(("1.2.3.1", 9999), host_name="host"), + ) + cx1 = pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + pool.release(cx1) + pool.routing_tables.get(TEST_DB1.name).ttl = 0 + + with pytest.raises(ServiceUnavailable) as exc_info: + pool.acquire(READ_ACCESS, 30, TEST_DB1, None, None, None) + + exc = exc_info.value + if sys.version_info >= (3, 11): + group = exc.__cause__ + assert isinstance(group, ExceptionGroup) # noqa: F821 + assert all( + a is b + for a, b in zip( + group.exceptions, + [error1, error2, error3, error4], + strict=True, + ) + ) + assert error4.__cause__ is None + assert error3.__cause__ is None + assert error2.__cause__ is error2_cause + assert error2_cause.__cause__ is None + assert error1.__cause__ is None + else: + assert exc.__cause__ is error4 + assert error4.__cause__ is error3 + assert error3.__cause__ is error2 + assert error2.__cause__ is error2_cause + assert error2_cause.__cause__ is error1 + assert error1.__cause__ is None + + @pytest.mark.parametrize( "error", map(