2
2
from typing import Any , List , Dict , Optional , Union
3
3
4
4
from redis import Redis
5
- from redis .commands .search .aggregation import AggregateRequest , AggregateResult , Reducer
5
+ from redis .commands .search .aggregation import AggregateRequest , AggregateResult
6
6
import redis .commands .search .reducers as reducers
7
7
8
8
from redisvl .index import SearchIndex
9
9
from redisvl .query import VectorQuery , RangeQuery
10
10
from redisvl .schema import IndexSchema , IndexInfo
11
11
from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
12
- from redisvl .extensions .router .routes import Route , RoutingConfig , AccumulationMethod
12
+ from redisvl .extensions .router .routes import Route , RoutingConfig , RouteSortingMethod
13
13
14
14
from redisvl .redis .utils import make_dict , convert_bytes
15
15
@@ -20,6 +20,9 @@ class SemanticRouterIndexSchema(IndexSchema):
20
20
21
21
@classmethod
22
22
def from_params (cls , name : str , vector_dims : int ):
23
+ """Load the semantic router index schema from the router name and
24
+ vector dimensionality.
25
+ """
23
26
return cls (
24
27
index = IndexInfo (name = name , prefix = name ),
25
28
fields = [
@@ -50,7 +53,6 @@ class SemanticRouter(BaseModel):
50
53
"""Configuration for routing behavior"""
51
54
52
55
_index : SearchIndex = PrivateAttr ()
53
- # _accumulation_method: AccumulationMethod = PrivateAttr()
54
56
55
57
class Config :
56
58
arbitrary_types_allowed = True
@@ -85,7 +87,6 @@ def __init__(
85
87
routing_config = routing_config
86
88
)
87
89
self ._initialize_index (redis_client , redis_url , overwrite )
88
- # self._accumulation_method = self._pick_accumulation_method()
89
90
90
91
def _initialize_index (
91
92
self ,
@@ -116,20 +117,6 @@ def _initialize_index(
116
117
if not existed or overwrite :
117
118
self ._add_routes (self .routes )
118
119
119
- # def _pick_accumulation_method(self) -> AccumulationMethod:
120
- # """Pick the accumulation method based on the routing configuration."""
121
- # if self.routing_config.accumulation_method != AccumulationMethod.auto:
122
- # return self.routing_config.accumulation_method
123
-
124
- # num_route_references = [len(route.references) for route in self.routes]
125
- # avg_num_references = sum(num_route_references) / len(num_route_references)
126
- # variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references)
127
-
128
- # if variance < 1: # TODO: Arbitrary threshold for low variance
129
- # return AccumulationMethod.sum
130
- # else:
131
- # return AccumulationMethod.avg
132
-
133
120
def update_routing_config (self , routing_config : RoutingConfig ):
134
121
"""Update the routing configuration.
135
122
@@ -165,25 +152,24 @@ def __call__(
165
152
statement : str ,
166
153
max_k : Optional [int ] = None ,
167
154
distance_threshold : Optional [float ] = None ,
155
+ sort_by : Optional [str ] = None
168
156
) -> List [Dict [str , Any ]]:
169
157
"""Query the semantic router with a given statement.
170
158
171
159
Args:
172
160
statement (str): The input statement to be queried.
173
161
max_k (Optional[int]): The maximum number of top matches to return.
174
162
distance_threshold (Optional[float]): The threshold for semantic distance.
163
+ sort_by (Optional[str]): The technique used to sort the final route matches before truncating.
175
164
176
165
Returns:
177
166
List[Dict[str, Any]]: The matching routes and their details.
178
167
"""
179
168
vector = self .vectorizer .embed (statement )
180
169
max_k = max_k if max_k is not None else self .routing_config .max_k
181
170
distance_threshold = distance_threshold if distance_threshold is not None else self .routing_config .distance_threshold
171
+ sort_by = RouteSortingMethod (sort_by ) if sort_by is not None else self .routing_config .sort_by
182
172
183
- # # get the total number of route references in the index
184
- # num_route_references = sum(
185
- # [len(route.references) for route in self.routes]
186
- # )
187
173
# define the baseline range query to fetch relevant route references
188
174
vector_range_query = RangeQuery (
189
175
vector = vector ,
@@ -198,42 +184,40 @@ def __call__(
198
184
AggregateRequest (aggregate_query )
199
185
.group_by (
200
186
"@route_name" ,
201
- reducers .avg ("vector_distance" ).alias ("avg " ),
202
- reducers .min ("vector_distance" ).alias ("score " )
187
+ reducers .avg ("vector_distance" ).alias ("avg_distance " ),
188
+ reducers .min ("vector_distance" ).alias ("min_distance " )
203
189
)
204
- .apply (avg_score = "1 - @avg" , score = "1 - @score" )
205
190
.dialect (2 )
206
191
)
207
192
208
- top_routes_and_scores = []
209
- aggregate_results = self ._index .client .ft (self ._index .name ).aggregate (aggregate_request , vector_range_query .params )
210
-
211
- for result in aggregate_results .rows :
212
- top_routes_and_scores .append (make_dict (convert_bytes (result )))
193
+ # run the aggregation query in Redis
194
+ aggregate_result : AggregateResult = (
195
+ self ._index .client
196
+ .ft (self ._index .name )
197
+ .aggregate (aggregate_request , vector_range_query .params )
198
+ )
213
199
214
- top_routes = self ._fetch_routes (top_routes_and_scores )
200
+ top_routes_and_scores = sorted ([
201
+ self ._process_result (result ) for result in aggregate_result .rows
202
+ ], key = lambda r : r [sort_by .value ])
215
203
216
- return top_routes
204
+ return top_routes_and_scores [: max_k ]
217
205
218
206
219
- def _fetch_routes (self , top_routes_and_scores : List [ Dict [str , Any ]] ) -> List [ Dict [str , Any ] ]:
220
- """Fetch route objects and metadata based on top matches .
207
+ def _process_result (self , result : Dict [str , Any ]) -> Dict [str , Any ]:
208
+ """Process resulting route objects and metadata.
221
209
222
210
Args:
223
- top_routes_and_scores: List of top routes and their scores.
211
+ result: Aggregation query result object
224
212
225
213
Returns:
226
214
List[Dict[str, Any]]: Routes with their metadata.
227
215
"""
228
- results = []
229
- for route_info in top_routes_and_scores :
230
- route_name = route_info ["route_name" ]
231
- route = next ((r for r in self .routes if r .name == route_name ), None )
232
- if route :
233
- results .append ({
234
- ** route .dict (),
235
- "score" : route_info ["score" ],
236
- "avg_score" : route_info ["avg_score" ]
237
- })
238
-
239
- return results
216
+ result_dict = make_dict (convert_bytes (result ))
217
+ route_name = result_dict ["route_name" ]
218
+ route = next ((r for r in self .routes if r .name == route_name ), None )
219
+ return {
220
+ ** route .dict (),
221
+ "avg_distance" : float (result_dict ["avg_distance" ]),
222
+ "min_distance" : float (result_dict ["min_distance" ])
223
+ }
0 commit comments