@@ -1288,18 +1288,17 @@ def __init__(
1288
1288
self .shard_hint = shard_hint
1289
1289
self .ignore_subscribe_messages = ignore_subscribe_messages
1290
1290
self .connection = None
1291
+ self .subscribed_event = threading .Event ()
1291
1292
# we need to know the encoding options for this connection in order
1292
1293
# to lookup channel and pattern names for callback handlers.
1293
1294
self .encoder = encoder
1294
1295
if self .encoder is None :
1295
1296
self .encoder = self .connection_pool .get_encoder ()
1297
+ self .health_check_response_b = self .encoder .encode (self .HEALTH_CHECK_MESSAGE )
1296
1298
if self .encoder .decode_responses :
1297
1299
self .health_check_response = ["pong" , self .HEALTH_CHECK_MESSAGE ]
1298
1300
else :
1299
- self .health_check_response = [
1300
- b"pong" ,
1301
- self .encoder .encode (self .HEALTH_CHECK_MESSAGE ),
1302
- ]
1301
+ self .health_check_response = [b"pong" , self .health_check_response_b ]
1303
1302
self .reset ()
1304
1303
1305
1304
def __enter__ (self ):
@@ -1324,9 +1323,11 @@ def reset(self):
1324
1323
self .connection_pool .release (self .connection )
1325
1324
self .connection = None
1326
1325
self .channels = {}
1326
+ self .health_check_response_counter = 0
1327
1327
self .pending_unsubscribe_channels = set ()
1328
1328
self .patterns = {}
1329
1329
self .pending_unsubscribe_patterns = set ()
1330
+ self .subscribed_event .clear ()
1330
1331
1331
1332
def close (self ):
1332
1333
self .reset ()
@@ -1352,7 +1353,7 @@ def on_connect(self, connection):
1352
1353
@property
1353
1354
def subscribed (self ):
1354
1355
"Indicates if there are subscriptions to any channels or patterns"
1355
- return bool ( self .channels or self . patterns )
1356
+ return self .subscribed_event . is_set ( )
1356
1357
1357
1358
def execute_command (self , * args ):
1358
1359
"Execute a publish/subscribe command"
@@ -1370,8 +1371,28 @@ def execute_command(self, *args):
1370
1371
self .connection .register_connect_callback (self .on_connect )
1371
1372
connection = self .connection
1372
1373
kwargs = {"check_health" : not self .subscribed }
1374
+ if not self .subscribed :
1375
+ self .clean_health_check_responses ()
1373
1376
self ._execute (connection , connection .send_command , * args , ** kwargs )
1374
1377
1378
+ def clean_health_check_responses (self ):
1379
+ """
1380
+ If any health check responses are present, clean them
1381
+ """
1382
+ ttl = 10
1383
+ conn = self .connection
1384
+ while self .health_check_response_counter > 0 and ttl > 0 :
1385
+ if self ._execute (conn , conn .can_read , timeout = conn .socket_timeout ):
1386
+ response = self ._execute (conn , conn .read_response )
1387
+ if self .is_health_check_response (response ):
1388
+ self .health_check_response_counter -= 1
1389
+ else :
1390
+ raise PubSubError (
1391
+ "A non health check response was cleaned by "
1392
+ "execute_command: {0}" .format (response )
1393
+ )
1394
+ ttl -= 1
1395
+
1375
1396
def _disconnect_raise_connect (self , conn , error ):
1376
1397
"""
1377
1398
Close the connection and raise an exception
@@ -1411,11 +1432,23 @@ def parse_response(self, block=True, timeout=0):
1411
1432
return None
1412
1433
response = self ._execute (conn , conn .read_response )
1413
1434
1414
- if conn . health_check_interval and response == self .health_check_response :
1435
+ if self .is_health_check_response ( response ) :
1415
1436
# ignore the health check message as user might not expect it
1437
+ self .health_check_response_counter -= 1
1416
1438
return None
1417
1439
return response
1418
1440
1441
+ def is_health_check_response (self , response ):
1442
+ """
1443
+ Check if the response is a health check response.
1444
+ If there are no subscriptions redis responds to PING command with a
1445
+ bulk response, instead of a multi-bulk with "pong" and the response.
1446
+ """
1447
+ return response in [
1448
+ self .health_check_response , # If there was a subscription
1449
+ self .health_check_response_b , # If there wasn't
1450
+ ]
1451
+
1419
1452
def check_health (self ):
1420
1453
conn = self .connection
1421
1454
if conn is None :
@@ -1426,6 +1459,7 @@ def check_health(self):
1426
1459
1427
1460
if conn .health_check_interval and time .time () > conn .next_health_check :
1428
1461
conn .send_command ("PING" , self .HEALTH_CHECK_MESSAGE , check_health = False )
1462
+ self .health_check_response_counter += 1
1429
1463
1430
1464
def _normalize_keys (self , data ):
1431
1465
"""
@@ -1455,6 +1489,11 @@ def psubscribe(self, *args, **kwargs):
1455
1489
# for the reconnection.
1456
1490
new_patterns = self ._normalize_keys (new_patterns )
1457
1491
self .patterns .update (new_patterns )
1492
+ if not self .subscribed :
1493
+ # Set the subscribed_event flag to True
1494
+ self .subscribed_event .set ()
1495
+ # Clear the health check counter
1496
+ self .health_check_response_counter = 0
1458
1497
self .pending_unsubscribe_patterns .difference_update (new_patterns )
1459
1498
return ret_val
1460
1499
@@ -1489,6 +1528,11 @@ def subscribe(self, *args, **kwargs):
1489
1528
# for the reconnection.
1490
1529
new_channels = self ._normalize_keys (new_channels )
1491
1530
self .channels .update (new_channels )
1531
+ if not self .subscribed :
1532
+ # Set the subscribed_event flag to True
1533
+ self .subscribed_event .set ()
1534
+ # Clear the health check counter
1535
+ self .health_check_response_counter = 0
1492
1536
self .pending_unsubscribe_channels .difference_update (new_channels )
1493
1537
return ret_val
1494
1538
@@ -1520,6 +1564,20 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
1520
1564
before returning. Timeout should be specified as a floating point
1521
1565
number.
1522
1566
"""
1567
+ if not self .subscribed :
1568
+ # Wait for subscription
1569
+ start_time = time .time ()
1570
+ if self .subscribed_event .wait (timeout ) is True :
1571
+ # The connection was subscribed during the timeout time frame.
1572
+ # The timeout should be adjusted based on the time spent
1573
+ # waiting for the subscription
1574
+ time_spent = time .time () - start_time
1575
+ timeout = max (0.0 , timeout - time_spent )
1576
+ else :
1577
+ # The connection isn't subscribed to any channels or patterns,
1578
+ # so no messages are available
1579
+ return None
1580
+
1523
1581
response = self .parse_response (block = False , timeout = timeout )
1524
1582
if response :
1525
1583
return self .handle_message (response , ignore_subscribe_messages )
@@ -1575,6 +1633,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
1575
1633
if channel in self .pending_unsubscribe_channels :
1576
1634
self .pending_unsubscribe_channels .remove (channel )
1577
1635
self .channels .pop (channel , None )
1636
+ if not self .channels and not self .patterns :
1637
+ # There are no subscriptions anymore, set subscribed_event flag
1638
+ # to false
1639
+ self .subscribed_event .clear ()
1578
1640
1579
1641
if message_type in self .PUBLISH_MESSAGE_TYPES :
1580
1642
# if there's a message handler, invoke it
0 commit comments