diff --git a/redisvl/utils/optimize/__init__.py b/redisvl/utils/optimize/__init__.py index c420610f..8025228f 100644 --- a/redisvl/utils/optimize/__init__.py +++ b/redisvl/utils/optimize/__init__.py @@ -1,12 +1,12 @@ from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric from redisvl.utils.optimize.cache import CacheThresholdOptimizer from redisvl.utils.optimize.router import RouterThresholdOptimizer -from redisvl.utils.optimize.schema import TestData +from redisvl.utils.optimize.schema import LabeledData __all__ = [ "CacheThresholdOptimizer", "RouterThresholdOptimizer", "EvalMetric", "BaseThresholdOptimizer", - "TestData", + "LabeledData", ] diff --git a/redisvl/utils/optimize/cache.py b/redisvl/utils/optimize/cache.py index e787bc66..f88c53ff 100644 --- a/redisvl/utils/optimize/cache.py +++ b/redisvl/utils/optimize/cache.py @@ -6,11 +6,11 @@ from redisvl.extensions.llmcache.semantic import SemanticCache from redisvl.query import RangeQuery from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric -from redisvl.utils.optimize.schema import TestData +from redisvl.utils.optimize.schema import LabeledData from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels -def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run: +def _generate_run_cache(test_data: List[LabeledData], threshold: float) -> Run: """Format observed data for evaluation with ranx""" run_dict: Dict[str, Dict[str, int]] = {} @@ -30,7 +30,7 @@ def _generate_run_cache(test_data: List[TestData], threshold: float) -> Run: def _eval_cache( - test_data: List[TestData], threshold: float, qrels: Qrels, metric: str + test_data: List[LabeledData], threshold: float, qrels: Qrels, metric: str ) -> float: """Formats run data and evaluates supported metric""" run = _generate_run_cache(test_data, threshold) @@ -46,7 +46,7 @@ def _get_best_threshold(metrics: dict) -> float: def _grid_search_opt_cache( - cache: SemanticCache, test_data: List[TestData], eval_metric: EvalMetric + cache: SemanticCache, test_data: List[LabeledData], eval_metric: EvalMetric ): """Evaluates all thresholds in linspace for cache to determine optimal""" thresholds = np.linspace(0.01, 0.8, 60) diff --git a/redisvl/utils/optimize/router.py b/redisvl/utils/optimize/router.py index 30e70004..40a7ea7d 100644 --- a/redisvl/utils/optimize/router.py +++ b/redisvl/utils/optimize/router.py @@ -6,11 +6,11 @@ from redisvl.extensions.router.semantic import SemanticRouter from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric -from redisvl.utils.optimize.schema import TestData +from redisvl.utils.optimize.schema import LabeledData from redisvl.utils.optimize.utils import NULL_RESPONSE_KEY, _format_qrels -def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> Run: +def _generate_run_router(test_data: List[LabeledData], router: SemanticRouter) -> Run: """Format router results into format for ranx Run""" run_dict: Dict[Any, Any] = {} @@ -26,7 +26,7 @@ def _generate_run_router(test_data: List[TestData], router: SemanticRouter) -> R def _eval_router( - router: SemanticRouter, test_data: List[TestData], qrels: Qrels, eval_metric: str + router: SemanticRouter, test_data: List[LabeledData], qrels: Qrels, eval_metric: str ) -> float: """Evaluate acceptable metric given run and qrels data""" run = _generate_run_router(test_data, router) @@ -55,7 +55,7 @@ def _router_random_search( def _random_search_opt_router( router: SemanticRouter, - test_data: List[TestData], + test_data: List[LabeledData], qrels: Qrels, eval_metric: EvalMetric, **kwargs: Any, @@ -67,12 +67,15 @@ def _random_search_opt_router( best_thresholds = router.route_thresholds max_iterations = kwargs.get("max_iterations", 20) + search_step = kwargs.get("search_step", 0.10) for _ in range(max_iterations): route_names = router.route_names route_thresholds = router.route_thresholds thresholds = _router_random_search( - route_names=route_names, route_thresholds=route_thresholds + route_names=route_names, + route_thresholds=route_thresholds, + search_step=search_step, ) router.update_route_thresholds(thresholds) score = _eval_router(router, test_data, qrels, eval_metric.value) diff --git a/redisvl/utils/optimize/schema.py b/redisvl/utils/optimize/schema.py index dfade3cc..f71d10f6 100644 --- a/redisvl/utils/optimize/schema.py +++ b/redisvl/utils/optimize/schema.py @@ -4,7 +4,7 @@ from ulid import ULID -class TestData(BaseModel): +class LabeledData(BaseModel): id: str = Field(default_factory=lambda: str(ULID())) query: str query_match: Optional[str] diff --git a/redisvl/utils/optimize/utils.py b/redisvl/utils/optimize/utils.py index 2a56504c..bebc1c79 100644 --- a/redisvl/utils/optimize/utils.py +++ b/redisvl/utils/optimize/utils.py @@ -2,12 +2,12 @@ from ranx import Qrels -from redisvl.utils.optimize.schema import TestData +from redisvl.utils.optimize.schema import LabeledData NULL_RESPONSE_KEY = "no_match" -def _format_qrels(test_data: List[TestData]) -> Qrels: +def _format_qrels(test_data: List[LabeledData]) -> Qrels: """Utility function for creating qrels for evaluation with ranx""" qrels_dict = {} @@ -21,6 +21,6 @@ def _format_qrels(test_data: List[TestData]) -> Qrels: return Qrels(qrels_dict) -def _validate_test_dict(test_dict: List[dict]) -> List[TestData]: +def _validate_test_dict(test_dict: List[dict]) -> List[LabeledData]: """Convert/validate test_dict for use in optimizer""" - return [TestData(**d) for d in test_dict] + return [LabeledData(**d) for d in test_dict] diff --git a/tests/integration/test_threshold_optimizer.py b/tests/integration/test_threshold_optimizer.py index 09a93dd3..44871901 100644 --- a/tests/integration/test_threshold_optimizer.py +++ b/tests/integration/test_threshold_optimizer.py @@ -111,7 +111,7 @@ def test_routes_different_distance_thresholds_optimizer_default( # now run optimizer router_optimizer = RouterThresholdOptimizer(router, test_data_optimization) - router_optimizer.optimize(max_iterations=10) + router_optimizer.optimize(max_iterations=10, search_step=0.5) # test that it updated thresholds beyond the null case for route in routes: diff --git a/tests/unit/test_threshold_optimizer_utility.py b/tests/unit/test_threshold_optimizer_utility.py index 9dc844c9..61010fc3 100644 --- a/tests/unit/test_threshold_optimizer_utility.py +++ b/tests/unit/test_threshold_optimizer_utility.py @@ -7,7 +7,7 @@ from ranx import evaluate -from redisvl.utils.optimize import TestData +from redisvl.utils.optimize import LabeledData from redisvl.utils.optimize.cache import _generate_run_cache from redisvl.utils.optimize.utils import _format_qrels @@ -26,7 +26,7 @@ def test_known_precision_case(): """ # Setup test data test_data = [ - TestData( + LabeledData( query="test query 1", query_match="doc1", response=[ @@ -34,7 +34,7 @@ def test_known_precision_case(): {"id": "doc2", "vector_distance": 0.3}, ], ), - TestData( + LabeledData( query="test query 2", query_match="doc3", response=[ @@ -58,7 +58,7 @@ def test_known_precision_case(): def test_known_precision_with_no_matches(): """Test case where some queries have no matches.""" test_data = [ - TestData( + LabeledData( query="test query 2", query_match="", # Expecting no match response=[],