From 210ab2d767918256feb5ec85cafe0d9e09283bd5 Mon Sep 17 00:00:00 2001 From: Graeme Holliday Date: Mon, 23 Dec 2024 11:34:26 -0500 Subject: [PATCH] add optional prefix to redis keys --- taskiq_redis/redis_backend.py | 60 ++++++++++++++++++++++++----------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/taskiq_redis/redis_backend.py b/taskiq_redis/redis_backend.py index 104d8cc..1c46557 100644 --- a/taskiq_redis/redis_backend.py +++ b/taskiq_redis/redis_backend.py @@ -56,6 +56,7 @@ def __init__( result_px_time: Optional[int] = None, max_connection_pool_size: Optional[int] = None, serializer: Optional[TaskiqSerializer] = None, + prefix_str: Optional[str] = None, **connection_kwargs: Any, ) -> None: """ @@ -82,6 +83,7 @@ def __init__( self.keep_results = keep_results self.result_ex_time = result_ex_time self.result_px_time = result_px_time + self.prefix_str = prefix_str unavailable_conditions = any( ( @@ -99,6 +101,11 @@ def __init__( "Choose either result_ex_time or result_px_time.", ) + def _task_name(self, task_id: str) -> str: + if self.prefix_str is None: + return task_id + return f"{self.prefix_str}:{task_id}" + async def shutdown(self) -> None: """Closes redis connection.""" await self.redis_pool.disconnect() @@ -119,7 +126,7 @@ async def set_result( :param result: TaskiqResult instance. """ redis_set_params: Dict[str, Union[str, int, bytes]] = { - "name": task_id, + "name": self._task_name(task_id), "value": self.serializer.dumpb(model_dump(result)), } if self.result_ex_time: @@ -139,7 +146,7 @@ async def is_result_ready(self, task_id: str) -> bool: :returns: True if the result is ready else False. """ async with Redis(connection_pool=self.redis_pool) as redis: - return bool(await redis.exists(task_id)) + return bool(await redis.exists(self._task_name(task_id))) async def get_result( self, @@ -154,14 +161,15 @@ async def get_result( :raises ResultIsMissingError: if there is no result when trying to get it. :return: task's return value. """ + task_name = self._task_name(task_id) async with Redis(connection_pool=self.redis_pool) as redis: if self.keep_results: result_value = await redis.get( - name=task_id, + name=task_name, ) else: result_value = await redis.getdel( - name=task_id, + name=task_name, ) if result_value is None: @@ -192,7 +200,7 @@ async def set_progress( :param result: task's TaskProgress instance. """ redis_set_params: Dict[str, Union[str, int, bytes]] = { - "name": task_id + PROGRESS_KEY_SUFFIX, + "name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX, "value": self.serializer.dumpb(model_dump(progress)), } if self.result_ex_time: @@ -215,7 +223,7 @@ async def get_progress( """ async with Redis(connection_pool=self.redis_pool) as redis: result_value = await redis.get( - name=task_id + PROGRESS_KEY_SUFFIX, + name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX, ) if result_value is None: @@ -237,6 +245,7 @@ def __init__( result_ex_time: Optional[int] = None, result_px_time: Optional[int] = None, serializer: Optional[TaskiqSerializer] = None, + prefix_str: Optional[str] = None, **connection_kwargs: Any, ) -> None: """ @@ -261,6 +270,7 @@ def __init__( self.keep_results = keep_results self.result_ex_time = result_ex_time self.result_px_time = result_px_time + self.prefix_str = prefix_str unavailable_conditions = any( ( @@ -278,6 +288,11 @@ def __init__( "Choose either result_ex_time or result_px_time.", ) + def _task_name(self, task_id: str) -> str: + if self.prefix_str is None: + return task_id + return f"{self.prefix_str}:{task_id}" + async def shutdown(self) -> None: """Closes redis connection.""" await self.redis.aclose() # type: ignore[attr-defined] @@ -298,7 +313,7 @@ async def set_result( :param result: TaskiqResult instance. """ redis_set_params: Dict[str, Union[str, bytes, int]] = { - "name": task_id, + "name": self._task_name(task_id), "value": self.serializer.dumpb(model_dump(result)), } if self.result_ex_time: @@ -316,7 +331,7 @@ async def is_result_ready(self, task_id: str) -> bool: :returns: True if the result is ready else False. """ - return bool(await self.redis.exists(task_id)) # type: ignore[attr-defined] + return bool(await self.redis.exists(self._task_name(task_id))) # type: ignore[attr-defined] async def get_result( self, @@ -331,13 +346,14 @@ async def get_result( :raises ResultIsMissingError: if there is no result when trying to get it. :return: task's return value. """ + task_name = self._task_name(task_id) if self.keep_results: result_value = await self.redis.get( # type: ignore[attr-defined] - name=task_id, + name=task_name, ) else: result_value = await self.redis.getdel( # type: ignore[attr-defined] - name=task_id, + name=task_name, ) if result_value is None: @@ -368,7 +384,7 @@ async def set_progress( :param result: task's TaskProgress instance. """ redis_set_params: Dict[str, Union[str, int, bytes]] = { - "name": task_id + PROGRESS_KEY_SUFFIX, + "name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX, "value": self.serializer.dumpb(model_dump(progress)), } if self.result_ex_time: @@ -389,7 +405,7 @@ async def get_progress( :return: task's TaskProgress instance. """ result_value = await self.redis.get( # type: ignore[attr-defined] - name=task_id + PROGRESS_KEY_SUFFIX, + name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX, ) if result_value is None: @@ -414,6 +430,7 @@ def __init__( min_other_sentinels: int = 0, sentinel_kwargs: Optional[Any] = None, serializer: Optional[TaskiqSerializer] = None, + prefix_str: Optional[str] = None, **connection_kwargs: Any, ) -> None: """ @@ -443,6 +460,7 @@ def __init__( self.keep_results = keep_results self.result_ex_time = result_ex_time self.result_px_time = result_px_time + self.prefix_str = prefix_str unavailable_conditions = any( ( @@ -460,6 +478,11 @@ def __init__( "Choose either result_ex_time or result_px_time.", ) + def _task_name(self, task_id: str) -> str: + if self.prefix_str is None: + return task_id + return f"{self.prefix_str}:{task_id}" + @asynccontextmanager async def _acquire_master_conn(self) -> AsyncIterator[_Redis]: async with self.sentinel.master_for(self.master_name) as redis_conn: @@ -480,7 +503,7 @@ async def set_result( :param result: TaskiqResult instance. """ redis_set_params: Dict[str, Union[str, bytes, int]] = { - "name": task_id, + "name": self._task_name(task_id), "value": self.serializer.dumpb(model_dump(result)), } if self.result_ex_time: @@ -500,7 +523,7 @@ async def is_result_ready(self, task_id: str) -> bool: :returns: True if the result is ready else False. """ async with self._acquire_master_conn() as redis: - return bool(await redis.exists(task_id)) + return bool(await redis.exists(self._task_name(task_id))) async def get_result( self, @@ -515,14 +538,15 @@ async def get_result( :raises ResultIsMissingError: if there is no result when trying to get it. :return: task's return value. """ + task_name = self._task_name(task_id) async with self._acquire_master_conn() as redis: if self.keep_results: result_value = await redis.get( - name=task_id, + name=task_name, ) else: result_value = await redis.getdel( - name=task_id, + name=task_name, ) if result_value is None: @@ -553,7 +577,7 @@ async def set_progress( :param result: task's TaskProgress instance. """ redis_set_params: Dict[str, Union[str, int, bytes]] = { - "name": task_id + PROGRESS_KEY_SUFFIX, + "name": self._task_name(task_id) + PROGRESS_KEY_SUFFIX, "value": self.serializer.dumpb(model_dump(progress)), } if self.result_ex_time: @@ -576,7 +600,7 @@ async def get_progress( """ async with self._acquire_master_conn() as redis: result_value = await redis.get( - name=task_id + PROGRESS_KEY_SUFFIX, + name=self._task_name(task_id) + PROGRESS_KEY_SUFFIX, ) if result_value is None: