diff --git a/taskiq_redis/__init__.py b/taskiq_redis/__init__.py index 36cec84..5a17567 100644 --- a/taskiq_redis/__init__.py +++ b/taskiq_redis/__init__.py @@ -4,6 +4,7 @@ RedisAsyncResultBackend, ) from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker +from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker from taskiq_redis.schedule_source import RedisScheduleSource __all__ = [ @@ -11,5 +12,6 @@ "RedisAsyncResultBackend", "ListQueueBroker", "PubSubBroker", + "ListQueueClusterBroker", "RedisScheduleSource", ] diff --git a/taskiq_redis/redis_cluster_broker.py b/taskiq_redis/redis_cluster_broker.py new file mode 100644 index 0000000..af2d30f --- /dev/null +++ b/taskiq_redis/redis_cluster_broker.py @@ -0,0 +1,67 @@ +from typing import Any, AsyncGenerator + +from redis.asyncio import RedisCluster +from taskiq.abc.broker import AsyncBroker +from taskiq.message import BrokerMessage + + +class BaseRedisClusterBroker(AsyncBroker): + """Base broker that works with Redis Cluster.""" + + def __init__( + self, + url: str, + queue_name: str = "taskiq", + max_connection_pool_size: int = 2**31, + **connection_kwargs: Any, + ) -> None: + """ + Constructs a new broker. + + :param url: url to redis. + :param queue_name: name for a list in redis. + :param max_connection_pool_size: maximum number of connections in pool. + :param connection_kwargs: additional arguments for aio-redis ConnectionPool. + """ + super().__init__() + + self.redis: RedisCluster[bytes] = RedisCluster.from_url( + url=url, + max_connections=max_connection_pool_size, + **connection_kwargs, + ) + + self.queue_name = queue_name + + async def shutdown(self) -> None: + """Closes redis connection pool.""" + await self.redis.aclose() # type: ignore[attr-defined] + await super().shutdown() + + +class ListQueueClusterBroker(BaseRedisClusterBroker): + """Broker that works with Redis Cluster and distributes tasks between workers.""" + + async def kick(self, message: BrokerMessage) -> None: + """ + Put a message in a list. + + This method appends a message to the list of all messages. + + :param message: message to append. + """ + await self.redis.lpush(self.queue_name, message.message) # type: ignore[attr-defined] + + async def listen(self) -> AsyncGenerator[bytes, None]: + """ + Listen redis queue for new messages. + + This function listens to the queue + and yields new messages if they have BrokerMessage type. + + :yields: broker messages. + """ + redis_brpop_data_position = 1 + while True: + value = await self.redis.brpop([self.queue_name]) # type: ignore[attr-defined] + yield value[redis_brpop_data_position] diff --git a/tests/test_broker.py b/tests/test_broker.py index f664a96..813e72e 100644 --- a/tests/test_broker.py +++ b/tests/test_broker.py @@ -5,7 +5,7 @@ import pytest from taskiq import AckableMessage, AsyncBroker, BrokerMessage -from taskiq_redis import ListQueueBroker, PubSubBroker +from taskiq_redis import ListQueueBroker, ListQueueClusterBroker, PubSubBroker def test_no_url_should_raise_typeerror() -> None: @@ -96,3 +96,30 @@ async def test_list_queue_broker( worker1_task.cancel() worker2_task.cancel() await broker.shutdown() + + +@pytest.mark.anyio +async def test_list_queue_cluster_broker( + valid_broker_message: BrokerMessage, + redis_cluster_url: str, +) -> None: + """ + Test that messages are published and read correctly by ListQueueClusterBroker. + + We create two workers that listen and send a message to them. + Expect only one worker to receive the same message we sent. + """ + broker = ListQueueClusterBroker( + url=redis_cluster_url, + queue_name=uuid.uuid4().hex, + ) + worker_task = asyncio.create_task(get_message(broker)) + await asyncio.sleep(0.3) + + await broker.kick(valid_broker_message) + await asyncio.sleep(0.3) + + assert worker_task.done() + assert worker_task.result() == valid_broker_message.message + worker_task.cancel() + await broker.shutdown()