Skip to content

Commit d41a87f

Browse files
committed
Fixed infinitely recursive health checks.
1 parent 601c1aa commit d41a87f

File tree

2 files changed

+68
-33
lines changed

2 files changed

+68
-33
lines changed

redis/asyncio/connection.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
282282
"""
283283
self._parser = parser_class(socket_read_size=self._socket_read_size)
284284

285-
async def connect(self):
285+
async def connect(self, check_health: bool = True):
286286
"""Connects to the Redis server if not already connected"""
287287
if self.is_connected:
288288
return
@@ -302,7 +302,7 @@ async def connect(self):
302302
try:
303303
if not self.redis_connect_func:
304304
# Use the default on_connect function
305-
await self.on_connect()
305+
await self.on_connect(check_health=check_health)
306306
else:
307307
# Use the passed function redis_connect_func
308308
(
@@ -339,7 +339,7 @@ def _error_message(self, exception: BaseException) -> str:
339339
def get_protocol(self):
340340
return self.protocol
341341

342-
async def on_connect(self) -> None:
342+
async def on_connect(self, check_health: bool = True) -> None:
343343
"""Initialize the connection, authenticate and select a database"""
344344
self._parser.on_connect(self)
345345
parser = self._parser
@@ -398,7 +398,7 @@ async def on_connect(self) -> None:
398398
# update cluster exception classes
399399
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
400400
self._parser.on_connect(self)
401-
await self.send_command("HELLO", self.protocol)
401+
await self.send_command("HELLO", self.protocol, check_health=check_health)
402402
response = await self.read_response()
403403
# if response.get(b"proto") != self.protocol and response.get(
404404
# "proto"
@@ -407,18 +407,35 @@ async def on_connect(self) -> None:
407407

408408
# if a client_name is given, set it
409409
if self.client_name:
410-
await self.send_command("CLIENT", "SETNAME", self.client_name)
410+
await self.send_command(
411+
"CLIENT",
412+
"SETNAME",
413+
self.client_name,
414+
check_health=check_health,
415+
)
411416
if str_if_bytes(await self.read_response()) != "OK":
412417
raise ConnectionError("Error setting client name")
413418

414419
# set the library name and version, pipeline for lower startup latency
415420
if self.lib_name:
416-
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
421+
await self.send_command(
422+
"CLIENT",
423+
"SETINFO",
424+
"LIB-NAME",
425+
self.lib_name,
426+
check_health=check_health,
427+
)
417428
if self.lib_version:
418-
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
429+
await self.send_command(
430+
"CLIENT",
431+
"SETINFO",
432+
"LIB-VER",
433+
self.lib_version,
434+
check_health=check_health,
435+
)
419436
# if a database is specified, switch to it. Also pipeline this
420437
if self.db:
421-
await self.send_command("SELECT", self.db)
438+
await self.send_command("SELECT", self.db, check_health=check_health)
422439

423440
# read responses from pipeline
424441
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -480,8 +497,8 @@ async def send_packed_command(
480497
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
481498
) -> None:
482499
if not self.is_connected:
483-
await self.connect()
484-
elif check_health:
500+
await self.connect(check_health=False)
501+
if check_health:
485502
await self.check_health()
486503

487504
try:
@@ -1134,7 +1151,7 @@ def make_connection(self):
11341151

11351152
async def ensure_connection(self, connection: AbstractConnection):
11361153
"""Ensure that the connection object is connected and valid"""
1137-
await connection.connect()
1154+
await connection.connect(check_health=True)
11381155
# connections that the pool provides should be ready to send
11391156
# a command. if not, the connection was either returned to the
11401157
# pool before all data has been read or the socket has been
@@ -1144,7 +1161,7 @@ async def ensure_connection(self, connection: AbstractConnection):
11441161
raise ConnectionError("Connection has data") from None
11451162
except (ConnectionError, OSError):
11461163
await connection.disconnect()
1147-
await connection.connect()
1164+
await connection.connect(check_health=True)
11481165
if await connection.can_read_destructive():
11491166
raise ConnectionError("Connection not ready") from None
11501167

redis/connection.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -159,11 +159,11 @@ def get_protocol(self):
159159
pass
160160

161161
@abstractmethod
162-
def connect(self):
162+
def connect(self, check_health=True):
163163
pass
164164

165165
@abstractmethod
166-
def on_connect(self):
166+
def on_connect(self, check_health=True):
167167
pass
168168

169169
@abstractmethod
@@ -370,13 +370,14 @@ def set_parser(self, parser_class):
370370
"""
371371
self._parser = parser_class(socket_read_size=self._socket_read_size)
372372

373-
def connect(self):
373+
def connect(self, check_health: bool = True):
374374
"Connects to the Redis server if not already connected"
375375
if self._sock:
376376
return
377377
try:
378378
sock = self.retry.call_with_retry(
379-
lambda: self._connect(), lambda error: self.disconnect(error)
379+
lambda: self._connect(),
380+
lambda error: self.disconnect(error),
380381
)
381382
except socket.timeout:
382383
raise TimeoutError("Timeout connecting to server")
@@ -387,7 +388,7 @@ def connect(self):
387388
try:
388389
if self.redis_connect_func is None:
389390
# Use the default on_connect function
390-
self.on_connect()
391+
self.on_connect(check_health=check_health)
391392
else:
392393
# Use the passed function redis_connect_func
393394
self.redis_connect_func(self)
@@ -416,7 +417,7 @@ def _host_error(self):
416417
def _error_message(self, exception):
417418
return format_error_message(self._host_error(), exception)
418419

