Skip to content

Commit 704038b

Browse files
fix formatting and mypy
1 parent 51b6af3 commit 704038b

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

redisvl/extensions/llmcache/schema.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import Any, Dict, List, Optional
2+
23
from pydantic.v1 import BaseModel, Field, root_validator, validator
4+
35
from redisvl.redis.utils import array_to_buffer, hashify
4-
from redisvl.utils.utils import current_timestamp, deserialize, serialize
56
from redisvl.schema import IndexSchema
7+
from redisvl.utils.utils import current_timestamp, deserialize, serialize
68

79

810
class CacheEntry(BaseModel):
9-
entry_id: str
11+
entry_id: Optional[str] = Field(default=None)
1012
prompt: str
1113
response: str
1214
prompt_vector: List[float]
@@ -103,4 +105,4 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
103105
},
104106
},
105107
],
106-
)
108+
)

redisvl/extensions/llmcache/semantic.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
from redis import Redis
44

55
from redisvl.extensions.llmcache.base import BaseLLMCache
6-
from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit, SemanticCacheIndexSchema
6+
from redisvl.extensions.llmcache.schema import (
7+
CacheEntry,
8+
CacheHit,
9+
SemanticCacheIndexSchema,
10+
)
711
from redisvl.index import SearchIndex
812
from redisvl.query import RangeQuery
913
from redisvl.query.filter import FilterExpression, Tag
@@ -92,8 +96,13 @@ def __init__(
9296

9397
if filterable_fields is not None:
9498
for filter_field in filterable_fields:
95-
if filter_field["name"] in self.return_fields or filter_field["name"] =="key":
96-
raise ValueError(f'{filter_field["name"]} is a reserved field name for the semantic cache schema')
99+
if (
100+
filter_field["name"] in self.return_fields
101+
or filter_field["name"] == "key"
102+
):
103+
raise ValueError(
104+
f'{filter_field["name"]} is a reserved field name for the semantic cache schema'
105+
)
97106
schema.add_field(filter_field)
98107
# Add to return fields too
99108
self.return_fields.append(filter_field["name"])
@@ -285,7 +294,9 @@ def check(
285294

286295
# Create cache hit
287296
cache_hit = CacheHit(**cache_search_result)
288-
cache_hit_dict = {k: v for k, v in cache_hit.to_dict().items() if k in return_fields}
297+
cache_hit_dict = {
298+
k: v for k, v in cache_hit.to_dict().items() if k in return_fields
299+
}
289300
cache_hit_dict["key"] = key
290301
cache_hits.append(cache_hit_dict)
291302

@@ -370,7 +381,9 @@ def update(self, key: str, **kwargs) -> None:
370381
for k, v in kwargs.items():
371382

372383
# Make sure the item is in the index schema
373-
if k not in set(self._index.schema.field_names + [self.metadata_field_name]):
384+
if k not in set(
385+
self._index.schema.field_names + [self.metadata_field_name]
386+
):
374387
raise ValueError(f"{k} is not a valid field within the cache entry")
375388

376389
# Check for metadata and deserialize
@@ -384,6 +397,6 @@ def update(self, key: str, **kwargs) -> None:
384397

385398
kwargs.update({self.updated_at_field_name: current_timestamp()})
386399

387-
self._index.client.hset(key, mapping=kwargs) # type: ignore
400+
self._index.client.hset(key, mapping=kwargs) # type: ignore
388401

389402
self._refresh_ttl(key)

tests/integration/test_llmcache.py

+34-14
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from collections import namedtuple
22
from time import sleep, time
3-
from pydantic.v1 import ValidationError
4-
import pytest
53

4+
import pytest
5+
from pydantic.v1 import ValidationError
66
from redis.exceptions import ConnectionError
77

88
from redisvl.extensions.llmcache import SemanticCache
@@ -24,13 +24,14 @@ def cache(vectorizer, redis_url):
2424
yield cache_instance
2525
cache_instance._index.delete(True) # Clean up index
2626

27+
2728
@pytest.fixture
2829
def cache_with_filters(vectorizer, redis_url):
2930
cache_instance = SemanticCache(
3031
vectorizer=vectorizer,
3132
distance_threshold=0.2,
3233
filterable_fields=[{"name": "label", "type": "tag"}],
33-
redis_url=redis_url
34+
redis_url=redis_url,
3435
)
3536
yield cache_instance
3637
cache_instance._index.delete(True) # Clean up index
@@ -411,13 +412,17 @@ def test_cache_filtering(cache_with_filters):
411412
cache_with_filters.store(prompt, response, filters={"label": tags[i]})
412413

413414
# test we can specify one specific tag
414-
results = cache_with_filters.check("test prompt 1", filter_expression=filter_1, num_results=5)
415+
results = cache_with_filters.check(
416+
"test prompt 1", filter_expression=filter_1, num_results=5
417+
)
415418
assert len(results) == 1
416419
assert results[0]["prompt"] == "test prompt 0"
417420

418421
# test we can pass a list of tags
419422
combined_filter = filter_1 | filter_2 | filter_3
420-
results = cache_with_filters.check("test prompt 1", filter_expression=combined_filter, num_results=5)
423+
results = cache_with_filters.check(
424+
"test prompt 1", filter_expression=combined_filter, num_results=5
425+
)
421426
assert len(results) == 3
422427

423428
# test that default tag param searches full cache
@@ -426,7 +431,9 @@ def test_cache_filtering(cache_with_filters):
426431

427432
# test no results are returned if we pass a nonexistant tag
428433
bad_filter = Tag("label") == "bad tag"
429-
results = cache_with_filters.check("test prompt 1", filter_expression=bad_filter, num_results=5)
434+
results = cache_with_filters.check(
435+
"test prompt 1", filter_expression=bad_filter, num_results=5
436+
)
430437
assert len(results) == 0
431438

432439

@@ -436,26 +443,35 @@ def test_cache_bad_filters(vectorizer, redis_url):
436443
vectorizer=vectorizer,
437444
distance_threshold=0.2,
438445
# invalid field type
439-
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "test", "type": "nothing"}],
440-
redis_url=redis_url
446+
filterable_fields=[
447+
{"name": "label", "type": "tag"},
448+
{"name": "test", "type": "nothing"},
449+
],
450+
redis_url=redis_url,
441451
)
442452

