Skip to content

implement redis cluster broker #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions taskiq_redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
RedisAsyncResultBackend,
)
from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker
from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker
from taskiq_redis.schedule_source import RedisScheduleSource

__all__ = [
"RedisAsyncClusterResultBackend",
"RedisAsyncResultBackend",
"ListQueueBroker",
"PubSubBroker",
"ListQueueClusterBroker",
"RedisScheduleSource",
]
67 changes: 67 additions & 0 deletions taskiq_redis/redis_cluster_broker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Any, AsyncGenerator

from redis.asyncio import RedisCluster
from taskiq.abc.broker import AsyncBroker
from taskiq.message import BrokerMessage


class BaseRedisClusterBroker(AsyncBroker):
"""Base broker that works with Redis Cluster."""

def __init__(
self,
url: str,
queue_name: str = "taskiq",
max_connection_pool_size: int = 2**31,
**connection_kwargs: Any,
) -> None:
"""
Constructs a new broker.

:param url: url to redis.
:param queue_name: name for a list in redis.
:param max_connection_pool_size: maximum number of connections in pool.
:param connection_kwargs: additional arguments for aio-redis ConnectionPool.
"""
super().__init__()

self.redis: RedisCluster[bytes] = RedisCluster.from_url(
url=url,
max_connections=max_connection_pool_size,
**connection_kwargs,
)

self.queue_name = queue_name

async def shutdown(self) -> None:
"""Closes redis connection pool."""
await self.redis.aclose() # type: ignore[attr-defined]
await super().shutdown()


class ListQueueClusterBroker(BaseRedisClusterBroker):
"""Broker that works with Redis Cluster and distributes tasks between workers."""

async def kick(self, message: BrokerMessage) -> None:
"""
Put a message in a list.

This method appends a message to the list of all messages.

:param message: message to append.
"""
await self.redis.lpush(self.queue_name, message.message) # type: ignore[attr-defined]

async def listen(self) -> AsyncGenerator[bytes, None]:
"""
Listen redis queue for new messages.

This function listens to the queue
and yields new messages if they have BrokerMessage type.

:yields: broker messages.
"""
redis_brpop_data_position = 1
while True:
value = await self.redis.brpop([self.queue_name]) # type: ignore[attr-defined]
yield value[redis_brpop_data_position]
29 changes: 28 additions & 1 deletion tests/test_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from taskiq import AckableMessage, AsyncBroker, BrokerMessage

from taskiq_redis import ListQueueBroker, PubSubBroker
from taskiq_redis import ListQueueBroker, ListQueueClusterBroker, PubSubBroker


def test_no_url_should_raise_typeerror() -> None:
Expand Down Expand Up @@ -96,3 +96,30 @@ async def test_list_queue_broker(
worker1_task.cancel()
worker2_task.cancel()
await broker.shutdown()


@pytest.mark.anyio
async def test_list_queue_cluster_broker(
valid_broker_message: BrokerMessage,
redis_cluster_url: str,
) -> None:
"""
Test that messages are published and read correctly by ListQueueClusterBroker.

We create two workers that listen and send a message to them.
Expect only one worker to receive the same message we sent.
"""
broker = ListQueueClusterBroker(
url=redis_cluster_url,
queue_name=uuid.uuid4().hex,
)
worker_task = asyncio.create_task(get_message(broker))
await asyncio.sleep(0.3)

await broker.kick(valid_broker_message)
await asyncio.sleep(0.3)

assert worker_task.done()
assert worker_task.result() == valid_broker_message.message
worker_task.cancel()
await broker.shutdown()