@@ -159,11 +159,11 @@ def get_protocol(self):
159
159
pass
160
160
161
161
@abstractmethod
162
- def connect (self ):
162
+ def connect (self , check_health = True ):
163
163
pass
164
164
165
165
@abstractmethod
166
- def on_connect (self ):
166
+ def on_connect (self , check_health = True ):
167
167
pass
168
168
169
169
@abstractmethod
@@ -370,13 +370,14 @@ def set_parser(self, parser_class):
370
370
"""
371
371
self ._parser = parser_class (socket_read_size = self ._socket_read_size )
372
372
373
- def connect (self ):
373
+ def connect (self , check_health : bool = True ):
374
374
"Connects to the Redis server if not already connected"
375
375
if self ._sock :
376
376
return
377
377
try :
378
378
sock = self .retry .call_with_retry (
379
- lambda : self ._connect (), lambda error : self .disconnect (error )
379
+ lambda : self ._connect (),
380
+ lambda error : self .disconnect (error ),
380
381
)
381
382
except socket .timeout :
382
383
raise TimeoutError ("Timeout connecting to server" )
@@ -387,7 +388,7 @@ def connect(self):
387
388
try :
388
389
if self .redis_connect_func is None :
389
390
# Use the default on_connect function
390
- self .on_connect ()
391
+ self .on_connect (check_health = check_health )
391
392
else :
392
393
# Use the passed function redis_connect_func
393
394
self .redis_connect_func (self )
@@ -416,7 +417,7 @@ def _host_error(self):
416
417
def _error_message (self , exception ):
417
418
return format_error_message (self ._host_error (), exception )
418
419
419
- def on_connect (self ):
420
+ def on_connect (self , check_health : bool = True ):
420
421
"Initialize the connection, authenticate and select a database"
421
422
self ._parser .on_connect (self )
422
423
parser = self ._parser
@@ -475,7 +476,7 @@ def on_connect(self):
475
476
# update cluster exception classes
476
477
self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
477
478
self ._parser .on_connect (self )
478
- self .send_command ("HELLO" , self .protocol )
479
+ self .send_command ("HELLO" , self .protocol , check_health = check_health )
479
480
self .handshake_metadata = self .read_response ()
480
481
if (
481
482
self .handshake_metadata .get (b"proto" ) != self .protocol
@@ -485,28 +486,45 @@ def on_connect(self):
485
486
486
487
# if a client_name is given, set it
487
488
if self .client_name :
488
- self .send_command ("CLIENT" , "SETNAME" , self .client_name )
489
+ self .send_command (
490
+ "CLIENT" ,
491
+ "SETNAME" ,
492
+ self .client_name ,
493
+ check_health = check_health ,
494
+ )
489
495
if str_if_bytes (self .read_response ()) != "OK" :
490
496
raise ConnectionError ("Error setting client name" )
491
497
492
498
try :
493
499
# set the library name and version
494
500
if self .lib_name :
495
- self .send_command ("CLIENT" , "SETINFO" , "LIB-NAME" , self .lib_name )
501
+ self .send_command (
502
+ "CLIENT" ,
503
+ "SETINFO" ,
504
+ "LIB-NAME" ,
505
+ self .lib_name ,
506
+ check_health = check_health ,
507
+ )
496
508
self .read_response ()
497
509
except ResponseError :
498
510
pass
499
511
500
512
try :
501
513
if self .lib_version :
502
- self .send_command ("CLIENT" , "SETINFO" , "LIB-VER" , self .lib_version )
514
+ self .send_command (
515
+ "CLIENT" ,
516
+ "SETINFO" ,
517
+ "LIB-VER" ,
518
+ self .lib_version ,
519
+ check_health = check_health ,
520
+ )
503
521
self .read_response ()
504
522
except ResponseError :
505
523
pass
506
524
507
525
# if a database is specified, switch to it
508
526
if self .db :
509
- self .send_command ("SELECT" , self .db )
527
+ self .send_command ("SELECT" , self .db , check_health = check_health )
510
528
if str_if_bytes (self .read_response ()) != "OK" :
511
529
raise ConnectionError ("Invalid Database" )
512
530
@@ -548,7 +566,7 @@ def check_health(self):
548
566
def send_packed_command (self , command , check_health = True ):
549
567
"""Send an already packed command to the Redis server"""
550
568
if not self ._sock :
551
- self .connect ()
569
+ self .connect (check_health = False )
552
570
# guard against health check recursion
553
571
if check_health :
554
572
self .check_health ()
@@ -587,7 +605,7 @@ def can_read(self, timeout=0):
587
605
"""Poll the socket to see if there's data that can be read."""
588
606
sock = self ._sock
589
607
if not sock :
590
- self .connect ()
608
+ self .connect (check_health = True )
591
609
592
610
host_error = self ._host_error ()
593
611
@@ -806,8 +824,8 @@ def deregister_connect_callback(self, callback):
806
824
def set_parser (self , parser_class ):
807
825
self ._conn .set_parser (parser_class )
808
826
809
- def connect (self ):
810
- self ._conn .connect ()
827
+ def connect (self , check_health : bool = True ):
828
+ self ._conn .connect (check_health = check_health )
811
829
812
830
server_name = self ._conn .handshake_metadata .get (b"server" , None )
813
831
if server_name is None :
@@ -829,8 +847,8 @@ def connect(self):
829
847
"To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
830
848
)
831
849
832
- def on_connect (self ):
833
- self ._conn .on_connect ()
850
+ def on_connect (self , check_health : bool = True ):
851
+ self ._conn .on_connect (check_health = check_health )
834
852
835
853
def disconnect (self , * args ):
836
854
with self ._cache_lock :
@@ -1482,7 +1500,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1482
1500
1483
1501
try :
1484
1502
# ensure this connection is connected to Redis
1485
- connection .connect ()
1503
+ connection .connect (check_health = True )
1486
1504
# connections that the pool provides should be ready to send
1487
1505
# a command. if not, the connection was either returned to the
1488
1506
# pool before all data has been read or the socket has been
@@ -1492,7 +1510,7 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection":
1492
1510
raise ConnectionError ("Connection has data" )
1493
1511
except (ConnectionError , OSError ):
1494
1512
connection .disconnect ()
1495
- connection .connect ()
1513
+ connection .connect (check_health = True )
1496
1514
if connection .can_read ():
1497
1515
raise ConnectionError ("Connection not ready" )
1498
1516
except BaseException :
@@ -1729,7 +1747,7 @@ def get_connection(self, command_name=None, *keys, **options):
1729
1747
1730
1748
try :
1731
1749
# ensure this connection is connected to Redis
1732
- connection .connect ()
1750
+ connection .connect (check_health = True )
1733
1751
# connections that the pool provides should be ready to send
1734
1752
# a command. if not, the connection was either returned to the
1735
1753
# pool before all data has been read or the socket has been
@@ -1739,7 +1757,7 @@ def get_connection(self, command_name=None, *keys, **options):
1739
1757
raise ConnectionError ("Connection has data" )
1740
1758
except (ConnectionError , OSError ):
1741
1759
connection .disconnect ()
1742
- connection .connect ()
1760
+ connection .connect (check_health = True )
1743
1761
if connection .can_read ():
1744
1762
raise ConnectionError ("Connection not ready" )
1745
1763
except BaseException :
0 commit comments