diff --git a/nats/aio/client.py b/nats/aio/client.py index c8a9bfe1..35d8cad6 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -1367,6 +1367,15 @@ async def _select_next_server(self) -> None: if self.options["max_reconnect_attempts"] > 0: if s.reconnects > self.options["max_reconnect_attempts"]: # Discard server since already tried to reconnect too many times + # Check if all remaining servers have also exceeded max reconnect attempts + if len(self._server_pool) == 0 or all( + srv.reconnects > self.options["max_reconnect_attempts"] + for srv in self._server_pool + ): + # No more servers available or all have exceeded max attempts + raise errors.MaxReconnectAttemptsExceededError( + self.options["max_reconnect_attempts"] + ) continue # Not yet exceeded max_reconnect_attempts so can still use diff --git a/nats/errors.py b/nats/errors.py index a5d29287..fe7af179 100644 --- a/nats/errors.py +++ b/nats/errors.py @@ -196,3 +196,14 @@ def __init__(self, msg=None) -> None: def __str__(self) -> str: return f"nats: message was already acknowledged: {self._msg}" + + +class MaxReconnectAttemptsExceededError(Error): + + def __init__(self, max_attempts: int | None = None) -> None: + self.max_attempts = max_attempts + + def __str__(self) -> str: + if self.max_attempts is not None: + return f"nats: maximum reconnection attempts exceeded: {self.max_attempts}" + return "nats: maximum reconnection attempts exceeded" diff --git a/tests/test_client.py b/tests/test_client.py index e16ee09a..e61bbac7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1625,6 +1625,65 @@ async def worker_handler(msg): self.assertEqual(1, len(errors)) self.assertTrue(type(errors[0]) is nats.errors.UnexpectedEOF) + @async_test + async def test_max_reconnect_attempts_exceeded(self): + nc = NATS() + + disconnected_count = 0 + closed_count = 0 + reconnected_count = 0 + errors = [] + + async def disconnected_cb(): + nonlocal disconnected_count + disconnected_count += 1 + + async def closed_cb(): + nonlocal closed_count + closed_count += 1 + + async def reconnected_cb(): + nonlocal reconnected_count + reconnected_count += 1 + + async def err_cb(e): + nonlocal errors + errors.append(e) + + options = { + "servers": [ + "nats://127.0.0.1:4222", + "nats://127.0.0.1:4223", + ], + "disconnected_cb": disconnected_cb, + "closed_cb": closed_cb, + "reconnected_cb": reconnected_cb, + "error_cb": err_cb, + "max_reconnect_attempts": 2, + "reconnect_time_wait": 0.1, + } + await nc.connect(**options) + + # Stop all servers to force reconnection attempts + for server in self.server_pool: + asyncio.get_running_loop().run_in_executor(None, server.stop) + + # Wait for the client to exhaust all reconnection attempts + await asyncio.sleep(1.0) + + # Should have raised MaxReconnectAttemptsExceededError + self.assertTrue(nc.is_closed) + self.assertIsNotNone(nc.last_error) + self.assertIsInstance(nc.last_error, nats.errors.MaxReconnectAttemptsExceededError) + self.assertEqual(nc.last_error.max_attempts, 2) + + # Verify error was also passed to error callback + max_reconnect_errors = [e for e in errors if isinstance(e, nats.errors.MaxReconnectAttemptsExceededError)] + self.assertEqual(len(max_reconnect_errors), 1) + self.assertEqual(max_reconnect_errors[0].max_attempts, 2) + + await nc.close() + class ClientAuthTokenTest(MultiServerAuthTokenTestCase):