|
1 | 1 | import asyncio
|
| 2 | +import functools |
2 | 3 | import sys
|
3 | 4 | from typing import Optional
|
4 | 5 |
|
|
20 | 21 | pytestmark = pytest.mark.asyncio(forbid_global_loop=True)
|
21 | 22 |
|
22 | 23 |
|
| 24 | +def with_timeout(t): |
| 25 | + def wrapper(corofunc): |
| 26 | + @functools.wraps(corofunc) |
| 27 | + async def run(*args, **kwargs): |
| 28 | + async with async_timeout.timeout(t): |
| 29 | + return await corofunc(*args, **kwargs) |
| 30 | + |
| 31 | + return run |
| 32 | + |
| 33 | + return wrapper |
| 34 | + |
| 35 | + |
23 | 36 | async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False):
|
24 | 37 | now = asyncio.get_event_loop().time()
|
25 | 38 | timeout = now + timeout
|
@@ -603,6 +616,76 @@ async def test_get_message_with_timeout_returns_none(self, r: redis.Redis):
|
603 | 616 | assert await p.get_message(timeout=0.01) is None
|
604 | 617 |
|
605 | 618 |
|
| 619 | +@pytest.mark.onlynoncluster |
| 620 | +class TestPubSubReconnect: |
| 621 | + # @pytest.mark.xfail |
| 622 | + @with_timeout(2) |
| 623 | + async def test_reconnect_listen(self, r: redis.Redis): |
| 624 | + """ |
| 625 | + Test that a loop processing PubSub messages can survive |
| 626 | + a disconnect, by issuing a connect() call. |
| 627 | + """ |
| 628 | + messages = asyncio.Queue() |
| 629 | + pubsub = r.pubsub() |
| 630 | + interrupt = False |
| 631 | + |
| 632 | + async def loop(): |
| 633 | + # must make sure the task exits |
| 634 | + async with async_timeout.timeout(2): |
| 635 | + nonlocal interrupt |
| 636 | + await pubsub.subscribe("foo") |
| 637 | + while True: |
| 638 | + # print("loop") |
| 639 | + try: |
| 640 | + try: |
| 641 | + await pubsub.connect() |
| 642 | + await loop_step() |
| 643 | + # print("succ") |
| 644 | + except redis.ConnectionError: |
| 645 | + err = True |
| 646 | + # print("err") |
| 647 | + await asyncio.sleep(0.1) |
| 648 | + except asyncio.CancelledError: |
| 649 | + # we use a cancel to interrupt the "listen" when we perform a disconnect |
| 650 | + # print("cancel", interrupt) |
| 651 | + if interrupt: |
| 652 | + interrupt = False |
| 653 | + else: |
| 654 | + raise |
| 655 | + |
| 656 | + async def loop_step(): |
| 657 | + # get a single message via listen() |
| 658 | + async for message in pubsub.listen(): |
| 659 | + await messages.put(message) |
| 660 | + break |
| 661 | + |
| 662 | + task = asyncio.get_event_loop().create_task(loop()) |
| 663 | + # get the initial connect message |
| 664 | + async with async_timeout.timeout(1): |
| 665 | + message = await messages.get() |
| 666 | + assert message == { |
| 667 | + "channel": b"foo", |
| 668 | + "data": 1, |
| 669 | + "pattern": None, |
| 670 | + "type": "subscribe", |
| 671 | + } |
| 672 | + # now, disconnect the connection. |
| 673 | + await pubsub.connection.disconnect() |
| 674 | + interrupt = True |
| 675 | + task.cancel() # interrupt the listen call |
| 676 | + # await another auto-connect message |
| 677 | + message = await messages.get() |
| 678 | + assert message == { |
| 679 | + "channel": b"foo", |
| 680 | + "data": 1, |
| 681 | + "pattern": None, |
| 682 | + "type": "subscribe", |
| 683 | + } |
| 684 | + task.cancel() |
| 685 | + with pytest.raises(asyncio.CancelledError): |
| 686 | + await task |
| 687 | + |
| 688 | + |
606 | 689 | @pytest.mark.onlynoncluster
|
607 | 690 | class TestPubSubRun:
|
608 | 691 | async def _subscribe(self, p, *args, **kwargs):
|
|
0 commit comments