diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index f71ec7e04d..4037040f7c 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -1,4 +1,5 @@ import asyncio +import functools import sys from typing import Optional @@ -20,6 +21,18 @@ pytestmark = pytest.mark.asyncio(forbid_global_loop=True) +def with_timeout(t): + def wrapper(corofunc): + @functools.wraps(corofunc) + async def run(*args, **kwargs): + async with async_timeout.timeout(t): + return await corofunc(*args, **kwargs) + + return run + + return wrapper + + async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): now = asyncio.get_event_loop().time() timeout = now + timeout @@ -603,6 +616,75 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis): assert await p.get_message(timeout=0.01) is None +@pytest.mark.onlynoncluster +class TestPubSubReconnect: + # @pytest.mark.xfail + @with_timeout(2) + async def test_reconnect_listen(self, r: redis.Redis): + """ + Test that a loop processing PubSub messages can survive + a disconnect, by issuing a connect() call. + """ + messages = asyncio.Queue() + pubsub = r.pubsub() + interrupt = False + + async def loop(): + # must make sure the task exits + async with async_timeout.timeout(2): + nonlocal interrupt + await pubsub.subscribe("foo") + while True: + # print("loop") + try: + try: + await pubsub.connect() + await loop_step() + # print("succ") + except redis.ConnectionError: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + # we use a cancel to interrupt the "listen" + # when we perform a disconnect + # print("cancel", interrupt) + if interrupt: + interrupt = False + else: + raise + + async def loop_step(): + # get a single message via listen() + async for message in pubsub.listen(): + await messages.put(message) + break + + task = asyncio.get_event_loop().create_task(loop()) + # get the initial connect message + async with async_timeout.timeout(1): + message = await messages.get() + assert message == { + "channel": b"foo", + "data": 1, + "pattern": None, + "type": "subscribe", + } + # now, disconnect the connection. + await pubsub.connection.disconnect() + interrupt = True + task.cancel() # interrupt the listen call + # await another auto-connect message + message = await messages.get() + assert message == { + "channel": b"foo", + "data": 1, + "pattern": None, + "type": "subscribe", + } + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + @pytest.mark.onlynoncluster class TestPubSubRun: async def _subscribe(self, p, *args, **kwargs):