Skip to content

Commit 0bdae21

Browse files
Revert "support dynamic distance threshold"
This reverts commit 16050b7.
1 parent 2c6f3bc commit 0bdae21

File tree

4 files changed

+6
-22
lines changed

4 files changed

+6
-22
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ def check(
233233
num_results: int = 1,
234234
return_fields: Optional[List[str]] = None,
235235
filter_expression: Optional[FilterExpression] = None,
236-
distance_threshold: Optional[float] = None,
237236
) -> List[Dict[str, Any]]:
238237
"""Checks the semantic cache for results similar to the specified prompt
239238
or vector.
@@ -256,8 +255,6 @@ def check(
256255
filter_expression (Optional[FilterExpression]) : Optional filter expression
257256
that can be used to filter cache results. Defaults to None and
258257
the full cache will be searched.
259-
distance_threshold (Optional[float]): The threshold for semantic
260-
vector distance.
261258
262259
Returns:
263260
List[Dict[str, Any]]: A list of dicts containing the requested
@@ -277,12 +274,9 @@ def check(
277274
if not (prompt or vector):
278275
raise ValueError("Either prompt or vector must be specified.")
279276

280-
# overrides
281-
distance_threshold = distance_threshold or self._distance_threshold
282-
return_fields = return_fields or self.return_fields
283277
vector = vector or self._vectorize_prompt(prompt)
284-
285278
self._check_vector_dims(vector)
279+
return_fields = return_fields or self.return_fields
286280

287281
if not isinstance(return_fields, list):
288282
raise TypeError("return_fields must be a list of field names")
@@ -291,7 +285,7 @@ def check(
291285
vector=vector,
292286
vector_field_name=self.vector_field_name,
293287
return_fields=self.return_fields,
294-
distance_threshold=distance_threshold,
288+
distance_threshold=self._distance_threshold,
295289
num_results=num_results,
296290
return_score=True,
297291
filter_expression=filter_expression,

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def get_relevant(
137137
fall_back: bool = False,
138138
session_tag: Optional[str] = None,
139139
raw: bool = False,
140-
distance_threshold: Optional[float] = None,
141140
) -> Union[List[str], List[Dict[str, str]]]:
142141
"""Searches the chat history for information semantically related to
143142
the specified prompt.
@@ -152,12 +151,10 @@ def get_relevant(
152151
as_text (bool): Whether to return the prompts and responses as text
153152
or as JSON
154153
top_k (int): The number of previous messages to return. Default is 5.
155-
session_tag (Optional[str]): Tag to be added to entries to link to a specific
156-
session. Defaults to instance uuid.
157-
distance_threshold (Optional[float]): The threshold for semantic
158-
vector distance.
159154
fall_back (bool): Whether to drop back to recent conversation history
160155
if no relevant context is found.
156+
session_tag (Optional[str]): Tag to be added to entries to link to a specific
157+
session. Defaults to instance uuid.
161158
raw (bool): Whether to return the full Redis hash entry or just the
162159
message.
163160
@@ -172,9 +169,6 @@ def get_relevant(
172169
if top_k == 0:
173170
return []
174171

175-
# override distance threshold
176-
distance_threshold = distance_threshold or self._distance_threshold
177-
178172
return_fields = [
179173
self.session_field_name,
180174
self.role_field_name,
@@ -193,7 +187,7 @@ def get_relevant(
193187
vector=self._vectorizer.embed(prompt),
194188
vector_field_name=self.vector_field_name,
195189
return_fields=return_fields,
196-
distance_threshold=distance_threshold,
190+
distance_threshold=self._distance_threshold,
197191
num_results=top_k,
198192
return_score=True,
199193
filter_expression=session_filter,

tests/integration/test_llmcache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_store_and_check(cache, vectorizer):
9595
vector = vectorizer.embed(prompt)
9696

9797
cache.store(prompt, response, vector=vector)
98-
check_result = cache.check(vector=vector, distance_threshold=0.4)
98+
check_result = cache.check(vector=vector)
9999

100100
assert len(check_result) == 1
101101
print(check_result, flush=True)

tests/integration/test_session_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,6 @@ def test_semantic_add_and_get_relevant(semantic_session):
463463
semantic_session.set_distance_threshold(0.5)
464464
default_context = semantic_session.get_relevant("list of fruits and vegetables")
465465
assert len(default_context) == 5 # 2 pairs of prompt:response, and system
466-
assert default_context == semantic_session.get_relevant(
467-
"list of fruits and vegetables",
468-
distance_threshold=0.5
469-
)
470466

471467
# test tool calls can also be returned
472468
context = semantic_session.get_relevant("winter sports like skiing")

0 commit comments

Comments
 (0)