Skip to content

Commit d6cb997

Browse files
authored
Fixing read race condition during pubsub (#1737)
1 parent ddc51c4 commit d6cb997

File tree

2 files changed

+102
-15
lines changed

2 files changed

+102
-15
lines changed

redis/client.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,18 +1288,17 @@ def __init__(
12881288
self.shard_hint = shard_hint
12891289
self.ignore_subscribe_messages = ignore_subscribe_messages
12901290
self.connection = None
1291+
self.subscribed_event = threading.Event()
12911292
# we need to know the encoding options for this connection in order
12921293
# to lookup channel and pattern names for callback handlers.
12931294
self.encoder = encoder
12941295
if self.encoder is None:
12951296
self.encoder = self.connection_pool.get_encoder()
1297+
self.health_check_response_b = self.encoder.encode(self.HEALTH_CHECK_MESSAGE)
12961298
if self.encoder.decode_responses:
12971299
self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE]
12981300
else:
1299-
self.health_check_response = [
1300-
b"pong",
1301-
self.encoder.encode(self.HEALTH_CHECK_MESSAGE),
1302-
]
1301+
self.health_check_response = [b"pong", self.health_check_response_b]
13031302
self.reset()
13041303

13051304
def __enter__(self):
@@ -1324,9 +1323,11 @@ def reset(self):
13241323
self.connection_pool.release(self.connection)
13251324
self.connection = None
13261325
self.channels = {}
1326+
self.health_check_response_counter = 0
13271327
self.pending_unsubscribe_channels = set()
13281328
self.patterns = {}
13291329
self.pending_unsubscribe_patterns = set()
1330+
self.subscribed_event.clear()
13301331

13311332
def close(self):
13321333
self.reset()
@@ -1352,7 +1353,7 @@ def on_connect(self, connection):
13521353
@property
13531354
def subscribed(self):
13541355
"Indicates if there are subscriptions to any channels or patterns"
1355-
return bool(self.channels or self.patterns)
1356+
return self.subscribed_event.is_set()
13561357

13571358
def execute_command(self, *args):
13581359
"Execute a publish/subscribe command"
@@ -1370,8 +1371,28 @@ def execute_command(self, *args):
13701371
self.connection.register_connect_callback(self.on_connect)
13711372
connection = self.connection
13721373
kwargs = {"check_health": not self.subscribed}
1374+
if not self.subscribed:
1375+
self.clean_health_check_responses()
13731376
self._execute(connection, connection.send_command, *args, **kwargs)
13741377

1378+
def clean_health_check_responses(self):
1379+
"""
1380+
If any health check responses are present, clean them
1381+
"""
1382+
ttl = 10
1383+
conn = self.connection
1384+
while self.health_check_response_counter > 0 and ttl > 0:
1385+
if self._execute(conn, conn.can_read, timeout=conn.socket_timeout):
1386+
response = self._execute(conn, conn.read_response)
1387+
if self.is_health_check_response(response):
1388+
self.health_check_response_counter -= 1
1389+
else:
1390+
raise PubSubError(
1391+
"A non health check response was cleaned by "
1392+
"execute_command: {0}".format(response)
1393+
)
1394+
ttl -= 1
1395+
13751396
def _disconnect_raise_connect(self, conn, error):
13761397
"""
13771398
Close the connection and raise an exception
@@ -1411,11 +1432,23 @@ def parse_response(self, block=True, timeout=0):
14111432
return None
14121433
response = self._execute(conn, conn.read_response)
14131434

1414-
if conn.health_check_interval and response == self.health_check_response:
1435+
if self.is_health_check_response(response):
14151436
# ignore the health check message as user might not expect it
1437+
self.health_check_response_counter -= 1
14161438
return None
14171439
return response
14181440

1441+
def is_health_check_response(self, response):
1442+
"""
1443+
Check if the response is a health check response.
1444+
If there are no subscriptions redis responds to PING command with a
1445+
bulk response, instead of a multi-bulk with "pong" and the response.
1446+
"""
1447+
return response in [
1448+
self.health_check_response, # If there was a subscription
1449+
self.health_check_response_b, # If there wasn't
1450+
]
1451+
14191452
def check_health(self):
14201453
conn = self.connection
14211454
if conn is None:
@@ -1426,6 +1459,7 @@ def check_health(self):
14261459

14271460
if conn.health_check_interval and time.time() > conn.next_health_check:
14281461
conn.send_command("PING", self.HEALTH_CHECK_MESSAGE, check_health=False)
1462+
self.health_check_response_counter += 1
14291463

14301464
def _normalize_keys(self, data):
14311465
"""
@@ -1455,6 +1489,11 @@ def psubscribe(self, *args, **kwargs):
14551489
# for the reconnection.
14561490
new_patterns = self._normalize_keys(new_patterns)
14571491
self.patterns.update(new_patterns)
1492+
if not self.subscribed:
1493+
# Set the subscribed_event flag to True
1494+
self.subscribed_event.set()
1495+
# Clear the health check counter
1496+
self.health_check_response_counter = 0
14581497
self.pending_unsubscribe_patterns.difference_update(new_patterns)
14591498
return ret_val
14601499

@@ -1489,6 +1528,11 @@ def subscribe(self, *args, **kwargs):
14891528
# for the reconnection.
14901529
new_channels = self._normalize_keys(new_channels)
14911530
self.channels.update(new_channels)
1531+
if not self.subscribed:
1532+
# Set the subscribed_event flag to True
1533+
self.subscribed_event.set()
1534+
# Clear the health check counter
1535+
self.health_check_response_counter = 0
14921536
self.pending_unsubscribe_channels.difference_update(new_channels)
14931537
return ret_val
14941538

@@ -1520,6 +1564,20 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
15201564
before returning. Timeout should be specified as a floating point
15211565
number.
15221566
"""
1567+
if not self.subscribed:
1568+
# Wait for subscription
1569+
start_time = time.time()
1570+
if self.subscribed_event.wait(timeout) is True:
1571+
# The connection was subscribed during the timeout time frame.
1572+
# The timeout should be adjusted based on the time spent
1573+
# waiting for the subscription
1574+
time_spent = time.time() - start_time
1575+
timeout = max(0.0, timeout - time_spent)
1576+
else:
1577+
# The connection isn't subscribed to any channels or patterns,
1578+
# so no messages are available
1579+
return None
1580+
15231581
response = self.parse_response(block=False, timeout=timeout)
15241582
if response:
15251583
return self.handle_message(response, ignore_subscribe_messages)
@@ -1575,6 +1633,10 @@ def handle_message(self, response, ignore_subscribe_messages=False):
15751633
if channel in self.pending_unsubscribe_channels:
15761634
self.pending_unsubscribe_channels.remove(channel)
15771635
self.channels.pop(channel, None)
1636+
if not self.channels and not self.patterns:
1637+
# There are no subscriptions anymore, set subscribed_event flag
1638+
# to false
1639+
self.subscribed_event.clear()
15781640

15791641
if message_type in self.PUBLISH_MESSAGE_TYPES:
15801642
# if there's a message handler, invoke it

tests/test_pubsub.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
import time
44
from unittest import mock
5+
from unittest.mock import patch
56

67
import pytest
78

@@ -348,15 +349,6 @@ def test_unicode_pattern_message_handler(self, r):
348349
"pmessage", channel, "test message", pattern=pattern
349350
)
350351

351-
def test_get_message_without_subscribe(self, r):
352-
p = r.pubsub()
353-
with pytest.raises(RuntimeError) as info:
354-
p.get_message()
355-
expect = (
356-
"connection not set: " "did you forget to call subscribe() or psubscribe()?"
357-
)
358-
assert expect in info.exconly()
359-
360352

361353
class TestPubSubAutoDecoding:
362354
"These tests only validate that we get unicode values back"
@@ -549,6 +541,39 @@ def test_get_message_with_timeout_returns_none(self, r):
549541
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
550542
assert p.get_message(timeout=0.01) is None
551543

544+
def test_get_message_not_subscribed_return_none(self, r):
545+
p = r.pubsub()
546+
assert p.subscribed is False
547+
assert p.get_message() is None
548+
assert p.get_message(timeout=0.1) is None
549+
with patch.object(threading.Event, "wait") as mock:
550+
mock.return_value = False
551+
assert p.get_message(timeout=0.01) is None
552+
assert mock.called
553+
554+
def test_get_message_subscribe_during_waiting(self, r):
555+
p = r.pubsub()
556+
557+
def poll(ps, expected_res):
558+
assert ps.get_message() is None
559+
message = ps.get_message(timeout=1)
560+
assert message == expected_res
561+
562+
subscribe_response = make_message("subscribe", "foo", 1)
563+
poller = threading.Thread(target=poll, args=(p, subscribe_response))
564+
poller.start()
565+
time.sleep(0.2)
566+
p.subscribe("foo")
567+
poller.join()
568+
569+
def test_get_message_wait_for_subscription_not_being_called(self, r):
570+
p = r.pubsub()
571+
p.subscribe("foo")
572+
with patch.object(threading.Event, "wait") as mock:
573+
assert p.subscribed is True
574+
assert wait_for_message(p) == make_message("subscribe", "foo", 1)
575+
assert mock.called is False
576+
552577

553578
class TestPubSubWorkerThread:
554579
@pytest.mark.skipif(

0 commit comments

Comments
 (0)