Skip to content

Commit 64cd6ad

Browse files
Add intial round of tests
1 parent d9766da commit 64cd6ad

File tree

6 files changed

+161
-350
lines changed

6 files changed

+161
-350
lines changed

redisvl/extensions/router/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from redisvl.extensions.router.semantic import SemanticRouter
2+
from redisvl.extensions.router.routes import Route, RoutingConfig
3+
4+
5+
__all__ = ["SemanticRouter", "Route", "RoutingConfig"]

redisvl/extensions/router/routes.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
2-
3-
from pydantic.v1 import BaseModel, Field, validator
4-
from typing import List, Dict, Optional
5-
61
from enum import Enum
2+
from pydantic.v1 import BaseModel, Field, validator
3+
from typing import List, Dict
74

85

96
class Route(BaseModel):
@@ -29,20 +26,18 @@ def references_must_not_be_empty(cls, v):
2926
return v
3027

3128

32-
class AccumulationMethod(Enum):
33-
simple = "simple" # Take the winner at face value
34-
avg = "avg" # Consider the avg score of all matches
35-
sum = "sum" # Consider the cumulative score of all matches
36-
auto = "auto" # Pick on the user's behalf?
29+
class RouteSortingMethod(Enum):
30+
avg_distance = "avg_distance"
31+
min_distance = "min_distance"
3732

3833

3934
class RoutingConfig(BaseModel):
40-
max_k: int = Field(default=1)
41-
"""The maximum number of top matches to return"""
4235
distance_threshold: float = Field(default=0.5)
43-
"""The threshold for semantic distance"""
44-
accumulation_method: AccumulationMethod = Field(default=AccumulationMethod.auto)
45-
"""The accumulation method used to determine the matching route"""
36+
"""The threshold for semantic distance."""
37+
max_k: int = Field(default=1)
38+
"""The maximum number of top matches to return."""
39+
sort_by: RouteSortingMethod = Field(default=RouteSortingMethod.avg_distance)
40+
"""The technique used to sort the final route matches before truncating."""
4641

4742
@validator('max_k')
4843
def max_k_must_be_positive(cls, v):

redisvl/extensions/router/semantic.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from typing import Any, List, Dict, Optional, Union
33

44
from redis import Redis
5-
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
5+
from redis.commands.search.aggregation import AggregateRequest, AggregateResult
66
import redis.commands.search.reducers as reducers
77

88
from redisvl.index import SearchIndex
99
from redisvl.query import VectorQuery, RangeQuery
1010
from redisvl.schema import IndexSchema, IndexInfo
1111
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer
12-
from redisvl.extensions.router.routes import Route, RoutingConfig, AccumulationMethod
12+
from redisvl.extensions.router.routes import Route, RoutingConfig, RouteSortingMethod
1313

1414
from redisvl.redis.utils import make_dict, convert_bytes
1515

@@ -20,6 +20,9 @@ class SemanticRouterIndexSchema(IndexSchema):
2020

