Skip to content

Commit 13dcd66

Browse files
Support distance threshold override (#202)
At runtime, we should extend the ability to override the distance threshold if provided. This is also how the semantic router works. Parity between the extensions is key here (in terms of functionality)
1 parent 38f2fe1 commit 13dcd66

File tree

4 files changed

+22
-6
lines changed

4 files changed

+22
-6
lines changed

redisvl/extensions/llmcache/semantic.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def check(
233233
num_results: int = 1,
234234
return_fields: Optional[List[str]] = None,
235235
filter_expression: Optional[FilterExpression] = None,
236+
distance_threshold: Optional[float] = None,
236237
) -> List[Dict[str, Any]]:
237238
"""Checks the semantic cache for results similar to the specified prompt
238239
or vector.
@@ -255,6 +256,8 @@ def check(
255256
filter_expression (Optional[FilterExpression]) : Optional filter expression
256257
that can be used to filter cache results. Defaults to None and
257258
the full cache will be searched.
259+
distance_threshold (Optional[float]): The threshold for semantic
260+
vector distance.
258261
259262
Returns:
260263
List[Dict[str, Any]]: A list of dicts containing the requested
@@ -274,9 +277,12 @@ def check(
274277
if not (prompt or vector):
275278
raise ValueError("Either prompt or vector must be specified.")
276279

280+
# overrides
281+
distance_threshold = distance_threshold or self._distance_threshold
282+
return_fields = return_fields or self.return_fields
277283
vector = vector or self._vectorize_prompt(prompt)
284+
278285
self._check_vector_dims(vector)
279-
return_fields = return_fields or self.return_fields
280286

281287
if not isinstance(return_fields, list):
282288
raise TypeError("return_fields must be a list of field names")
@@ -285,7 +291,7 @@ def check(
285291
vector=vector,
286292
vector_field_name=self.vector_field_name,
287293
return_fields=self.return_fields,
288-
distance_threshold=self._distance_threshold,
294+
distance_threshold=distance_threshold,
289295
num_results=num_results,
290296
return_score=True,
291297
filter_expression=filter_expression,

redisvl/extensions/session_manager/semantic_session.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def get_relevant(
137137
fall_back: bool = False,
138138
session_tag: Optional[str] = None,
139139
raw: bool = False,
140+
distance_threshold: Optional[float] = None,
140141
) -> Union[List[str], List[Dict[str, str]]]:
141142
"""Searches the chat history for information semantically related to
142143
the specified prompt.
@@ -151,10 +152,12 @@ def get_relevant(
151152
as_text (bool): Whether to return the prompts and responses as text
152153
or as JSON
153154
top_k (int): The number of previous messages to return. Default is 5.
154-
fall_back (bool): Whether to drop back to recent conversation history
155-
if no relevant context is found.
156155
session_tag (Optional[str]): Tag to be added to entries to link to a specific
157156
session. Defaults to instance uuid.
157+
distance_threshold (Optional[float]): The threshold for semantic
158+
vector distance.
159+
fall_back (bool): Whether to drop back to recent conversation history
160+
if no relevant context is found.
158161
raw (bool): Whether to return the full Redis hash entry or just the
159162
message.
160163
@@ -169,6 +172,9 @@ def get_relevant(
169172
if top_k == 0:
170173
return []
171174

175+
# override distance threshold
176+
distance_threshold = distance_threshold or self._distance_threshold
177+
172178
return_fields = [
173179
self.session_field_name,
174180
self.role_field_name,
@@ -187,7 +193,7 @@ def get_relevant(
187193
vector=self._vectorizer.embed(prompt),
188194
vector_field_name=self.vector_field_name,
189195
return_fields=return_fields,
190-
distance_threshold=self._distance_threshold,
196+
distance_threshold=distance_threshold,
191197
num_results=top_k,
192198
return_score=True,
193199
filter_expression=session_filter,

tests/integration/test_llmcache.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer):
9595
vector = vectorizer.embed(prompt)
9696

9797
cache.store(prompt, response, vector=vector)
98-
check_result = cache.check(vector=vector)
98+
check_result = cache.check(vector=vector, distance_threshold=0.4)
9999

100100
assert len(check_result) == 1
101101
print(check_result, flush=True)

tests/integration/test_session_manager.py

+4
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,10 @@ def test_semantic_add_and_get_relevant(semantic_session):
463463
semantic_session.set_distance_threshold(0.5)
464464
default_context = semantic_session.get_relevant("list of fruits and vegetables")
465465
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
466+
assert default_context == semantic_session.get_relevant(
467+
"list of fruits and vegetables",
468+
distance_threshold=0.5
469+
)
466470

467471
# test tool calls can also be returned
468472
context = semantic_session.get_relevant("winter sports like skiing")

0 commit comments

Comments
 (0)