1
1
import hashlib
2
- from typing import Any , Dict , List , Optional
2
+ from typing import Any , Dict , List , Optional , Type
3
3
4
4
import redis .commands .search .reducers as reducers
5
5
from pydantic .v1 import BaseModel , Field , PrivateAttr
6
6
from redis import Redis
7
- from redis .commands .search .aggregation import AggregateRequest , AggregateResult
8
-
9
- from redisvl .extensions .router .schema import Route , RoutingConfig , RouteMatch , DistanceAggregationMethod
7
+ from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
8
+
9
+ from redisvl .extensions .router .schema import (
10
+ DistanceAggregationMethod ,
11
+ Route ,
12
+ RouteMatch ,
13
+ RoutingConfig ,
14
+ )
10
15
from redisvl .index import SearchIndex
11
16
from redisvl .query import RangeQuery
12
17
from redisvl .redis .utils import convert_bytes , make_dict
@@ -18,7 +23,7 @@ class SemanticRouterIndexSchema(IndexSchema):
18
23
"""Customized index schema for SemanticRouter."""
19
24
20
25
@classmethod
21
- def from_params (cls , name : str , vector_dims : int ) -> ' SemanticRouterIndexSchema' :
26
+ def from_params (cls , name : str , vector_dims : int ) -> " SemanticRouterIndexSchema" :
22
27
"""Create an index schema based on router name and vector dimensions.
23
28
24
29
Args:
@@ -30,7 +35,7 @@ def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema'
30
35
"""
31
36
return cls (
32
37
index = IndexInfo (name = name , prefix = name ),
33
- fields = [
38
+ fields = [ # type: ignore
34
39
{"name" : "route_name" , "type" : "tag" },
35
40
{"name" : "reference" , "type" : "text" },
36
41
{
@@ -87,7 +92,12 @@ def __init__(
87
92
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
88
93
**kwargs: Additional arguments.
89
94
"""
90
- super ().__init__ (name = name , routes = routes , vectorizer = vectorizer , routing_config = routing_config )
95
+ super ().__init__ (
96
+ name = name ,
97
+ routes = routes ,
98
+ vectorizer = vectorizer ,
99
+ routing_config = routing_config ,
100
+ )
91
101
self ._initialize_index (redis_client , redis_url , overwrite )
92
102
93
103
def _initialize_index (
@@ -130,7 +140,7 @@ def route_names(self) -> List[str]:
130
140
return [route .name for route in self .routes ]
131
141
132
142
@property
133
- def route_thresholds (self ) -> Dict [str , float ]:
143
+ def route_thresholds (self ) -> Dict [str , Optional [ float ] ]:
134
144
"""Get the distance thresholds for each route.
135
145
136
146
Returns:
@@ -168,7 +178,9 @@ def _add_routes(self, routes: List[Route]):
168
178
}
169
179
)
170
180
reference_hash = hashlib .sha256 (reference .encode ("utf-8" )).hexdigest ()
171
- keys .append (f"{ self ._index .schema .index .prefix } :{ route .name } :{ reference_hash } " )
181
+ keys .append (
182
+ f"{ self ._index .schema .index .prefix } :{ route .name } :{ reference_hash } "
183
+ )
172
184
173
185
# set route if does not yet exist client side
174
186
if not self .get (route .name ):
@@ -204,7 +216,7 @@ def _build_aggregate_request(
204
216
self ,
205
217
vector_range_query : RangeQuery ,
206
218
aggregation_method : DistanceAggregationMethod ,
207
- max_k : int
219
+ max_k : int ,
208
220
) -> AggregateRequest :
209
221
"""Build the Redis aggregation request.
210
222
@@ -216,6 +228,8 @@ def _build_aggregate_request(
216
228
Returns:
217
229
AggregateRequest: The constructed aggregation request.
218
230
"""
231
+ aggregation_func : Type [Reducer ]
232
+
219
233
if aggregation_method == DistanceAggregationMethod .min :
220
234
aggregation_func = reducers .min
221
235
elif aggregation_method == DistanceAggregationMethod .sum :
@@ -226,7 +240,9 @@ def _build_aggregate_request(
226
240
aggregate_query = str (vector_range_query ).split (" RETURN" )[0 ]
227
241
aggregate_request = (
228
242
AggregateRequest (aggregate_query )
229
- .group_by ("@route_name" , aggregation_func ("vector_distance" ).alias ("distance" ))
243
+ .group_by (
244
+ "@route_name" , aggregation_func ("vector_distance" ).alias ("distance" )
245
+ )
230
246
.sort_by ("@distance" , max = max_k )
231
247
.dialect (2 )
232
248
)
@@ -237,7 +253,7 @@ def _classify(
237
253
self ,
238
254
vector : List [float ],
239
255
distance_threshold : float ,
240
- aggregation_method : DistanceAggregationMethod
256
+ aggregation_method : DistanceAggregationMethod ,
241
257
) -> List [RouteMatch ]:
242
258
"""Classify a single query vector.
243
259
@@ -256,16 +272,20 @@ def _classify(
256
272
return_fields = ["route_name" ],
257
273
)
258
274
259
- aggregate_request = self ._build_aggregate_request (vector_range_query , aggregation_method , max_k = 1 )
260
- route_matches : AggregateResult = self ._index .client .ft (self ._index .name ).aggregate (aggregate_request , vector_range_query .params )
275
+ aggregate_request = self ._build_aggregate_request (
276
+ vector_range_query , aggregation_method , max_k = 1
277
+ )
278
+ route_matches : AggregateResult = self ._index .client .ft ( # type: ignore
279
+ self ._index .name
280
+ ).aggregate (aggregate_request , vector_range_query .params )
261
281
return [self ._process_route (route_match ) for route_match in route_matches .rows ]
262
282
263
283
def _classify_many (
264
284
self ,
265
285
vector : List [float ],
266
286
max_k : int ,
267
287
distance_threshold : float ,
268
- aggregation_method : DistanceAggregationMethod
288
+ aggregation_method : DistanceAggregationMethod ,
269
289
) -> List [RouteMatch ]:
270
290
"""Classify multiple query vectors.
271
291
@@ -284,8 +304,12 @@ def _classify_many(
284
304
distance_threshold = distance_threshold ,
285
305
return_fields = ["route_name" ],
286
306
)
287
- aggregate_request = self ._build_aggregate_request (vector_range_query , aggregation_method , max_k )
288
- route_matches : AggregateResult = self ._index .client .ft (self ._index .name ).aggregate (aggregate_request , vector_range_query .params )
307
+ aggregate_request = self ._build_aggregate_request (
308
+ vector_range_query , aggregation_method , max_k
309
+ )
310
+ route_matches : AggregateResult = self ._index .client .ft ( # type: ignore
311
+ self ._index .name
312
+ ).aggregate (aggregate_request , vector_range_query .params )
289
313
return [self ._process_route (route_match ) for route_match in route_matches .rows ]
290
314
291
315
def _pass_threshold (self , route_match : Optional [RouteMatch ]) -> bool :
@@ -297,7 +321,11 @@ def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool:
297
321
Returns:
298
322
bool: True if the route match passes the threshold, False otherwise.
299
323
"""
300
- return route_match is not None and route_match .distance <= route_match .route .distance_threshold
324
+ if route_match :
325
+ if route_match .distance is not None and route_match .route is not None :
326
+ if route_match .route .distance_threshold :
327
+ return route_match .distance <= route_match .route .distance_threshold
328
+ return False
301
329
302
330
def __call__ (
303
331
self ,
@@ -320,8 +348,12 @@ def __call__(
320
348
raise ValueError ("Must provide a vector or statement to the router" )
321
349
vector = self .vectorizer .embed (statement )
322
350
323
- distance_threshold = distance_threshold or self .routing_config .distance_threshold
324
- route_matches = self ._classify (vector , distance_threshold , self .routing_config .aggregation_method )
351
+ distance_threshold = (
352
+ distance_threshold or self .routing_config .distance_threshold
353
+ )
354
+ route_matches = self ._classify (
355
+ vector , distance_threshold , self .routing_config .aggregation_method
356
+ )
325
357
route_match = route_matches [0 ] if route_matches else None
326
358
327
359
if route_match and self ._pass_threshold (route_match ):
@@ -352,8 +384,16 @@ def route_many(
352
384
raise ValueError ("Must provide a vector or statement to the router" )
353
385
vector = self .vectorizer .embed (statement )
354
386
355
- distance_threshold = distance_threshold or self .routing_config .distance_threshold
387
+ distance_threshold = (
388
+ distance_threshold or self .routing_config .distance_threshold
389
+ )
356
390
max_k = max_k or self .routing_config .max_k
357
- route_matches = self ._classify_many (vector , max_k , distance_threshold , self .routing_config .aggregation_method )
391
+ route_matches = self ._classify_many (
392
+ vector , max_k , distance_threshold , self .routing_config .aggregation_method
393
+ )
358
394
359
- return [route_match for route_match in route_matches if self ._pass_threshold (route_match )]
395
+ return [
396
+ route_match
397
+ for route_match in route_matches
398
+ if self ._pass_threshold (route_match )
399
+ ]
0 commit comments