443453
with pytest.raises(ValueError):
444454
cache_instance = SemanticCache(
445455
vectorizer=vectorizer,
446456
distance_threshold=0.2,
447457
# duplicate field type
448-
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "label", "type": "tag"}],
449-
redis_url=redis_url
458+
filterable_fields=[
459+
{"name": "label", "type": "tag"},
460+
{"name": "label", "type": "tag"},
461+
],
462+
redis_url=redis_url,
450463
)
451464

452465
with pytest.raises(ValueError):
453466
cache_instance = SemanticCache(
454467
vectorizer=vectorizer,
455468
distance_threshold=0.2,
456469
# reserved field name
457-
filterable_fields=[{"name": "label", "type": "tag"}, {"name": "metadata", "type": "tag"}],
458-
redis_url=redis_url
470+
filterable_fields=[
471+
{"name": "label", "type": "tag"},
472+
{"name": "metadata", "type": "tag"},
473+
],
474+
redis_url=redis_url,
459475
)
460476

461477

@@ -468,12 +484,16 @@ def test_complex_filters(cache_with_filters):
468484

469485
# test we can do range filters on inserted_at and updated_at fields
470486
range_filter = Num("inserted_at") < current_timestamp
471-
results = cache_with_filters.check("prompt 1", filter_expression=range_filter, num_results=5)
487+
results = cache_with_filters.check(
488+
"prompt 1", filter_expression=range_filter, num_results=5
489+
)
472490
assert len(results) == 2
473491

474492
# test we can combine range filters and text filters
475493
prompt_filter = Text("prompt") % "*pt 1"
476494
combined_filter = prompt_filter & range_filter
477495

478-
results = cache_with_filters.check("prompt 1", filter_expression=combined_filter, num_results=5)
496+
results = cache_with_filters.check(
497+
"prompt 1", filter_expression=combined_filter, num_results=5
498+
)
479499
assert len(results) == 1

tests/unit/test_llmcache_schema.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,51 @@
1-
import pytest
21
import json
32

3+
import pytest
44
from pydantic.v1 import ValidationError
5+
56
from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit
6-
from redisvl.redis.utils import hashify, array_to_buffer
7+
from redisvl.redis.utils import array_to_buffer, hashify
78

89

910
def test_valid_cache_entry_creation():
1011
entry = CacheEntry(
1112
prompt="What is AI?",
1213
response="AI is artificial intelligence.",
13-
prompt_vector=[0.1, 0.2, 0.3]
14+
prompt_vector=[0.1, 0.2, 0.3],
1415
)
1516
assert entry.entry_id == hashify("What is AI?")
1617
assert entry.prompt == "What is AI?"
1718
assert entry.response == "AI is artificial intelligence."
1819
assert entry.prompt_vector == [0.1, 0.2, 0.3]
1920

