Skip to content

Commit eda1379

Browse files
formatting and linting updates
1 parent a8cd496 commit eda1379

File tree

5 files changed

+123
-56
lines changed

5 files changed

+123
-56
lines changed

redisvl/extensions/router/schema.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ class RoutingConfig(BaseModel):
6464
"""The threshold for semantic distance."""
6565
max_k: int = Field(default=1)
6666
"""The maximum number of top matches to return."""
67-
aggregation_method: DistanceAggregationMethod = Field(default=DistanceAggregationMethod.avg)
67+
aggregation_method: DistanceAggregationMethod = Field(
68+
default=DistanceAggregationMethod.avg
69+
)
6870
"""Aggregation method to use to classify queries."""
6971

7072
@validator("max_k")

redisvl/extensions/router/semantic.py

Lines changed: 63 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import hashlib
2-
from typing import Any, Dict, List, Optional
2+
from typing import Any, Dict, List, Optional, Type
33

44
import redis.commands.search.reducers as reducers
55
from pydantic.v1 import BaseModel, Field, PrivateAttr
66
from redis import Redis
7-
from redis.commands.search.aggregation import AggregateRequest, AggregateResult
8-
9-
from redisvl.extensions.router.schema import Route, RoutingConfig, RouteMatch, DistanceAggregationMethod
7+
from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer
8+
9+
from redisvl.extensions.router.schema import (
10+
DistanceAggregationMethod,
11+
Route,
12+
RouteMatch,
13+
RoutingConfig,
14+
)
1015
from redisvl.index import SearchIndex
1116
from redisvl.query import RangeQuery
1217
from redisvl.redis.utils import convert_bytes, make_dict
@@ -18,7 +23,7 @@ class SemanticRouterIndexSchema(IndexSchema):
1823
"""Customized index schema for SemanticRouter."""
1924

