From 05cf89dfc7cc43b33a0828134f786e0ece4f75b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Mus=C3=ADlek?= Date: Tue, 5 Dec 2023 11:00:38 +0100 Subject: [PATCH] Add cluster schedule source --- taskiq_redis/__init__.py | 6 +- taskiq_redis/schedule_source.py | 81 ++++++++++++++++- tests/test_schedule_source.py | 150 +++++++++++++++++++++++++++++++- 3 files changed, 234 insertions(+), 3 deletions(-) diff --git a/taskiq_redis/__init__.py b/taskiq_redis/__init__.py index 5a17567..b8262a1 100644 --- a/taskiq_redis/__init__.py +++ b/taskiq_redis/__init__.py @@ -5,7 +5,10 @@ ) from taskiq_redis.redis_broker import ListQueueBroker, PubSubBroker from taskiq_redis.redis_cluster_broker import ListQueueClusterBroker -from taskiq_redis.schedule_source import RedisScheduleSource +from taskiq_redis.schedule_source import ( + RedisClusterScheduleSource, + RedisScheduleSource, +) __all__ = [ "RedisAsyncClusterResultBackend", @@ -14,4 +17,5 @@ "PubSubBroker", "ListQueueClusterBroker", "RedisScheduleSource", + "RedisClusterScheduleSource", ] diff --git a/taskiq_redis/schedule_source.py b/taskiq_redis/schedule_source.py index 57ddc61..17ed1ee 100644 --- a/taskiq_redis/schedule_source.py +++ b/taskiq_redis/schedule_source.py @@ -1,6 +1,6 @@ from typing import Any, List, Optional -from redis.asyncio import ConnectionPool, Redis +from redis.asyncio import ConnectionPool, Redis, RedisCluster from taskiq import ScheduleSource from taskiq.abc.serializer import TaskiqSerializer from taskiq.compat import model_dump, model_validate @@ -95,3 +95,82 @@ async def post_send(self, task: ScheduledTask) -> None: async def shutdown(self) -> None: """Shut down the schedule source.""" await self.connection_pool.disconnect() + + +class RedisClusterScheduleSource(ScheduleSource): + """ + Source of schedules for redis cluster. + + This class allows you to store schedules in redis. + Also it supports dynamic schedules. + + :param url: url to redis cluster. + :param prefix: prefix for redis schedule keys. + :param buffer_size: buffer size for redis scan. + This is how many keys will be fetched at once. + :param max_connection_pool_size: maximum number of connections in pool. + :param serializer: serializer for data. + :param connection_kwargs: additional arguments for RedisCluster. + """ + + def __init__( + self, + url: str, + prefix: str = "schedule", + buffer_size: int = 50, + serializer: Optional[TaskiqSerializer] = None, + **connection_kwargs: Any, + ) -> None: + self.prefix = prefix + self.redis: RedisCluster[bytes] = RedisCluster.from_url( + url, + **connection_kwargs, + ) + self.buffer_size = buffer_size + if serializer is None: + serializer = PickleSerializer() + self.serializer = serializer + + async def delete_schedule(self, schedule_id: str) -> None: + """Remove schedule by id.""" + await self.redis.delete(f"{self.prefix}:{schedule_id}") # type: ignore[attr-defined] + + async def add_schedule(self, schedule: ScheduledTask) -> None: + """ + Add schedule to redis. + + :param schedule: schedule to add. + :param schedule_id: schedule id. + """ + await self.redis.set( # type: ignore[attr-defined] + f"{self.prefix}:{schedule.schedule_id}", + self.serializer.dumpb(model_dump(schedule)), + ) + + async def get_schedules(self) -> List[ScheduledTask]: + """ + Get all schedules from redis. + + This method is used by scheduler to get all schedules. + + :return: list of schedules. + """ + schedules = [] + buffer = [] + async for key in self.redis.scan_iter(f"{self.prefix}:*"): # type: ignore[attr-defined] + buffer.append(key) + if len(buffer) >= self.buffer_size: + schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined] + buffer = [] + if buffer: + schedules.extend(await self.redis.mget(buffer)) # type: ignore[attr-defined] + return [ + model_validate(ScheduledTask, self.serializer.loadb(schedule)) + for schedule in schedules + if schedule + ] + + async def post_send(self, task: ScheduledTask) -> None: + """Delete a task after it's completed.""" + if task.time is not None: + await self.delete_schedule(task.schedule_id) diff --git a/tests/test_schedule_source.py b/tests/test_schedule_source.py index b257204..b9c1685 100644 --- a/tests/test_schedule_source.py +++ b/tests/test_schedule_source.py @@ -1,9 +1,10 @@ +import datetime as dt import uuid import pytest from taskiq import ScheduledTask -from taskiq_redis import RedisScheduleSource +from taskiq_redis import RedisClusterScheduleSource, RedisScheduleSource @pytest.mark.anyio @@ -56,6 +57,153 @@ async def test_post_run_cron(redis_url: str) -> None: cron="* * * * *", ) await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [schedule] + await source.shutdown() + + +@pytest.mark.anyio +async def test_post_run_time(redis_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisScheduleSource(redis_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=dt.datetime(2000, 1, 1), + ) + await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [] + await source.shutdown() + + +@pytest.mark.anyio +async def test_buffer(redis_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisScheduleSource(redis_url, prefix=prefix, buffer_size=1) + schedule1 = ScheduledTask( + task_name="test_task1", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + schedule2 = ScheduledTask( + task_name="test_task2", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule1) + await source.add_schedule(schedule2) + schedules = await source.get_schedules() + assert len(schedules) == 2 + assert schedule1 in schedules + assert schedule2 in schedules + await source.shutdown() + + +@pytest.mark.anyio +async def test_cluster_set_schedule(redis_cluster_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) + schedules = await source.get_schedules() + assert schedules == [schedule] + await source.shutdown() + + +@pytest.mark.anyio +async def test_cluster_delete_schedule(redis_cluster_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) schedules = await source.get_schedules() assert schedules == [schedule] + await source.delete_schedule(schedule.schedule_id) + schedules = await source.get_schedules() + # Schedules are empty. + assert not schedules + await source.shutdown() + + +@pytest.mark.anyio +async def test_cluster_post_run_cron(redis_cluster_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [schedule] + await source.shutdown() + + +@pytest.mark.anyio +async def test_cluster_post_run_time(redis_cluster_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix) + schedule = ScheduledTask( + task_name="test_task", + labels={}, + args=[], + kwargs={}, + time=dt.datetime(2000, 1, 1), + ) + await source.add_schedule(schedule) + assert await source.get_schedules() == [schedule] + await source.post_send(schedule) + assert await source.get_schedules() == [] + await source.shutdown() + + +@pytest.mark.anyio +async def test_cluster_buffer(redis_cluster_url: str) -> None: + prefix = uuid.uuid4().hex + source = RedisClusterScheduleSource(redis_cluster_url, prefix=prefix, buffer_size=1) + schedule1 = ScheduledTask( + task_name="test_task1", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + schedule2 = ScheduledTask( + task_name="test_task2", + labels={}, + args=[], + kwargs={}, + cron="* * * * *", + ) + await source.add_schedule(schedule1) + await source.add_schedule(schedule2) + schedules = await source.get_schedules() + assert len(schedules) == 2 + assert schedule1 in schedules + assert schedule2 in schedules await source.shutdown()