Skip to content

Commit afbed2b

Browse files
committed
fix #417: fix setting ttl on async pipeline
BEFORE: - await pipe.expire() triggers ClusterPipeline.__await__ - this triggers .initialize() which delete the pipeline commands AFTER: - pipe.expire() in case of pipeline - await pipe.expire() in case of classic client
1 parent 29da9b9 commit afbed2b

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

redisvl/index/storage.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,20 @@ async def _aget(
175175
"""Asynchronously get data from Redis using the provided client or pipeline."""
176176
raise NotImplementedError
177177

178+
@staticmethod
179+
async def _aexpire(client: AsyncRedisClientOrPipeline, key: str, ttl: int):
180+
"""Asynchronously set TTL on a key using the provided client or pipeline
181+
182+
Args:
183+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
184+
key (str): The key for which to set the TTL.
185+
ttl (int): Time-to-live in seconds for each key.
186+
"""
187+
if isinstance(client, (AsyncPipeline, AsyncClusterPipeline)):
188+
client.expire(key, ttl)
189+
else:
190+
await client.expire(key, ttl)
191+
178192
def _validate(self, obj: Dict[str, Any]) -> Dict[str, Any]:
179193
"""
180194
Validate an object against the schema using Pydantic-based validation.
@@ -490,7 +504,7 @@ async def awrite(
490504

491505
# Set TTL if provided
492506
if ttl:
493-
await pipe.expire(key, ttl)
507+
await self._aexpire(pipe, key, ttl)
494508

495509
added_keys.append(key)
496510

@@ -615,7 +629,7 @@ async def _aset(client: AsyncRedisClientOrPipeline, key: str, obj: Dict[str, Any
615629
"""Asynchronously set a hash value in Redis for the given key.
616630
617631
Args:
618-
client (AsyncClientOrPipeline): The async Redis client or pipeline instance.
632+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
619633
key (str): The key under which to store the hash.
620634
obj (Dict[str, Any]): The hash to store in Redis.
621635
"""
@@ -644,7 +658,7 @@ async def _aget(
644658
"""Asynchronously retrieve a hash value from Redis for the given key.
645659
646660
Args:
647-
client (AsyncRedisClient): The async Redis client or pipeline instance.
661+
client (AsyncRedisClientOrPipeline): The async Redis client or pipeline instance.
648662
key (str): The key for which to retrieve the hash.
649663
650664
Returns:

0 commit comments

Comments
 (0)