Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,15 @@ async def _select_next_server(self) -> None:
if self.options["max_reconnect_attempts"] > 0:
if s.reconnects > self.options["max_reconnect_attempts"]:
# Discard server since already tried to reconnect too many times
# Check if all remaining servers have also exceeded max reconnect attempts
if len(self._server_pool) == 0 or all(
srv.reconnects > self.options["max_reconnect_attempts"]
for srv in self._server_pool
):
# No more servers available or all have exceeded max attempts
raise errors.MaxReconnectAttemptsExceededError(
self.options["max_reconnect_attempts"]
)
continue

# Not yet exceeded max_reconnect_attempts so can still use
Expand Down
11 changes: 11 additions & 0 deletions nats/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,14 @@ def __init__(self, msg=None) -> None:

def __str__(self) -> str:
return f"nats: message was already acknowledged: {self._msg}"


class MaxReconnectAttemptsExceededError(Error):

def __init__(self, max_attempts: int | None = None) -> None:
self.max_attempts = max_attempts

def __str__(self) -> str:
if self.max_attempts is not None:
return f"nats: maximum reconnection attempts exceeded: {self.max_attempts}"
return "nats: maximum reconnection attempts exceeded"
59 changes: 59 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,65 @@ async def worker_handler(msg):
self.assertEqual(1, len(errors))
self.assertTrue(type(errors[0]) is nats.errors.UnexpectedEOF)

@async_test
async def test_max_reconnect_attempts_exceeded(self):
nc = NATS()

disconnected_count = 0
closed_count = 0
reconnected_count = 0
errors = []

async def disconnected_cb():
nonlocal disconnected_count
disconnected_count += 1

async def closed_cb():
nonlocal closed_count
closed_count += 1

async def reconnected_cb():
nonlocal reconnected_count
reconnected_count += 1

async def err_cb(e):
nonlocal errors
errors.append(e)

options = {
"servers": [
"nats://127.0.0.1:4222",
"nats://127.0.0.1:4223",
],
"disconnected_cb": disconnected_cb,
"closed_cb": closed_cb,
"reconnected_cb": reconnected_cb,
"error_cb": err_cb,
"max_reconnect_attempts": 2,
"reconnect_time_wait": 0.1,
}
await nc.connect(**options)

# Stop all servers to force reconnection attempts
for server in self.server_pool:
asyncio.get_running_loop().run_in_executor(None, server.stop)

# Wait for the client to exhaust all reconnection attempts
await asyncio.sleep(1.0)

# Should have raised MaxReconnectAttemptsExceededError
self.assertTrue(nc.is_closed)
self.assertIsNotNone(nc.last_error)
self.assertIsInstance(nc.last_error, nats.errors.MaxReconnectAttemptsExceededError)
self.assertEqual(nc.last_error.max_attempts, 2)

# Verify error was also passed to error callback
max_reconnect_errors = [e for e in errors if isinstance(e, nats.errors.MaxReconnectAttemptsExceededError)]
self.assertEqual(len(max_reconnect_errors), 1)
self.assertEqual(max_reconnect_errors[0].max_attempts, 2)

await nc.close()


class ClientAuthTokenTest(MultiServerAuthTokenTestCase):

Expand Down
Loading