Skip to content

Commit a8cd496

Browse files
update router and improve tests
1 parent 64cd6ad commit a8cd496

File tree

7 files changed

+522
-239
lines changed

7 files changed

+522
-239
lines changed

redisvl/extensions/router/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1+
from redisvl.extensions.router.schema import Route, RoutingConfig
12
from redisvl.extensions.router.semantic import SemanticRouter
2-
from redisvl.extensions.router.routes import Route, RoutingConfig
3-
43

54
__all__ = ["SemanticRouter", "Route", "RoutingConfig"]

redisvl/extensions/router/routes.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

redisvl/extensions/router/schema.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from enum import Enum
2+
from typing import Dict, List, Optional
3+
4+
from pydantic.v1 import BaseModel, Field, validator
5+
6+
7+
class Route(BaseModel):
8+
"""Model representing a routing path with associated metadata and thresholds."""
9+
10+
name: str
11+
"""The name of the route."""
12+
references: List[str]
13+
"""List of reference phrases for the route."""
14+
metadata: Dict[str, str] = Field(default={})
15+
"""Metadata associated with the route."""
16+
distance_threshold: Optional[float] = Field(default=None)
17+
"""Distance threshold for matching the route."""
18+
19+
@validator("name")
20+
def name_must_not_be_empty(cls, v):
21+
if not v or not v.strip():
22+
raise ValueError("Route name must not be empty")
23+
return v
24+
25+
@validator("references")
26+
def references_must_not_be_empty(cls, v):
27+
if not v:
28+
raise ValueError("References must not be empty")
29+
if any(not ref.strip() for ref in v):
30+
raise ValueError("All references must be non-empty strings")
31+
return v
32+
33+
@validator("distance_threshold")
34+
def distance_threshold_must_be_positive(cls, v):
35+
if v is not None and v <= 0:
36+
raise ValueError("Route distance threshold must be greater than zero")
37+
return v
38+
39+
40+
class RouteMatch(BaseModel):
41+
"""Model representing a matched route with distance information."""
42+
43+
route: Optional[Route] = None
44+
"""The matched route."""
45+
distance: Optional[float] = Field(default=None)
46+
"""The distance of the match."""
47+
48+
49+
class DistanceAggregationMethod(Enum):
50+
"""Enumeration for distance aggregation methods."""
51+
52+
avg = "avg"
53+
"""Compute the average of the vector distances."""
54+
min = "min"
55+
"""Compute the minimum of the vector distances."""
56+
sum = "sum"
57+
"""Compute the sum of the vector distances."""
58+
59+
60+
class RoutingConfig(BaseModel):
61+
"""Configuration for routing behavior."""
62+
63+
distance_threshold: float = Field(default=0.5)
64+
"""The threshold for semantic distance."""
65+
max_k: int = Field(default=1)
66+
"""The maximum number of top matches to return."""
67+
aggregation_method: DistanceAggregationMethod = Field(default=DistanceAggregationMethod.avg)
68+
"""Aggregation method to use to classify queries."""
69+
70+
@validator("max_k")
71+
def max_k_must_be_positive(cls, v):
72+
if v <= 0:
73+
raise ValueError("max_k must be a positive integer")
74+
return v
75+
76+
@validator("distance_threshold")
77+
def distance_threshold_must_be_valid(cls, v):
78+
if v <= 0 or v > 1:
79+
raise ValueError("distance_threshold must be between 0 and 1")
80+
return v

0 commit comments

Comments
 (0)