Skip to content

Commit 994da43

Browse files
committed
Add unittest for PubSub reconnect
1 parent 5c99e27 commit 994da43

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

tests/test_asyncio/test_pubsub.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import functools
23
import sys
34
from typing import Optional
45

@@ -20,6 +21,18 @@
2021
pytestmark = pytest.mark.asyncio(forbid_global_loop=True)
2122

2223

24+
def with_timeout(t):
25+
def wrapper(corofunc):
26+
@functools.wraps(corofunc)
27+
async def run(*args, **kwargs):
28+
async with async_timeout.timeout(t):
29+
return await corofunc(*args, **kwargs)
30+
31+
return run
32+
33+
return wrapper
34+
35+
2336
async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
2437
now = asyncio.get_event_loop().time()
2538
timeout = now + timeout
@@ -603,6 +616,76 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis):
603616
assert await p.get_message(timeout=0.01) is None
604617

605618

619+
@pytest.mark.onlynoncluster
620+
class TestPubSubReconnect:
621+
# @pytest.mark.xfail
622+
@with_timeout(2)
623+
async def test_reconnect_listen(self, r: redis.Redis):
624+
"""
625+
Test that a loop processing PubSub messages can survive
626+
a disconnect, by issuing a connect() call.
627+
"""
628+
messages = asyncio.Queue()
629+
pubsub = r.pubsub()
630+
interrupt = False
631+
632+
async def loop():
633+
# must make sure the task exits
634+
async with async_timeout.timeout(2):
635+
nonlocal interrupt
636+
await pubsub.subscribe("foo")
637+
while True:
638+
# print("loop")
639+
try:
640+
try:
641+
await pubsub.connect()
642+
await loop_step()
643+
# print("succ")
644+
except redis.ConnectionError:
645+
err = True
646+
# print("err")
647+
await asyncio.sleep(0.1)
648+
except asyncio.CancelledError:
649+
# we use a cancel to interrupt the "listen" when we perform a disconnect
650+
# print("cancel", interrupt)
651+
if interrupt:
652+
interrupt = False
653+
else:
654+
raise
655+
656+
async def loop_step():
657+
# get a single message via listen()
658+
async for message in pubsub.listen():
659+
await messages.put(message)
660+
break
661+
662+
task = asyncio.get_event_loop().create_task(loop())
663+
# get the initial connect message
664+
async with async_timeout.timeout(1):
665+
message = await messages.get()
666+
assert message == {
667+
"channel": b"foo",
668+
"data": 1,
669+
"pattern": None,
670+
"type": "subscribe",
671+
}
672+
# now, disconnect the connection.
673+
await pubsub.connection.disconnect()
674+
interrupt = True
675+
task.cancel() # interrupt the listen call
676+
# await another auto-connect message
677+
message = await messages.get()
678+
assert message == {
679+
"channel": b"foo",
680+
"data": 1,
681+
"pattern": None,
682+
"type": "subscribe",
683+
}
684+
task.cancel()
685+
with pytest.raises(asyncio.CancelledError):
686+
await task
687+
688+
606689
@pytest.mark.onlynoncluster
607690
class TestPubSubRun:
608691
async def _subscribe(self, p, *args, **kwargs):

0 commit comments

Comments
 (0)