Skip to content

Commit b6f6d69

Browse files
wip on docs
1 parent e5d6b04 commit b6f6d69

File tree

1 file changed

+22
-10
lines changed

1 file changed

+22
-10
lines changed

redisvl/extensions/router/semantic.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,7 @@ def _add_routes(self, routes: List[Route]):
174174
keys: List[str] = []
175175

176176
for route in routes:
177-
if route.distance_threshold is None:
178-
route.distance_threshold = self.routing_config.distance_threshold
179-
# set route reference
177+
# set route references
180178
for reference in route.references:
181179
route_references.append(
182180
{
@@ -338,33 +336,37 @@ def _classify_many(
338336
)
339337
raise e
340338

341-
def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool:
339+
def _pass_threshold(self, route_match: Optional[RouteMatch], distance_threshold: float) -> bool:
342340
"""Check if a route match passes the distance threshold.
343341
344342
Args:
345343
route_match (Optional[RouteMatch]): The route match to check.
344+
distance_threshold (float): The fallback distance threshold to use if not assigned to a route.
346345
347346
Returns:
348347
bool: True if the route match passes the threshold, False otherwise.
349348
"""
350349
if route_match:
351350
if route_match.distance is not None and route_match.route is not None:
352-
if route_match.route.distance_threshold:
353-
return route_match.distance <= route_match.route.distance_threshold
351+
_distance_threshold = route_match.route.distance_threshold or distance_threshold
352+
if _distance_threshold:
353+
return route_match.distance <= _distance_threshold
354354
return False
355355

356356
def __call__(
357357
self,
358358
statement: Optional[str] = None,
359359
vector: Optional[List[float]] = None,
360360
distance_threshold: Optional[float] = None,
361+
aggregation_method: Optional[DistanceAggregationMethod] = None
361362
) -> RouteMatch:
362363
"""Query the semantic router with a given statement or vector.
363364
364365
Args:
365366
statement (Optional[str]): The input statement to be queried.
366367
vector (Optional[List[float]]): The input vector to be queried.
367368
distance_threshold (Optional[float]): The threshold for semantic distance.
369+
aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
368370
369371
Returns:
370372
RouteMatch: The matching route.
@@ -377,12 +379,13 @@ def __call__(
377379
distance_threshold = (
378380
distance_threshold or self.routing_config.distance_threshold
379381
)
382+
aggregation_method = aggregation_method or self.routing_config.aggregation_method
380383
route_matches = self._classify(
381-
vector, distance_threshold, self.routing_config.aggregation_method
384+
vector, distance_threshold, aggregation_method
382385
)
383386
route_match = route_matches[0] if route_matches else None
384387

385-
if route_match and self._pass_threshold(route_match):
388+
if route_match and self._pass_threshold(route_match, distance_threshold):
386389
return route_match
387390

388391
return RouteMatch()
@@ -393,6 +396,7 @@ def route_many(
393396
vector: Optional[List[float]] = None,
394397
max_k: Optional[int] = None,
395398
distance_threshold: Optional[float] = None,
399+
aggregation_method: Optional[DistanceAggregationMethod] = None
396400
) -> List[RouteMatch]:
397401
"""Query the semantic router with a given statement or vector for multiple matches.
398402
@@ -401,6 +405,7 @@ def route_many(
401405
vector (Optional[List[float]]): The input vector to be queried.
402406
max_k (Optional[int]): The maximum number of top matches to return.
403407
distance_threshold (Optional[float]): The threshold for semantic distance.
408+
aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
404409
405410
Returns:
406411
List[RouteMatch]: The matching routes and their details.
@@ -414,12 +419,19 @@ def route_many(
414419
distance_threshold or self.routing_config.distance_threshold
415420
)
416421
max_k = max_k or self.routing_config.max_k
422+
aggregation_method = aggregation_method or self.routing_config.aggregation_method
417423
route_matches = self._classify_many(
418-
vector, max_k, distance_threshold, self.routing_config.aggregation_method
424+
vector, max_k, distance_threshold, aggregation_method
419425
)
420426

421427
return [
422428
route_match
423429
for route_match in route_matches
424-
if self._pass_threshold(route_match)
430+
if self._pass_threshold(route_match, distance_threshold)
425431
]
432+
433+
def delete(self):
434+
self._index.delete(drop=True)
435+
436+
def clear(self):
437+
self._index.clear()

0 commit comments

Comments
 (0)