Skip to content

Add calibration #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions engine/base_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

from benchmark import ROOT_DIR
from benchmark.dataset import Dataset
from dataset_reader.base_reader import BaseReader
from engine.base_client.configure import BaseConfigurator
from engine.base_client.distances import Distance
from engine.base_client.search import BaseSearcher
from engine.base_client.upload import BaseUploader

Expand Down Expand Up @@ -168,6 +170,23 @@ def run_experiment(
if filter_ef_runtime and isinstance(ef, int) and (ef not in ef_runtime):
print(f"\tSkipping ef runtime: {ef}; #clients {client_count} (not in ef_runtime filter)")
continue

if (precision := search_params.get("calibration_precision", None)) is not None:
top = search_params["top"]
calibration_param = search_params["calibration_param"]
calibration_value, calibration_precision = calibrate(
searcher,
calibration_param,
top,
precision,
dataset.config.distance,
reader,
)
print(
f"Calibrated {top=} {precision=} {calibration_value=} {calibration_precision=!s}"
)
searcher.search_params["search_params"][calibration_param] = calibration_value

for repetition in range(1, REPETITIONS + 1):
print(
f"\tRunning repetition {repetition} ef runtime: {ef}; #clients {client_count}"
Expand Down Expand Up @@ -196,3 +215,52 @@ def delete_client(self):

for s in self.searchers:
s.delete_client()

def calibrate(
searcher: BaseSearcher,
calibration_param: str,
min_value: int,
precision: float,
distance: Distance,
reader: BaseReader,
max_value: int = 1000,
) -> tuple[int, float]:
"""Calibrate searcher for a given precision."""
if min_value > max_value:
raise ValueError(
f"{min_value=} cannot be greater than {max_value=}"
)
lower_bound = min_value
upper_bound = max_value
lower_bound_visited = False
upper_bound_visited = False
current = (lower_bound + upper_bound) // 2
previous = current
current_precision = 0
while True:
searcher.search_params["search_params"][calibration_param] = current
search_stats = searcher.search_all(distance, reader.read_queries())
previous_precision = current_precision
current_precision = search_stats["mean_precisions"]
if current_precision == precision:
return current, current_precision
elif current_precision > precision:
upper_bound = current
upper_bound_visited = True
else:
lower_bound = current
lower_bound_visited = True
next_value = (lower_bound + upper_bound) // 2
if (
(lower_bound_visited and next_value == lower_bound)
or (upper_bound_visited and next_value == upper_bound)
):
if abs(previous_precision - precision) < abs(current_precision - precision):
final_precision = previous_precision
final_value = previous
else:
final_precision = current_precision
final_value = current
return final_value, final_precision
previous = current
current = next_value