1
+ import asyncio
1
2
from typing import Any , Dict , List , Optional
2
3
3
4
from redis import Redis
@@ -341,8 +342,10 @@ def check(
341
342
prompt="What is the captial city of France?"
342
343
)
343
344
"""
344
- if not ( prompt or vector ):
345
+ if not any ([ prompt , vector ] ):
345
346
raise ValueError ("Either prompt or vector must be specified." )
347
+ if return_fields and not isinstance (return_fields , list ):
348
+ raise TypeError ("Return fields must be a list of values." )
346
349
347
350
# overrides
348
351
distance_threshold = distance_threshold or self ._distance_threshold
@@ -359,25 +362,14 @@ def check(
359
362
filter_expression = filter_expression ,
360
363
)
361
364
362
- cache_hits : List [Dict [Any , str ]] = []
363
-
364
365
# Search the cache!
365
366
cache_search_results = self ._index .query (query )
366
-
367
- for cache_search_result in cache_search_results :
368
- redis_key = cache_search_result .pop ("id" )
369
- self ._refresh_ttl (redis_key )
370
-
371
- # Create and process cache hit
372
- cache_hit = CacheHit (** cache_search_result )
373
- cache_hit_dict = cache_hit .to_dict ()
374
- # Filter down to only selected return fields if needed
375
- if isinstance (return_fields , list ) and len (return_fields ) > 0 :
376
- cache_hit_dict = {
377
- k : v for k , v in cache_hit_dict .items () if k in return_fields
378
- }
379
- cache_hit_dict [self .redis_key_field_name ] = redis_key
380
- cache_hits .append (cache_hit_dict )
367
+ redis_keys , cache_hits = self ._process_cache_results (
368
+ cache_search_results , return_fields # type: ignore
369
+ )
370
+ # Extend TTL on keys
371
+ for key in redis_keys :
372
+ self ._refresh_ttl (key )
381
373
382
374
return cache_hits
383
375
@@ -431,19 +423,16 @@ async def acheck(
431
423
"""
432
424
aindex = await self ._get_async_index ()
433
425
434
- if not ( prompt or vector ):
426
+ if not any ([ prompt , vector ] ):
435
427
raise ValueError ("Either prompt or vector must be specified." )
428
+ if return_fields and not isinstance (return_fields , list ):
429
+ raise TypeError ("Return fields must be a list of values." )
436
430
437
431
# overrides
438
432
distance_threshold = distance_threshold or self ._distance_threshold
439
- return_fields = return_fields or self .return_fields
440
433
vector = vector or await self ._avectorize_prompt (prompt )
441
-
442
434
self ._check_vector_dims (vector )
443
435
444
- if not isinstance (return_fields , list ):
445
- raise TypeError ("return_fields must be a list of field names" )
446
-
447
436
query = RangeQuery (
448
437
vector = vector ,
449
438
vector_field_name = self .vector_field_name ,
@@ -454,24 +443,36 @@ async def acheck(
454
443
filter_expression = filter_expression ,
455
444
)
456
445
457
- cache_hits : List [Dict [Any , str ]] = []
458
-
459
446
# Search the cache!
460
447
cache_search_results = await aindex .query (query )
448
+ redis_keys , cache_hits = self ._process_cache_results (
449
+ cache_search_results , return_fields # type: ignore
450
+ )
451
+ # Extend TTL on keys
452
+ asyncio .gather (* [self ._async_refresh_ttl (key ) for key in redis_keys ])
461
453
462
- for cache_search_result in cache_search_results :
463
- key = cache_search_result ["id" ]
464
- await self ._async_refresh_ttl (key )
454
+ return cache_hits
465
455
466
- # Create cache hit
456
+ def _process_cache_results (
457
+ self , cache_search_results : List [Dict [str , Any ]], return_fields : List [str ]
458
+ ):
459
+ redis_keys : List [str ] = []
460
+ cache_hits : List [Dict [Any , str ]] = []
461
+ for cache_search_result in cache_search_results :
462
+ # Pop the redis key from the result
463
+ redis_key = cache_search_result .pop ("id" )
464
+ redis_keys .append (redis_key )
465
+ # Create and process cache hit
467
466
cache_hit = CacheHit (** cache_search_result )
468
- cache_hit_dict = {
469
- k : v for k , v in cache_hit .to_dict ().items () if k in return_fields
470
- }
471
- cache_hit_dict ["key" ] = key
467
+ cache_hit_dict = cache_hit .to_dict ()
468
+ # Filter down to only selected return fields if needed
469
+ if isinstance (return_fields , list ) and len (return_fields ) > 0 :
470
+ cache_hit_dict = {
471
+ k : v for k , v in cache_hit_dict .items () if k in return_fields
472
+ }
473
+ cache_hit_dict [self .redis_key_field_name ] = redis_key
472
474
cache_hits .append (cache_hit_dict )
473
-
474
- return cache_hits
475
+ return redis_keys , cache_hits
475
476
476
477
def store (
477
478
self ,
0 commit comments