diff --git a/engine/base_client/client.py b/engine/base_client/client.py index 879a6e7d..4f097bfc 100644 --- a/engine/base_client/client.py +++ b/engine/base_client/client.py @@ -7,7 +7,9 @@ from benchmark import ROOT_DIR from benchmark.dataset import Dataset +from dataset_reader.base_reader import BaseReader from engine.base_client.configure import BaseConfigurator +from engine.base_client.distances import Distance from engine.base_client.search import BaseSearcher from engine.base_client.upload import BaseUploader @@ -168,6 +170,23 @@ def run_experiment( if filter_ef_runtime and isinstance(ef, int) and (ef not in ef_runtime): print(f"\tSkipping ef runtime: {ef}; #clients {client_count} (not in ef_runtime filter)") continue + + if (precision := search_params.get("calibration_precision", None)) is not None: + top = search_params["top"] + calibration_param = search_params["calibration_param"] + calibration_value, calibration_precision = calibrate( + searcher, + calibration_param, + top, + precision, + dataset.config.distance, + reader, + ) + print( + f"Calibrated {top=} {precision=} {calibration_value=} {calibration_precision=!s}" + ) + searcher.search_params["search_params"][calibration_param] = calibration_value + for repetition in range(1, REPETITIONS + 1): print( f"\tRunning repetition {repetition} ef runtime: {ef}; #clients {client_count}" @@ -196,3 +215,52 @@ def delete_client(self): for s in self.searchers: s.delete_client() + +def calibrate( + searcher: BaseSearcher, + calibration_param: str, + min_value: int, + precision: float, + distance: Distance, + reader: BaseReader, + max_value: int = 1000, +) -> tuple[int, float]: + """Calibrate searcher for a given precision.""" + if min_value > max_value: + raise ValueError( + f"{min_value=} cannot be greater than {max_value=}" + ) + lower_bound = min_value + upper_bound = max_value + lower_bound_visited = False + upper_bound_visited = False + current = (lower_bound + upper_bound) // 2 + previous = current + current_precision = 0 + while True: + searcher.search_params["search_params"][calibration_param] = current + search_stats = searcher.search_all(distance, reader.read_queries()) + previous_precision = current_precision + current_precision = search_stats["mean_precisions"] + if current_precision == precision: + return current, current_precision + elif current_precision > precision: + upper_bound = current + upper_bound_visited = True + else: + lower_bound = current + lower_bound_visited = True + next_value = (lower_bound + upper_bound) // 2 + if ( + (lower_bound_visited and next_value == lower_bound) + or (upper_bound_visited and next_value == upper_bound) + ): + if abs(previous_precision - precision) < abs(current_precision - precision): + final_precision = previous_precision + final_value = previous + else: + final_precision = current_precision + final_value = current + return final_value, final_precision + previous = current + current = next_value