Skip to content

Commit 0b3a5ce

Browse files
authored
Rename for test data class for pytest conflict (#302)
1 parent 494e5e2 commit 0b3a5ce

File tree

7 files changed

+24
-21
lines changed

7 files changed

+24
-21
lines changed

redisvl/utils/optimize/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
22
from redisvl.utils.optimize.cache import CacheThresholdOptimizer
33
from redisvl.utils.optimize.router import RouterThresholdOptimizer
4-
from redisvl.utils.optimize.schema import TestData
4+
from redisvl.utils.optimize.schema import LabeledData
55

66
__all__ = [
77
"CacheThresholdOptimizer",
88
"RouterThresholdOptimizer",
99
"EvalMetric",
1010
"BaseThresholdOptimizer",
11-
"TestData",
11+
"LabeledData",
1212
]

redisvl/utils/optimize/cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
from redisvl.extensions.llmcache.semantic import SemanticCache
77
from redisvl.query import RangeQuery
88
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.optimize.schema import TestData
9+
from redisvl.utils.optimize.schema import LabeledData
1010
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

13-
def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
13+
def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run:
1414
"""Format observed data for evaluation with ranx"""
1515
run_dict: Dict[str, Dict[str, int]] = {}
1616

@@ -30,7 +30,7 @@ def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run:
3030

3131

3232
def _eval_cache(
33-
test_data: List[TestData], threshold: float, qrels: Qrels, metric: str
33+
test_data: List[LabeledData], threshold: float, qrels: Qrels, metric: str
3434
) -> float:
3535
"""Formats run data and evaluates supported metric"""
3636
run = _generate_run_cache(test_data, threshold)
@@ -46,7 +46,7 @@ def _get_best_threshold(metrics: dict) -> float:
4646

4747

4848
def _grid_search_opt_cache(
49-
cache: SemanticCache, test_data: List[TestData], eval_metric: EvalMetric
49+
cache: SemanticCache, test_data: List[LabeledData], eval_metric: EvalMetric
5050
):
5151
"""Evaluates all thresholds in linspace for cache to determine optimal"""
5252
thresholds = np.linspace(0.01, 0.8, 60)

redisvl/utils/optimize/router.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from redisvl.extensions.router.semantic import SemanticRouter
88
from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric
9-
from redisvl.utils.optimize.schema import TestData
9+
from redisvl.utils.optimize.schema import LabeledData
1010
from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels
1111

1212

13-
def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run:
13+
def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> Run:
1414
"""Format router results into format for ranx Run"""
1515
run_dict: Dict[Any, Any] = {}
1616

@@ -26,7 +26,7 @@ def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> R
2626

2727

2828
def _eval_router(
29-
router: SemanticRouter, test_data: List[TestData], qrels: Qrels, eval_metric: str
29+
router: SemanticRouter, test_data: List[LabeledData], qrels: Qrels, eval_metric: str
3030
) -> float:
3131
"""Evaluate acceptable metric given run and qrels data"""
3232
run = _generate_run_router(test_data, router)
@@ -55,7 +55,7 @@ def _router_random_search(
5555

5656
def _random_search_opt_router(
5757
router: SemanticRouter,
58-
test_data: List[TestData],
58+
test_data: List[LabeledData],
5959
qrels: Qrels,
6060
eval_metric: EvalMetric,
6161
**kwargs: Any,
@@ -67,12 +67,15 @@ def _random_search_opt_router(
6767
best_thresholds = router.route_thresholds
6868

6969
max_iterations = kwargs.get("max_iterations", 20)
70+
search_step = kwargs.get("search_step", 0.10)
7071

7172
for _ in range(max_iterations):
7273
route_names = router.route_names
7374
route_thresholds = router.route_thresholds
7475
thresholds = _router_random_search(
75-
route_names=route_names, route_thresholds=route_thresholds
76+
route_names=route_names,
77+
route_thresholds=route_thresholds,
78+
search_step=search_step,
7679
)
7780
router.update_route_thresholds(thresholds)
7881
score = _eval_router(router, test_data, qrels, eval_metric.value)

redisvl/utils/optimize/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ulid import ULID
55

66

7-
class TestData(BaseModel):
7+
class LabeledData(BaseModel):
88
id: str = Field(default_factory=lambda: str(ULID()))
99
query: str
1010
query_match: Optional[str]

redisvl/utils/optimize/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from ranx import Qrels
44

5-
from redisvl.utils.optimize.schema import TestData
5+
from redisvl.utils.optimize.schema import LabeledData
66

77
NULL_RESPONSE_KEY = "no_match"
88

99

10-
def _format_qrels(test_data: List[TestData]) -> Qrels:
10+
def _format_qrels(test_data: List[LabeledData]) -> Qrels:
1111
"""Utility function for creating qrels for evaluation with ranx"""
1212
qrels_dict = {}
1313

@@ -21,6 +21,6 @@ def _format_qrels(test_data: List[TestData]) -> Qrels:
2121
return Qrels(qrels_dict)
2222

2323

24-
def _validate_test_dict(test_dict: List[dict]) -> List[TestData]:
24+
def _validate_test_dict(test_dict: List[dict]) -> List[LabeledData]:
2525
"""Convert/validate test_dict for use in optimizer"""
26-
return [TestData(**d) for d in test_dict]
26+
return [LabeledData(**d) for d in test_dict]

tests/integration/test_threshold_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def test_routes_different_distance_thresholds_optimizer_default(
111111

112112
# now run optimizer
113113
router_optimizer = RouterThresholdOptimizer(router, test_data_optimization)
114-
router_optimizer.optimize(max_iterations=10)
114+
router_optimizer.optimize(max_iterations=10, search_step=0.5)
115115

116116
# test that it updated thresholds beyond the null case
117117
for route in routes:

tests/unit/test_threshold_optimizer_utility.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ranx import evaluate
99

10-
from redisvl.utils.optimize import TestData
10+
from redisvl.utils.optimize import LabeledData
1111
from redisvl.utils.optimize.cache import _generate_run_cache
1212
from redisvl.utils.optimize.utils import _format_qrels
1313

@@ -26,15 +26,15 @@ def test_known_precision_case():
2626
"""
2727
# Setup test data
2828
test_data = [
29-
TestData(
29+
LabeledData(
3030
query="test query 1",
3131
query_match="doc1",
3232
response=[
3333
{"id": "doc1", "vector_distance": 0.2},
3434
{"id": "doc2", "vector_distance": 0.3},
3535
],
3636
),
37-
TestData(
37+
LabeledData(
3838
query="test query 2",
3939
query_match="doc3",
4040
response=[
@@ -58,7 +58,7 @@ def test_known_precision_case():
5858
def test_known_precision_with_no_matches():
5959
"""Test case where some queries have no matches."""
6060
test_data = [
61-
TestData(
61+
LabeledData(
6262
query="test query 2",
6363
query_match="", # Expecting no match
6464
response=[],

0 commit comments

Comments
 (0)