Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions apps/build_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ namespace po = boost::program_options;
int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label,
label_type;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold;
label_type, seller_file;
uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, num_diverse_build;
float B, M;
bool append_reorder_data = false;
bool use_opq = false;
bool diverse_index = false;


po::options_description desc{
program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")};
Expand Down Expand Up @@ -78,6 +80,11 @@ int main(int argc, char **argv)
"internally where each node has a maximum F labels.");
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("seller_file", po::value<std::string>(&seller_file)->default_value(""),
"In case of diverse index, need the seller file");
optional_configs.add_options()("NumDiverse", po::value<uint32_t>(&num_diverse_build)->default_value(0),
program_options_utils::NUM_DIVERSE);


// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
Expand Down Expand Up @@ -138,23 +145,24 @@ int main(int argc, char **argv)
std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " +
std::string(std::to_string(append_reorder_data)) + " " +
std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD));

if(seller_file != "")
diverse_index = true;
try
{
if (label_file != "" && label_type == "ushort")
{
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float, uint16_t>(
data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix,
use_filters, label_file, universal_label, filter_threshold, Lf);
use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
Expand All @@ -166,15 +174,15 @@ int main(int argc, char **argv)
if (data_type == std::string("int8"))
return diskann::build_disk_index<int8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else if (data_type == std::string("uint8"))
return diskann::build_disk_index<uint8_t>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else if (data_type == std::string("float"))
return diskann::build_disk_index<float>(data_path.c_str(), index_path_prefix.c_str(), params.c_str(),
metric, use_opq, codebook_prefix, use_filters, label_file,
universal_label, filter_threshold, Lf);
universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build);
else
{
diskann::cerr << "Error. Unsupported data type" << std::endl;
Expand Down
19 changes: 17 additions & 2 deletions apps/build_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ namespace po = boost::program_options;

int main(int argc, char **argv)
{
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type;
uint32_t num_threads, R, L, Lf, build_PQ_bytes;
std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type, seller_file;
uint32_t num_threads, R, L, Lf, build_PQ_bytes, num_diverse_build;
float alpha;
bool use_pq_build, use_opq;
bool diverse_index=false;

po::options_description desc{
program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")};
Expand Down Expand Up @@ -70,6 +71,12 @@ int main(int argc, char **argv)
program_options_utils::FILTERED_LBUILD);
optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
optional_configs.add_options()("seller_file", po::value<std::string>(&seller_file)->default_value(""),
program_options_utils::DIVERSITY_FILE);
optional_configs.add_options()("NumDiverse", po::value<uint32_t>(&num_diverse_build)->default_value(1),
program_options_utils::NUM_DIVERSE);



// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
Expand Down Expand Up @@ -112,6 +119,9 @@ int main(int argc, char **argv)
return -1;
}

if(seller_file != "")
diverse_index = true;

try
{
diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha
Expand All @@ -120,11 +130,16 @@ int main(int argc, char **argv)
size_t data_num, data_dim;
diskann::get_bin_metadata(data_path, data_num, data_dim);

std::cout<<"Num diverse build: " << num_diverse_build << std::endl;

auto index_build_params = diskann::IndexWriteParametersBuilder(L, R)
.with_filter_list_size(Lf)
.with_alpha(alpha)
.with_saturate_graph(false)
.with_num_threads(num_threads)
.with_diverse_index(diverse_index)
.with_seller_file(seller_file)
.with_num_diverse_build(num_diverse_build)
.build();

auto filter_params = diskann::IndexFilterParamsBuilder()
Expand Down
24 changes: 14 additions & 10 deletions apps/search_disk_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth,
const uint32_t num_nodes_to_cache, const uint32_t search_io_limit,
const std::vector<uint32_t> &Lvec, const float fail_if_recall_below,
const std::vector<std::string> &query_filters, const bool use_reorder_data = false)
const std::vector<std::string> &query_filters, const bool use_reorder_data = false, const uint32_t max_K_per_seller = std::numeric_limits<uint32_t>::max())
{
diskann::cout << "Search parameters: #threads: " << num_threads << ", ";
if (beamwidth <= 0)
Expand Down Expand Up @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
_pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L,
query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at),
optimized_beamwidth, use_reorder_data, stats + i);
optimized_beamwidth, max_K_per_seller, use_reorder_data, stats + i);
}
else
{
Expand All @@ -247,7 +247,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre
}
_pFlashIndex->cached_beam_search(
query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at),
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search,
query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits<uint32_t>::max(),
use_reorder_data, stats + i);
}
}
Expand Down Expand Up @@ -314,7 +314,8 @@ int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label,
label_type, query_filters_file;
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit;
uint32_t max_K_per_seller = std::numeric_limits<uint32_t>::max();
std::vector<uint32_t> Lvec;
bool use_reorder_data = false;
float fail_if_recall_below = 0.0f;
Expand Down Expand Up @@ -372,6 +373,9 @@ int main(int argc, char **argv)
optional_configs.add_options()("fail_if_recall_below",
po::value<float>(&fail_if_recall_below)->default_value(0.0f),
program_options_utils::FAIL_IF_RECALL_BELOW);
optional_configs.add_options()("max_K_per_seller", po::value<uint32_t>(&max_K_per_seller)->default_value(std::numeric_limits<uint32_t>::max()),
"Diverse search, max number of results per seller");


// Merge required and optional parameters
desc.add(required_configs).add(optional_configs);
Expand Down Expand Up @@ -451,15 +455,15 @@ int main(int argc, char **argv)
if (data_type == std::string("float"))
return search_disk_index<float, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t, uint16_t>(
metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W,
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data);
num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
Expand All @@ -471,15 +475,15 @@ int main(int argc, char **argv)
if (data_type == std::string("float"))
return search_disk_index<float>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else if (data_type == std::string("int8"))
return search_disk_index<int8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else if (data_type == std::string("uint8"))
return search_disk_index<uint8_t>(metric, index_path_prefix, result_path_prefix, query_file, gt_file,
num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec,
fail_if_recall_below, query_filters, use_reorder_data);
fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller);
else
{
std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl;
Expand Down
Loading
Loading