2025
@classmethod
21-
def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema':
26+
def from_params(cls, name: str, vector_dims: int) -> "SemanticRouterIndexSchema":
2227
"""Create an index schema based on router name and vector dimensions.
2328
2429
Args:
@@ -30,7 +35,7 @@ def from_params(cls, name: str, vector_dims: int) -> 'SemanticRouterIndexSchema'
3035
"""
3136
return cls(
3237
index=IndexInfo(name=name, prefix=name),
33-
fields=[
38+
fields=[ # type: ignore
3439
{"name": "route_name", "type": "tag"},
3540
{"name": "reference", "type": "text"},
3641
{
@@ -87,7 +92,12 @@ def __init__(
8792
overwrite (bool, optional): Whether to overwrite existing index. Defaults to False.
8893
**kwargs: Additional arguments.
8994
"""
90-
super().__init__(name=name, routes=routes, vectorizer=vectorizer, routing_config=routing_config)
95+
super().__init__(
96+
name=name,
97+
routes=routes,
98+
vectorizer=vectorizer,
99+
routing_config=routing_config,
100+
)
91101
self._initialize_index(redis_client, redis_url, overwrite)
92102

93103
def _initialize_index(
@@ -130,7 +140,7 @@ def route_names(self) -> List[str]:
130140
return [route.name for route in self.routes]
131141

132142
@property
133-
def route_thresholds(self) -> Dict[str, float]:
143+
def route_thresholds(self) -> Dict[str, Optional[float]]:
134144
"""Get the distance thresholds for each route.
135145
136146
Returns:
@@ -168,7 +178,9 @@ def _add_routes(self, routes: List[Route]):
168178
}
169179
)
170180
reference_hash = hashlib.sha256(reference.encode("utf-8")).hexdigest()
171-
keys.append(f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}")
181+
keys.append(
182+
f"{self._index.schema.index.prefix}:{route.name}:{reference_hash}"
183+
)
172184

173185
# set route if does not yet exist client side
174186
if not self.get(route.name):
@@ -204,7 +216,7 @@ def _build_aggregate_request(
204216
self,
205217
vector_range_query: RangeQuery,
206218
aggregation_method: DistanceAggregationMethod,
207-
max_k: int
219+
max_k: int,
208220
) -> AggregateRequest:
209221
"""Build the Redis aggregation request.
210222
@@ -216,6 +228,8 @@ def _build_aggregate_request(
216228
Returns:
217229
AggregateRequest: The constructed aggregation request.
218230
"""
231+
aggregation_func: Type[Reducer]
232+
219233
if aggregation_method == DistanceAggregationMethod.min:
220234
aggregation_func = reducers.min
221235
elif aggregation_method == DistanceAggregationMethod.sum:
@@ -226,7 +240,9 @@ def _build_aggregate_request(
226240
aggregate_query = str(vector_range_query).split(" RETURN")[0]
227241
aggregate_request = (
228242
AggregateRequest(aggregate_query)
229-
.group_by("@route_name", aggregation_func("vector_distance").alias("distance"))
243+
.group_by(
244+
"@route_name", aggregation_func("vector_distance").alias("distance")
245+
)
230246
.sort_by("@distance", max=max_k)
231247
.dialect(2)
232248
)
@@ -237,7 +253,7 @@ def _classify(
237253
self,
238254
vector: List[float],
239255
distance_threshold: float,
240-
aggregation_method: DistanceAggregationMethod
256+
aggregation_method: DistanceAggregationMethod,
241257
) -> List[RouteMatch]:
242258
"""Classify a single query vector.
243259
@@ -256,16 +272,20 @@ def _classify(
256272
return_fields=["route_name"],
257273
)
258274

259-
aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k=1)
260-
route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params)
275+
aggregate_request = self._build_aggregate_request(
276+
vector_range_query, aggregation_method, max_k=1
277+
)
278+
route_matches: AggregateResult = self._index.client.ft( # type: ignore
279+
self._index.name
280+
).aggregate(aggregate_request, vector_range_query.params)
261281
return [self._process_route(route_match) for route_match in route_matches.rows]
262282

263283
def _classify_many(
264284
self,
265285
vector: List[float],
266286
max_k: int,
267287
distance_threshold: float,
268-
aggregation_method: DistanceAggregationMethod
288+
aggregation_method: DistanceAggregationMethod,
269289
) -> List[RouteMatch]:
270290
"""Classify multiple query vectors.
271291
@@ -284,8 +304,12 @@ def _classify_many(
284304
distance_threshold=distance_threshold,
285305
return_fields=["route_name"],
286306
)
287-
aggregate_request = self._build_aggregate_request(vector_range_query, aggregation_method, max_k)
288-
route_matches: AggregateResult = self._index.client.ft(self._index.name).aggregate(aggregate_request, vector_range_query.params)
307+
aggregate_request = self._build_aggregate_request(
308+
vector_range_query, aggregation_method, max_k
309+
)
310+
route_matches: AggregateResult = self._index.client.ft( # type: ignore
311+
self._index.name
312+
).aggregate(aggregate_request, vector_range_query.params)
289313
return [self._process_route(route_match) for route_match in route_matches.rows]
290314

291315
def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool:
@@ -297,7 +321,11 @@ def _pass_threshold(self, route_match: Optional[RouteMatch]) -> bool:
297321
Returns:
298322
bool: True if the route match passes the threshold, False otherwise.
299323
"""
300-
return route_match is not None and route_match.distance <= route_match.route.distance_threshold
324+
if route_match:
325+
if route_match.distance is not None and route_match.route is not None:
326+
if route_match.route.distance_threshold:
327+
return route_match.distance <= route_match.route.distance_threshold
328+
return False
301329

302330
def __call__(
303331
self,
@@ -320,8 +348,12 @@ def __call__(
320348
raise ValueError("Must provide a vector or statement to the router")
321349
vector = self.vectorizer.embed(statement)
322350

323-
distance_threshold = distance_threshold or self.routing_config.distance_threshold
324-
route_matches = self._classify(vector, distance_threshold, self.routing_config.aggregation_method)
351+
distance_threshold = (
352+
distance_threshold or self.routing_config.distance_threshold
353+
)
354+
route_matches = self._classify(
355+
vector, distance_threshold, self.routing_config.aggregation_method
356+
)
325357
route_match = route_matches[0] if route_matches else None
326358

327359
if route_match and self._pass_threshold(route_match):
@@ -352,8 +384,16 @@ def route_many(
352384
raise ValueError("Must provide a vector or statement to the router")
353385
vector = self.vectorizer.embed(statement)
354386

355-
distance_threshold = distance_threshold or self.routing_config.distance_threshold
387+
distance_threshold = (
388+
distance_threshold or self.routing_config.distance_threshold
389+
)
356390
max_k = max_k or self.routing_config.max_k
357-
route_matches = self._classify_many(vector, max_k, distance_threshold, self.routing_config.aggregation_method)
391+
route_matches = self._classify_many(
392+
vector, max_k, distance_threshold, self.routing_config.aggregation_method
393+
)
358394

359-
return [route_match for route_match in route_matches if self._pass_threshold(route_match)]
395+
return [
396+
route_match
397+
for route_match in route_matches
398+
if self._pass_threshold(route_match)
399+
]

redisvl/schema/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def validate_and_create_fields(cls, values):
195195
"""
196196
Validate uniqueness of field names and create valid field instances.
197197
"""
198-
# Ensure index is a dictionary for validation
198+
# Ensure index is a dictionary for validation
199199
index = values.get("index")
200200
if not isinstance(index, IndexInfo):
201201
index = IndexInfo(**index)

tests/integration/test_semantic_router.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
import pytest
22

3-
from redisvl.extensions.router.schema import Route, RoutingConfig
43
from redisvl.extensions.router import SemanticRouter
4+
from redisvl.extensions.router.schema import Route, RoutingConfig
55

66

77
@pytest.fixture
88
def routes():
99
return [
10-
Route(name="greeting", references=["hello", "hi"], metadata={"type": "greeting"}, distance_threshold=0.3),
11-
Route(name="farewell", references=["bye", "goodbye"], metadata={"type": "farewell"}, distance_threshold=0.3)
10+
Route(
11+
name="greeting",
12+
references=["hello", "hi"],
13+
metadata={"type": "greeting"},
14+
distance_threshold=0.3,
15+
),
16+
Route(
17+
name="farewell",
18+
references=["bye", "goodbye"],
19+
metadata={"type": "farewell"},
20+
distance_threshold=0.3,
21+
),
1222
]
1323

24+
1425
@pytest.fixture
1526
def semantic_router(client, routes):
1627
router = SemanticRouter(
1728
name="test-router",
1829
routes=routes,
1930
routing_config=RoutingConfig(distance_threshold=0.3, max_k=2),
2031
redis_client=client,
21-
overwrite=False
32+
overwrite=False,
2233
)
2334
yield router
2435
router._index.delete(drop=True)
@@ -70,6 +81,7 @@ def test_multiple_query(semantic_router):
7081
assert len(matches) > 0
7182
assert matches[0].route.name == "greeting"
7283

84+
7385
def test_update_routing_config(semantic_router):
7486
new_config = RoutingConfig(distance_threshold=0.5, max_k=1)
7587
semantic_router.update_routing_config(new_config)
@@ -85,7 +97,9 @@ def test_vector_query(semantic_router):
8597

8698

8799
def test_vector_query_no_match(semantic_router):
88-
vector = [0.0] * semantic_router.vectorizer.dims # Random vector unlikely to match any route
100+
vector = [
101+
0.0
102+
] * semantic_router.vectorizer.dims # Random vector unlikely to match any route
89103
match = semantic_router(vector=vector)
90104
assert match.route is None
91105

@@ -94,7 +108,11 @@ def test_additional_route(semantic_router):
94108
new_routes = [
95109
Route(
96110
name="politics",
97-
references=["are you liberal or conservative?", "who will you vote for?", "political speech"],
111+
references=[
112+
"are you liberal or conservative?",
113+
"who will you vote for?",
114+
"political speech",
115+
],
98116
metadata={"type": "greeting"},
99117
)
100118
]

0 commit comments

Comments
 (0)