11import hashlib
2- from typing import Any , Dict , List , Optional
2+ from typing import Any , Dict , List , Optional , Type
33
44import redis .commands .search .reducers as reducers
55from pydantic .v1 import BaseModel , Field , PrivateAttr
66from redis import Redis
7- from redis .commands .search .aggregation import AggregateRequest , AggregateResult
8-
9- from redisvl .extensions .router .schema import Route , RoutingConfig , RouteMatch , DistanceAggregationMethod
7+ from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
8+
9+ from redisvl .extensions .router .schema import (
10+ DistanceAggregationMethod ,
11+ Route ,
12+ RouteMatch ,
13+ RoutingConfig ,
14+ )
1015from redisvl .index import SearchIndex
1116from redisvl .query import RangeQuery
1217from redisvl .redis .utils import convert_bytes , make_dict
@@ -18,7 +23,7 @@ class SemanticRouterIndexSchema(IndexSchema):
1823 """Customized index schema for SemanticRouter."""
1924
2025 @classmethod
21- def from_params (cls , name : str , vector_dims : int ) -> ' SemanticRouterIndexSchema' :
26+ def from_params (cls , name : str , vector_dims : int ) -> " SemanticRouterIndexSchema" :
2227 """Create an index schema based on router name and vector dimensions.
2328
2429 Args:
@@ -30,7 +35,7 @@ def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema'
3035 """
3136 return cls (
3237 index = IndexInfo (name = name , prefix = name ),
33- fields = [
38+ fields = [ # type: ignore
3439 {"name" : "route_name" , "type" : "tag" },
3540 {"name" : "reference" , "type" : "text" },
3641 {
@@ -87,7 +92,12 @@ def __init__(
8792 overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
8893 **kwargs: Additional arguments.
8994 """
90- super ().__init__ (name = name , routes = routes , vectorizer = vectorizer , routing_config = routing_config )
95+ super ().__init__ (
96+ name = name ,
97+ routes = routes ,
98+ vectorizer = vectorizer ,
99+ routing_config = routing_config ,
100+ )
91101 self ._initialize_index (redis_client , redis_url , overwrite )
92102
93103 def _initialize_index (
@@ -130,7 +140,7 @@ def route_names(self) -> List[str]:
130140 return [route .name for route in self .routes ]
131141
132142 @property
133- def route_thresholds (self ) -> Dict [str , float ]:
143+ def route_thresholds (self ) -> Dict [str , Optional [ float ] ]:
134144 """Get the distance thresholds for each route.
135145
136146 Returns:
@@ -168,7 +178,9 @@ def _add_routes(self, routes: List[Route]):
168178 }
169179 )
170180 reference_hash = hashlib .sha256 (reference .encode ("utf-8" )).hexdigest ()
171- keys .append (f"{ self ._index .schema .index .prefix } :{ route .name } :{ reference_hash } " )
181+ keys .append (
182+ f"{ self ._index .schema .index .prefix } :{ route .name } :{ reference_hash } "
183+ )
172184
173185 # set route if does not yet exist client side
174186 if not self .get (route .name ):
@@ -204,7 +216,7 @@ def _build_aggregate_request(
204216 self ,
205217 vector_range_query : RangeQuery ,
206218 aggregation_method : DistanceAggregationMethod ,
207- max_k : int
219+ max_k : int ,
208220 ) -> AggregateRequest :
209221 """Build the Redis aggregation request.
210222
@@ -216,6 +228,8 @@ def _build_aggregate_request(
216228 Returns:
217229 AggregateRequest: The constructed aggregation request.
218230 """
231+ aggregation_func : Type [Reducer ]
232+
219233 if aggregation_method == DistanceAggregationMethod .min :
220234 aggregation_func = reducers .min
221235 elif aggregation_method == DistanceAggregationMethod .sum :
@@ -226,7 +240,9 @@ def _build_aggregate_request(
226240 aggregate_query = str (vector_range_query ).split (" RETURN" )[0 ]
227241 aggregate_request = (
228242 AggregateRequest (aggregate_query )
229- .group_by ("@route_name" , aggregation_func ("vector_distance" ).alias ("distance" ))
243+ .group_by (
244+ "@route_name" , aggregation_func ("vector_distance" ).alias ("distance" )
245+ )
230246 .sort_by ("@distance" , max = max_k )
231247 .dialect (2 )
232248 )
@@ -237,7 +253,7 @@ def _classify(
237253 self ,
238254 vector : List [float ],
239255 distance_threshold : float ,
240- aggregation_method : DistanceAggregationMethod
256+ aggregation_method : DistanceAggregationMethod ,
241257 ) -> List [RouteMatch ]:
242258 """Classify a single query vector.
243259
@@ -256,16 +272,20 @@ def _classify(
256272 return_fields = ["route_name" ],
257273 )
258274
259- aggregate_request = self ._build_aggregate_request (vector_range_query , aggregation_method , max_k = 1 )
260- route_matches : AggregateResult = self ._index .client .ft (self ._index .name ).aggregate (aggregate_request , vector_range_query .params )
275+ aggregate_request = self ._build_aggregate_request (
276+ vector_range_query , aggregation_method , max_k = 1
277+ )
278+ route_matches : AggregateResult = self ._index .client .ft ( # type: ignore
279+ self ._index .name
280+ ).aggregate (aggregate_request , vector_range_query .params )
261281 return [self ._process_route (route_match ) for route_match in route_matches .rows ]
262282
263283 def _classify_many (
264284 self ,
265285 vector : List [float ],
266286 max_k : int ,
267287 distance_threshold : float ,
268- aggregation_method : DistanceAggregationMethod
288+ aggregation_method : DistanceAggregationMethod ,
269289 ) -> List [RouteMatch ]:
270290 """Classify multiple query vectors.
271291
@@ -284,8 +304,12 @@ def _classify_many(
284304 distance_threshold = distance_threshold ,
285305 return_fields = ["route_name" ],
286306 )
287- aggregate_request = self ._build_aggregate_request (vector_range_query , aggregation_method , max_k )
288- route_matches : AggregateResult = self ._index .client .ft (self ._index .name ).aggregate (aggregate_request , vector_range_query .params )
307+ aggregate_request = self ._build_aggregate_request (
308+ vector_range_query , aggregation_method , max_k
309+ )
310+ route_matches : AggregateResult = self ._index .client .ft ( # type: ignore
311+ self ._index .name
312+ ).aggregate (aggregate_request , vector_range_query .params )
289313 return [self ._process_route (route_match ) for route_match in route_matches .rows ]
290314
291315 def _pass_threshold (self , route_match : Optional [RouteMatch ]) -> bool :
@@ -297,7 +321,11 @@ def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool:
297321 Returns:
298322 bool: True if the route match passes the threshold, False otherwise.
299323 """
300- return route_match is not None and route_match .distance <= route_match .route .distance_threshold
324+ if route_match :
325+ if route_match .distance is not None and route_match .route is not None :
326+ if route_match .route .distance_threshold :
327+ return route_match .distance <= route_match .route .distance_threshold
328+ return False
301329
302330 def __call__ (
303331 self ,
@@ -320,8 +348,12 @@ def __call__(
320348 raise ValueError ("Must provide a vector or statement to the router" )
321349 vector = self .vectorizer .embed (statement )
322350
323- distance_threshold = distance_threshold or self .routing_config .distance_threshold
324- route_matches = self ._classify (vector , distance_threshold , self .routing_config .aggregation_method )
351+ distance_threshold = (
352+ distance_threshold or self .routing_config .distance_threshold
353+ )
354+ route_matches = self ._classify (
355+ vector , distance_threshold , self .routing_config .aggregation_method
356+ )
325357 route_match = route_matches [0 ] if route_matches else None
326358
327359 if route_match and self ._pass_threshold (route_match ):
@@ -352,8 +384,16 @@ def route_many(
352384 raise ValueError ("Must provide a vector or statement to the router" )
353385 vector = self .vectorizer .embed (statement )
354386
355- distance_threshold = distance_threshold or self .routing_config .distance_threshold
387+ distance_threshold = (
388+ distance_threshold or self .routing_config .distance_threshold
389+ )
356390 max_k = max_k or self .routing_config .max_k
357- route_matches = self ._classify_many (vector , max_k , distance_threshold , self .routing_config .aggregation_method )
391+ route_matches = self ._classify_many (
392+ vector , max_k , distance_threshold , self .routing_config .aggregation_method
393+ )
358394
359- return [route_match for route_match in route_matches if self ._pass_threshold (route_match )]
395+ return [
396+ route_match
397+ for route_match in route_matches
398+ if self ._pass_threshold (route_match )
399+ ]
0 commit comments