21+
2022
def test_cache_entry_with_given_entry_id():
2123
entry = CacheEntry(
2224
entry_id="custom_id",
2325
prompt="What is AI?",
2426
response="AI is artificial intelligence.",
25-
prompt_vector=[0.1, 0.2, 0.3]
27+
prompt_vector=[0.1, 0.2, 0.3],
2628
)
2729
assert entry.entry_id == "custom_id"
2830

31+
2932
def test_cache_entry_with_invalid_metadata():
3033
with pytest.raises(ValidationError):
3134
CacheEntry(
3235
prompt="What is AI?",
3336
response="AI is artificial intelligence.",
3437
prompt_vector=[0.1, 0.2, 0.3],
35-
metadata="invalid_metadata"
38+
metadata="invalid_metadata",
3639
)
3740

41+
3842
def test_cache_entry_to_dict():
3943
entry = CacheEntry(
4044
prompt="What is AI?",
4145
response="AI is artificial intelligence.",
4246
prompt_vector=[0.1, 0.2, 0.3],
4347
metadata={"author": "John"},
44-
filters={"category": "technology"}
48+
filters={"category": "technology"},
4549
)
4650
result = entry.to_dict()
4751
assert result["entry_id"] == hashify("What is AI?")
@@ -50,21 +54,23 @@ def test_cache_entry_to_dict():
5054
assert result["category"] == "technology"
5155
assert "filters" not in result
5256

57+
5358
def test_valid_cache_hit_creation():
5459
hit = CacheHit(
5560
entry_id="entry_1",
5661
prompt="What is AI?",
5762
response="AI is artificial intelligence.",
5863
vector_distance=0.1,
5964
inserted_at=1625819123.123,
60-
updated_at=1625819123.123
65+
updated_at=1625819123.123,
6166
)
6267
assert hit.entry_id == "entry_1"
6368
assert hit.prompt == "What is AI?"
6469
assert hit.response == "AI is artificial intelligence."
6570
assert hit.vector_distance == 0.1
6671
assert hit.inserted_at == hit.updated_at == 1625819123.123
6772

73+
6874
def test_cache_hit_with_serialized_metadata():
6975
hit = CacheHit(
7076
entry_id="entry_1",
@@ -73,10 +79,11 @@ def test_cache_hit_with_serialized_metadata():
7379
vector_distance=0.1,
7480
inserted_at=1625819123.123,
7581
updated_at=1625819123.123,
76-
metadata=json.dumps({"author": "John"})
82+
metadata=json.dumps({"author": "John"}),
7783
)
7884
assert hit.metadata == {"author": "John"}
7985

86+
8087
def test_cache_hit_to_dict():
8188
hit = CacheHit(
8289
entry_id="entry_1",
@@ -85,7 +92,7 @@ def test_cache_hit_to_dict():
8592
vector_distance=0.1,
8693
inserted_at=1625819123.123,
8794
updated_at=1625819123.123,
88-
filters={"category": "technology"}
95+
filters={"category": "technology"},
8996
)
9097
result = hit.to_dict()
9198
assert result["entry_id"] == "entry_1"
@@ -95,24 +102,26 @@ def test_cache_hit_to_dict():
95102
assert result["category"] == "technology"
96103
assert "filters" not in result
97104

105+
98106
def test_cache_entry_with_empty_optional_fields():
99107
entry = CacheEntry(
100108
prompt="What is AI?",
101109
response="AI is artificial intelligence.",
102-
prompt_vector=[0.1, 0.2, 0.3]
110+
prompt_vector=[0.1, 0.2, 0.3],
103111
)
104112
result = entry.to_dict()
105113
assert "metadata" not in result
106114
assert "filters" not in result
107115

116+
108117
def test_cache_hit_with_empty_optional_fields():
109118
hit = CacheHit(
110119
entry_id="entry_1",
111120
prompt="What is AI?",
112121
response="AI is artificial intelligence.",
113122
vector_distance=0.1,
114123
inserted_at=1625819123.123,
115-
updated_at=1625819123.123
124+
updated_at=1625819123.123,
116125
)
117126
result = hit.to_dict()
118127
assert "metadata" not in result

0 commit comments

Comments
 (0)