@@ -174,9 +174,7 @@ def _add_routes(self, routes: List[Route]):
174
174
keys : List [str ] = []
175
175
176
176
for route in routes :
177
- if route .distance_threshold is None :
178
- route .distance_threshold = self .routing_config .distance_threshold
179
- # set route reference
177
+ # set route references
180
178
for reference in route .references :
181
179
route_references .append (
182
180
{
@@ -338,33 +336,37 @@ def _classify_many(
338
336
)
339
337
raise e
340
338
341
- def _pass_threshold (self , route_match : Optional [RouteMatch ]) -> bool :
339
+ def _pass_threshold (self , route_match : Optional [RouteMatch ], distance_threshold : float ) -> bool :
342
340
"""Check if a route match passes the distance threshold.
343
341
344
342
Args:
345
343
route_match (Optional[RouteMatch]): The route match to check.
344
+ distance_threshold (float): The fallback distance threshold to use if not assigned to a route.
346
345
347
346
Returns:
348
347
bool: True if the route match passes the threshold, False otherwise.
349
348
"""
350
349
if route_match :
351
350
if route_match .distance is not None and route_match .route is not None :
352
- if route_match .route .distance_threshold :
353
- return route_match .distance <= route_match .route .distance_threshold
351
+ _distance_threshold = route_match .route .distance_threshold or distance_threshold
352
+ if _distance_threshold :
353
+ return route_match .distance <= _distance_threshold
354
354
return False
355
355
356
356
def __call__ (
357
357
self ,
358
358
statement : Optional [str ] = None ,
359
359
vector : Optional [List [float ]] = None ,
360
360
distance_threshold : Optional [float ] = None ,
361
+ aggregation_method : Optional [DistanceAggregationMethod ] = None
361
362
) -> RouteMatch :
362
363
"""Query the semantic router with a given statement or vector.
363
364
364
365
Args:
365
366
statement (Optional[str]): The input statement to be queried.
366
367
vector (Optional[List[float]]): The input vector to be queried.
367
368
distance_threshold (Optional[float]): The threshold for semantic distance.
369
+ aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
368
370
369
371
Returns:
370
372
RouteMatch: The matching route.
@@ -377,12 +379,13 @@ def __call__(
377
379
distance_threshold = (
378
380
distance_threshold or self .routing_config .distance_threshold
379
381
)
382
+ aggregation_method = aggregation_method or self .routing_config .aggregation_method
380
383
route_matches = self ._classify (
381
- vector , distance_threshold , self . routing_config . aggregation_method
384
+ vector , distance_threshold , aggregation_method
382
385
)
383
386
route_match = route_matches [0 ] if route_matches else None
384
387
385
- if route_match and self ._pass_threshold (route_match ):
388
+ if route_match and self ._pass_threshold (route_match , distance_threshold ):
386
389
return route_match
387
390
388
391
return RouteMatch ()
@@ -393,6 +396,7 @@ def route_many(
393
396
vector : Optional [List [float ]] = None ,
394
397
max_k : Optional [int ] = None ,
395
398
distance_threshold : Optional [float ] = None ,
399
+ aggregation_method : Optional [DistanceAggregationMethod ] = None
396
400
) -> List [RouteMatch ]:
397
401
"""Query the semantic router with a given statement or vector for multiple matches.
398
402
@@ -401,6 +405,7 @@ def route_many(
401
405
vector (Optional[List[float]]): The input vector to be queried.
402
406
max_k (Optional[int]): The maximum number of top matches to return.
403
407
distance_threshold (Optional[float]): The threshold for semantic distance.
408
+ aggregation_method (Optional[DistanceAggregationMethod]): The aggregation method used for vector distances.
404
409
405
410
Returns:
406
411
List[RouteMatch]: The matching routes and their details.
@@ -414,12 +419,19 @@ def route_many(
414
419
distance_threshold or self .routing_config .distance_threshold
415
420
)
416
421
max_k = max_k or self .routing_config .max_k
422
+ aggregation_method = aggregation_method or self .routing_config .aggregation_method
417
423
route_matches = self ._classify_many (
418
- vector , max_k , distance_threshold , self . routing_config . aggregation_method
424
+ vector , max_k , distance_threshold , aggregation_method
419
425
)
420
426
421
427
return [
422
428
route_match
423
429
for route_match in route_matches
424
- if self ._pass_threshold (route_match )
430
+ if self ._pass_threshold (route_match , distance_threshold )
425
431
]
432
+
433
+ def delete (self ):
434
+ self ._index .delete (drop = True )
435
+
436
+ def clear (self ):
437
+ self ._index .clear ()
0 commit comments