419-
def on_connect(self):
420+
def on_connect(self, check_health: bool = True):
420421
"Initialize the connection, authenticate and select a database"
421422
self._parser.on_connect(self)
422423
parser = self._parser
@@ -475,7 +476,7 @@ def on_connect(self):
475476
# update cluster exception classes
476477
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
477478
self._parser.on_connect(self)
478-
self.send_command("HELLO", self.protocol)
479+
self.send_command("HELLO", self.protocol, check_health=check_health)
479480
self.handshake_metadata = self.read_response()
480481
if (
481482
self.handshake_metadata.get(b"proto") != self.protocol
@@ -485,28 +486,45 @@ def on_connect(self):
485486

486487
# if a client_name is given, set it
487488
if self.client_name:
488-
self.send_command("CLIENT", "SETNAME", self.client_name)
489+
self.send_command(
490+
"CLIENT",
491+
"SETNAME",
492+
self.client_name,
493+
check_health=check_health,
494+
)
489495
if str_if_bytes(self.read_response()) != "OK":
490496
raise ConnectionError("Error setting client name")
491497

492498
try:
493499
# set the library name and version
494500
if self.lib_name:
495-
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
501+
self.send_command(
502+
"CLIENT",
503+
"SETINFO",
504+
"LIB-NAME",
505+
self.lib_name,
506+
check_health=check_health,
507+
)
496508
self.read_response()
497509
except ResponseError:
498510
pass
499511

500512
try:
501513
if self.lib_version:
502-
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
514+
self.send_command(
515+
"CLIENT",
516+
"SETINFO",
517+
"LIB-VER",
518+
self.lib_version,
519+
check_health=check_health,
520+
)
503521
self.read_response()
504522
except ResponseError:
505523
pass
506524

507525
# if a database is specified, switch to it
508526
if self.db:
509-
self.send_command("SELECT", self.db)
527+
self.send_command("SELECT", self.db, check_health=check_health)
510528
if str_if_bytes(self.read_response()) != "OK":
511529
raise ConnectionError("Invalid Database")
512530

@@ -548,7 +566,7 @@ def check_health(self):
548566
def send_packed_command(self, command, check_health=True):
549567
"""Send an already packed command to the Redis server"""
550568
if not self._sock:
551-
self.connect()
569+
self.connect(check_health=False)
552570
# guard against health check recursion
553571
if check_health:
554572
self.check_health()
@@ -587,7 +605,7 @@ def can_read(self, timeout=0):
587605
"""Poll the socket to see if there's data that can be read."""
588606
sock = self._sock
589607
if not sock:
590-
self.connect()
608+
self.connect(check_health=True)
591609

592610
host_error = self._host_error()
593611

@@ -806,8 +824,8 @@ def deregister_connect_callback(self, callback):
806824
def set_parser(self, parser_class):
807825
self._conn.set_parser(parser_class)
808826

809-
def connect(self):
810-
self._conn.connect()
827+
def connect(self, check_health: bool = True):
828+
self._conn.connect(check_health=check_health)
811829

812830
server_name = self._conn.handshake_metadata.get(b"server", None)
813831
if server_name is None:
@@ -829,8 +847,8 @@ def connect(self):
829847
"To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
830848
)
831849

832-
def on_connect(self):
833-
self._conn.on_connect()
850+
def on_connect(self, check_health: bool = True):
851+
self._conn.on_connect(check_health=check_health)
834852

835853
def disconnect(self, *args):
836854
with self._cache_lock:
@@ -1482,7 +1500,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
14821500

14831501
try:
14841502
# ensure this connection is connected to Redis
1485-
connection.connect()
1503+
connection.connect(check_health=True)
14861504
# connections that the pool provides should be ready to send
14871505
# a command. if not, the connection was either returned to the
14881506
# pool before all data has been read or the socket has been
@@ -1492,7 +1510,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
14921510
raise ConnectionError("Connection has data")
14931511
except (ConnectionError, OSError):
14941512
connection.disconnect()
1495-
connection.connect()
1513+
connection.connect(check_health=True)
14961514
if connection.can_read():
14971515
raise ConnectionError("Connection not ready")
14981516
except BaseException:
@@ -1729,7 +1747,7 @@ def get_connection(self, command_name=None, *keys, **options):
17291747

17301748
try:
17311749
# ensure this connection is connected to Redis
1732-
connection.connect()
1750+
connection.connect(check_health=True)
17331751
# connections that the pool provides should be ready to send
17341752
# a command. if not, the connection was either returned to the
17351753
# pool before all data has been read or the socket has been
@@ -1739,7 +1757,7 @@ def get_connection(self, command_name=None, *keys, **options):
17391757
raise ConnectionError("Connection has data")
17401758
except (ConnectionError, OSError):
17411759
connection.disconnect()
1742-
connection.connect()
1760+
connection.connect(check_health=True)
17431761
if connection.can_read():
17441762
raise ConnectionError("Connection not ready")
17451763
except BaseException:

0 commit comments

Comments
 (0)