2121
@classmethod
2222
def from_params(cls, name: str, vector_dims: int):
23+
"""Load the semantic router index schema from the router name and
24+
vector dimensionality.
25+
"""
2326
return cls(
2427
index=IndexInfo(name=name, prefix=name),
2528
fields=[
@@ -50,7 +53,6 @@ class SemanticRouter(BaseModel):
5053
"""Configuration for routing behavior"""
5154

5255
_index: SearchIndex = PrivateAttr()
53-
# _accumulation_method: AccumulationMethod = PrivateAttr()
5456

5557
class Config:
5658
arbitrary_types_allowed = True
@@ -85,7 +87,6 @@ def __init__(
8587
routing_config=routing_config
8688
)
8789
self._initialize_index(redis_client, redis_url, overwrite)
88-
# self._accumulation_method = self._pick_accumulation_method()
8990

9091
def _initialize_index(
9192
self,
@@ -116,20 +117,6 @@ def _initialize_index(
116117
if not existed or overwrite:
117118
self._add_routes(self.routes)
118119

119-
# def _pick_accumulation_method(self) -> AccumulationMethod:
120-
# """Pick the accumulation method based on the routing configuration."""
121-
# if self.routing_config.accumulation_method != AccumulationMethod.auto:
122-
# return self.routing_config.accumulation_method
123-
124-
# num_route_references = [len(route.references) for route in self.routes]
125-
# avg_num_references = sum(num_route_references) / len(num_route_references)
126-
# variance = sum((x - avg_num_references) ** 2 for x in num_route_references) / len(num_route_references)
127-
128-
# if variance < 1: # TODO: Arbitrary threshold for low variance
129-
# return AccumulationMethod.sum
130-
# else:
131-
# return AccumulationMethod.avg
132-
133120
def update_routing_config(self, routing_config: RoutingConfig):
134121
"""Update the routing configuration.
135122
@@ -165,25 +152,24 @@ def __call__(
165152
statement: str,
166153
max_k: Optional[int] = None,
167154
distance_threshold: Optional[float] = None,
155+
sort_by: Optional[str] = None
168156
) -> List[Dict[str, Any]]:
169157
"""Query the semantic router with a given statement.
170158
171159
Args:
172160
statement (str): The input statement to be queried.
173161
max_k (Optional[int]): The maximum number of top matches to return.
174162
distance_threshold (Optional[float]): The threshold for semantic distance.
163+
sort_by (Optional[str]): The technique used to sort the final route matches before truncating.
175164
176165
Returns:
177166
List[Dict[str, Any]]: The matching routes and their details.
178167
"""
179168
vector = self.vectorizer.embed(statement)
180169
max_k = max_k if max_k is not None else self.routing_config.max_k
181170
distance_threshold = distance_threshold if distance_threshold is not None else self.routing_config.distance_threshold
171+
sort_by = RouteSortingMethod(sort_by) if sort_by is not None else self.routing_config.sort_by
182172

183-
# # get the total number of route references in the index
184-
# num_route_references = sum(
185-
# [len(route.references) for route in self.routes]
186-
# )
187173
# define the baseline range query to fetch relevant route references
188174
vector_range_query = RangeQuery(
189175
vector=vector,
@@ -198,42 +184,40 @@ def __call__(
198184
AggregateRequest(aggregate_query)
199185
.group_by(
200186
"@route_name",
201-
reducers.avg("vector_distance").alias("avg"),
202-
reducers.min("vector_distance").alias("score")
187+
reducers.avg("vector_distance").alias("avg_distance"),
188+
reducers.min("vector_distance").alias("min_distance")
203189
)
204-
.apply(avg_score="1 - @avg", score="1 - @score")
205190
.dialect(2)
206191
)
207192

208-
top_routes_and_scores = []
209-
aggregate_results = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params)
210-
211-
for result in aggregate_results.rows:
212-
top_routes_and_scores.append(make_dict(convert_bytes(result)))
193+
# run the aggregation query in Redis
194+
aggregate_result: AggregateResult = (
195+
self._index.client
196+
.ft(self._index.name)
197+
.aggregate(aggregate_request, vector_range_query.params)
198+
)
213199

214-
top_routes = self._fetch_routes(top_routes_and_scores)
200+
top_routes_and_scores = sorted([
201+
self._process_result(result) for result in aggregate_result.rows
202+
], key=lambda r: r[sort_by.value])
215203

216-
return top_routes
204+
return top_routes_and_scores[:max_k]
217205

218206

219-
def _fetch_routes(self, top_routes_and_scores: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
220-
"""Fetch route objects and metadata based on top matches.
207+
def _process_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
208+
"""Process resulting route objects and metadata.
221209
222210
Args:
223-
top_routes_and_scores: List of top routes and their scores.
211+
result: Aggregation query result object
224212
225213
Returns:
226214
List[Dict[str, Any]]: Routes with their metadata.
227215
"""
228-
results = []
229-
for route_info in top_routes_and_scores:
230-
route_name = route_info["route_name"]
231-
route = next((r for r in self.routes if r.name == route_name), None)
232-
if route:
233-
results.append({
234-
**route.dict(),
235-
"score": route_info["score"],
236-
"avg_score": route_info["avg_score"]
237-
})
238-
239-
return results
216+
result_dict = make_dict(convert_bytes(result))
217+
route_name = result_dict["route_name"]
218+
route = next((r for r in self.routes if r.name == route_name), None)
219+
return {
220+
**route.dict(),
221+
"avg_distance": float(result_dict["avg_distance"]),
222+
"min_distance": float(result_dict["min_distance"])
223+
}

0 commit comments

Comments
 (0)