Skip to content

Commit df41114

Browse files
Update cache check logic (#216)
Refactors the cache check logic into a single helper method. Adjusts the cache check and processing logic to better handle `return_fields` configurations.
1 parent 973d431 commit df41114

File tree

2 files changed

+42
-37
lines changed

2 files changed

+42
-37
lines changed

redisvl/extensions/llmcache/semantic.py

+37-36
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import Any, Dict, List, Optional
23

34
from redis import Redis
@@ -341,8 +342,10 @@ def check(
341342
prompt="What is the captial city of France?"
342343
)
343344
"""
344-
if not (prompt or vector):
345+
if not any([prompt, vector]):
345346
raise ValueError("Either prompt or vector must be specified.")
347+
if return_fields and not isinstance(return_fields, list):
348+
raise TypeError("Return fields must be a list of values.")
346349

347350
# overrides
348351
distance_threshold = distance_threshold or self._distance_threshold
@@ -359,25 +362,14 @@ def check(
359362
filter_expression=filter_expression,
360363
)
361364

362-
cache_hits: List[Dict[Any, str]] = []
363-
364365
# Search the cache!
365366
cache_search_results = self._index.query(query)
366-
367-
for cache_search_result in cache_search_results:
368-
redis_key = cache_search_result.pop("id")
369-
self._refresh_ttl(redis_key)
370-
371-
# Create and process cache hit
372-
cache_hit = CacheHit(**cache_search_result)
373-
cache_hit_dict = cache_hit.to_dict()
374-
# Filter down to only selected return fields if needed
375-
if isinstance(return_fields, list) and len(return_fields) > 0:
376-
cache_hit_dict = {
377-
k: v for k, v in cache_hit_dict.items() if k in return_fields
378-
}
379-
cache_hit_dict[self.redis_key_field_name] = redis_key
380-
cache_hits.append(cache_hit_dict)
367+
redis_keys, cache_hits = self._process_cache_results(
368+
cache_search_results, return_fields # type: ignore
369+
)
370+
# Extend TTL on keys
371+
for key in redis_keys:
372+
self._refresh_ttl(key)
381373

382374
return cache_hits
383375

@@ -431,19 +423,16 @@ async def acheck(
431423
"""
432424
aindex = await self._get_async_index()
433425

434-
if not (prompt or vector):
426+
if not any([prompt, vector]):
435427
raise ValueError("Either prompt or vector must be specified.")
428+
if return_fields and not isinstance(return_fields, list):
429+
raise TypeError("Return fields must be a list of values.")
436430

437431
# overrides
438432
distance_threshold = distance_threshold or self._distance_threshold
439-
return_fields = return_fields or self.return_fields
440433
vector = vector or await self._avectorize_prompt(prompt)
441-
442434
self._check_vector_dims(vector)
443435

444-
if not isinstance(return_fields, list):
445-
raise TypeError("return_fields must be a list of field names")
446-
447436
query = RangeQuery(
448437
vector=vector,
449438
vector_field_name=self.vector_field_name,
@@ -454,24 +443,36 @@ async def acheck(
454443
filter_expression=filter_expression,
455444
)
456445

457-
cache_hits: List[Dict[Any, str]] = []
458-
459446
# Search the cache!
460447
cache_search_results = await aindex.query(query)
448+
redis_keys, cache_hits = self._process_cache_results(
449+
cache_search_results, return_fields # type: ignore
450+
)
451+
# Extend TTL on keys
452+
asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys])
461453

462-
for cache_search_result in cache_search_results:
463-
key = cache_search_result["id"]
464-
await self._async_refresh_ttl(key)
454+
return cache_hits
465455

466-
# Create cache hit
456+
def _process_cache_results(
457+
self, cache_search_results: List[Dict[str, Any]], return_fields: List[str]
458+
):
459+
redis_keys: List[str] = []
460+
cache_hits: List[Dict[Any, str]] = []
461+
for cache_search_result in cache_search_results:
462+
# Pop the redis key from the result
463+
redis_key = cache_search_result.pop("id")
464+
redis_keys.append(redis_key)
465+
# Create and process cache hit
467466
cache_hit = CacheHit(**cache_search_result)
468-
cache_hit_dict = {
469-
k: v for k, v in cache_hit.to_dict().items() if k in return_fields
470-
}
471-
cache_hit_dict["key"] = key
467+
cache_hit_dict = cache_hit.to_dict()
468+
# Filter down to only selected return fields if needed
469+
if isinstance(return_fields, list) and len(return_fields) > 0:
470+
cache_hit_dict = {
471+
k: v for k, v in cache_hit_dict.items() if k in return_fields
472+
}
473+
cache_hit_dict[self.redis_key_field_name] = redis_key
472474
cache_hits.append(cache_hit_dict)
473-
474-
return cache_hits
475+
return redis_keys, cache_hits
475476

476477
def store(
477478
self,

tests/integration/test_llmcache.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from collections import namedtuple
23
from time import sleep, time
34

@@ -297,7 +298,7 @@ async def test_async_ttl_refresh(cache_with_ttl, vectorizer):
297298
await cache_with_ttl.astore(prompt, response, vector=vector)
298299

299300
for _ in range(3):
300-
sleep(1)
301+
await asyncio.sleep(1)
301302
check_result = await cache_with_ttl.acheck(vector=vector)
302303

303304
assert len(check_result) == 1
@@ -465,6 +466,9 @@ def test_check_invalid_input(cache):
465466
with pytest.raises(ValueError):
466467
cache.check()
467468

469+
with pytest.raises(TypeError):
470+
cache.check(prompt="test", return_fields="bad value")
471+
468472

469473
@pytest.mark.asyncio
470474
async def test_async_check_invalid_input(cache):

0 commit comments

Comments
 (0)