diff --git a/apps/build_disk_index.cpp b/apps/build_disk_index.cpp index f48b61726..c94df0866 100644 --- a/apps/build_disk_index.cpp +++ b/apps/build_disk_index.cpp @@ -16,11 +16,13 @@ namespace po = boost::program_options; int main(int argc, char **argv) { std::string data_type, dist_fn, data_path, index_path_prefix, codebook_prefix, label_file, universal_label, - label_type; - uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold; + label_type, seller_file; + uint32_t num_threads, R, L, disk_PQ, build_PQ, QD, Lf, filter_threshold, num_diverse_build; float B, M; bool append_reorder_data = false; bool use_opq = false; + bool diverse_index = false; + po::options_description desc{ program_options_utils::make_program_description("build_disk_index", "Build a disk-based index.")}; @@ -78,6 +80,11 @@ int main(int argc, char **argv) "internally where each node has a maximum F labels."); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), + "In case of diverse index, need the seller file"); + optional_configs.add_options()("NumDiverse", po::value(&num_diverse_build)->default_value(0), + program_options_utils::NUM_DIVERSE); + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -138,23 +145,24 @@ int main(int argc, char **argv) std::string(std::to_string(num_threads)) + " " + std::string(std::to_string(disk_PQ)) + " " + std::string(std::to_string(append_reorder_data)) + " " + std::string(std::to_string(build_PQ)) + " " + std::string(std::to_string(QD)); - + if(seller_file != "") + diverse_index = true; try { if (label_file != "" && label_type == "ushort") { if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), - metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, codebook_prefix, use_filters, label_file, + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("uint8")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("float")) return diskann::build_disk_index( data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, - use_filters, label_file, universal_label, filter_threshold, Lf); + use_filters, label_file, universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else { diskann::cerr << "Error. Unsupported data type" << std::endl; @@ -166,15 +174,15 @@ int main(int argc, char **argv) if (data_type == std::string("int8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("uint8")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else if (data_type == std::string("float")) return diskann::build_disk_index(data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, use_opq, codebook_prefix, use_filters, label_file, - universal_label, filter_threshold, Lf); + universal_label, filter_threshold, Lf, diverse_index, seller_file, num_diverse_build); else { diskann::cerr << "Error. Unsupported data type" << std::endl; diff --git a/apps/build_memory_index.cpp b/apps/build_memory_index.cpp index 544e42dee..435c7f3a7 100644 --- a/apps/build_memory_index.cpp +++ b/apps/build_memory_index.cpp @@ -24,10 +24,11 @@ namespace po = boost::program_options; int main(int argc, char **argv) { - std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type; - uint32_t num_threads, R, L, Lf, build_PQ_bytes; + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, universal_label, label_type, seller_file; + uint32_t num_threads, R, L, Lf, build_PQ_bytes, num_diverse_build; float alpha; bool use_pq_build, use_opq; + bool diverse_index=false; po::options_description desc{ program_options_utils::make_program_description("build_memory_index", "Build a memory-based DiskANN index.")}; @@ -70,6 +71,12 @@ int main(int argc, char **argv) program_options_utils::FILTERED_LBUILD); optional_configs.add_options()("label_type", po::value(&label_type)->default_value("uint"), program_options_utils::LABEL_TYPE_DESCRIPTION); + optional_configs.add_options()("seller_file", po::value(&seller_file)->default_value(""), + program_options_utils::DIVERSITY_FILE); + optional_configs.add_options()("NumDiverse", po::value(&num_diverse_build)->default_value(1), + program_options_utils::NUM_DIVERSE); + + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -112,6 +119,9 @@ int main(int argc, char **argv) return -1; } + if(seller_file != "") + diverse_index = true; + try { diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha @@ -120,11 +130,16 @@ int main(int argc, char **argv) size_t data_num, data_dim; diskann::get_bin_metadata(data_path, data_num, data_dim); + std::cout<<"Num diverse build: " << num_diverse_build << std::endl; + auto index_build_params = diskann::IndexWriteParametersBuilder(L, R) .with_filter_list_size(Lf) .with_alpha(alpha) .with_saturate_graph(false) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(seller_file) + .with_num_diverse_build(num_diverse_build) .build(); auto filter_params = diskann::IndexFilterParamsBuilder() diff --git a/apps/search_disk_index.cpp b/apps/search_disk_index.cpp index 7e2a7ac6d..203079e8e 100644 --- a/apps/search_disk_index.cpp +++ b/apps/search_disk_index.cpp @@ -53,7 +53,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre const uint32_t num_threads, const uint32_t recall_at, const uint32_t beamwidth, const uint32_t num_nodes_to_cache, const uint32_t search_io_limit, const std::vector &Lvec, const float fail_if_recall_below, - const std::vector &query_filters, const bool use_reorder_data = false) + const std::vector &query_filters, const bool use_reorder_data = false, const uint32_t max_K_per_seller = std::numeric_limits::max()) { diskann::cout << "Search parameters: #threads: " << num_threads << ", "; if (beamwidth <= 0) @@ -232,7 +232,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre _pFlashIndex->cached_beam_search(query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, use_reorder_data, stats + i); + optimized_beamwidth, max_K_per_seller, use_reorder_data, stats + i); } else { @@ -247,7 +247,7 @@ int search_disk_index(diskann::Metric &metric, const std::string &index_path_pre } _pFlashIndex->cached_beam_search( query + (i * query_aligned_dim), recall_at, L, query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, + query_result_dists[test_id].data() + (i * recall_at), optimized_beamwidth, true, label_for_search, std::numeric_limits::max(), use_reorder_data, stats + i); } } @@ -314,7 +314,8 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + uint32_t num_threads, K, W, num_nodes_to_cache, search_io_limit; + uint32_t max_K_per_seller = std::numeric_limits::max(); std::vector Lvec; bool use_reorder_data = false; float fail_if_recall_below = 0.0f; @@ -372,6 +373,9 @@ int main(int argc, char **argv) optional_configs.add_options()("fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), program_options_utils::FAIL_IF_RECALL_BELOW); + optional_configs.add_options()("max_K_per_seller", po::value(&max_K_per_seller)->default_value(std::numeric_limits::max()), + "Diverse search, max number of results per seller"); + // Merge required and optional parameters desc.add(required_configs).add(optional_configs); @@ -451,15 +455,15 @@ int main(int argc, char **argv) if (data_type == std::string("float")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("int8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) return search_disk_index( metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, - num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data); + num_nodes_to_cache, search_io_limit, Lvec, fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; @@ -471,15 +475,15 @@ int main(int argc, char **argv) if (data_type == std::string("float")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("int8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else if (data_type == std::string("uint8")) return search_disk_index(metric, index_path_prefix, result_path_prefix, query_file, gt_file, num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - fail_if_recall_below, query_filters, use_reorder_data); + fail_if_recall_below, query_filters, use_reorder_data, max_K_per_seller); else { std::cerr << "Unsupported data type. Use float or int8 or uint8" << std::endl; diff --git a/apps/search_memory_index.cpp b/apps/search_memory_index.cpp index 1a9acc285..12fc918b5 100644 --- a/apps/search_memory_index.cpp +++ b/apps/search_memory_index.cpp @@ -25,13 +25,69 @@ namespace po = boost::program_options; + +void parse_seller_file(const std::string &label_file, size_t &num_points, std::vector &location_to_seller) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << " Search code: Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + + template int search_memory_index(diskann::Metric &metric, const std::string &index_path, const std::string &result_path_prefix, const std::string &query_file, const std::string &truthset_file, const uint32_t num_threads, const uint32_t recall_at, const bool print_all_recalls, const std::vector &Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const std::vector &query_filters, const float fail_if_recall_below) + const std::vector &query_filters, const float fail_if_recall_below, const uint32_t max_K_per_seller = std::numeric_limits::max(), const bool diverse_search = false, const bool scale_seller_limits = false, const bool post_process = false) { + std::cout< location_to_sellers; + std::string seller_file = index_path +"_sellers.txt"; + if (file_exists(seller_file)) { + std::cout<<"Here" << std::endl; + uint64_t num_pts_seller_file; + parse_seller_file(seller_file, num_pts_seller_file, location_to_sellers); + } using TagT = uint32_t; // Load the query file T *query = nullptr; @@ -68,6 +124,7 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, } } + //query_num = 2; const size_t num_frozen_pts = diskann::get_graph_num_frozen_points(index_path); auto config = diskann::IndexConfigBuilder() @@ -153,10 +210,19 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, continue; } - query_result_ids[test_id].resize(recall_at * query_num); - query_result_dists[test_id].resize(recall_at * query_num); + query_result_ids[test_id].resize(recall_at * query_num, std::numeric_limits::max()); + query_result_dists[test_id].resize(recall_at * query_num, std::numeric_limits::max()); std::vector res = std::vector(); + //uint32_t maxLperSeller = (max_L_per_seller > 0) ? max_L_per_seller : L; + + //maxLperSeller = (maxLperSeller == 0)? 1 : maxLperSeller; + uint32_t maxLperSeller = max_K_per_seller; + if (diverse_search && scale_seller_limits) { + maxLperSeller = (1.0*L* max_K_per_seller)/(1.0*recall_at); + // std::cout< results(L,std::numeric_limits::max()); + std::vector dists(L,std::numeric_limits::max()); + uint32_t K_to_use = (post_process == true) ? L : recall_at; + + if (diverse_search) { + cmp_stats[i] = index - ->search(query + i * query_aligned_dim, recall_at, L, - query_result_ids[test_id].data() + i * recall_at) + ->diverse_search(query + i * query_aligned_dim, K_to_use, L, maxLperSeller, + results.data(), dists.data()) .second; + } else { +// { + cmp_stats[i] = index + ->search(query + i * query_aligned_dim, K_to_use, L, + results.data(), dists.data()) + .second; + } + if (post_process) { + diskann::bestCandidates final_results(recall_at, max_K_per_seller, location_to_sellers); + for (uint32_t rr = 0; rr < L; rr++) { + final_results.insert(results[rr], dists[rr]); + } + + for (uint32_t ctr = 0; ctr < std::min(final_results.best_L_nodes.size(), (uint64_t)recall_at); ctr++) { + query_result_ids[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].id; + query_result_dists[test_id][recall_at * i + ctr] = final_results.best_L_nodes._data[ctr].distance; + } + } else { + for (uint32_t ctr = 0; ctr < std::min(results.size(),(uint64_t)recall_at); ctr++) { + query_result_ids[test_id][recall_at * i + ctr] = results[ctr]; + query_result_dists[test_id][recall_at * i + ctr] = dists[ctr]; + } + } } auto qe = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = qe - qs; @@ -222,7 +317,10 @@ int search_memory_index(diskann::Metric &metric, const std::string &index_path, for (uint32_t curr_recall = first_recall; curr_recall <= recall_at; curr_recall++) { recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, - query_result_ids[test_id].data(), recall_at, curr_recall)); + query_result_ids[test_id].data(), recall_at, curr_recall, query_result_dists[test_id].data())); +// recalls.push_back(diskann::calculate_recall((uint32_t)query_num, gt_ids, gt_dists, (uint32_t)gt_dim, +// query_result_ids[test_id].data(), recall_at, curr_recall)); + } } @@ -279,9 +377,9 @@ int main(int argc, char **argv) { std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type, query_filters_file; - uint32_t num_threads, K; + uint32_t num_threads, K, max_L_per_seller; std::vector Lvec; - bool print_all_recalls, dynamic, tags, show_qps_per_thread; + bool print_all_recalls, dynamic, tags, show_qps_per_thread, post_process, diverse_search, scale_seller_limits; float fail_if_recall_below = 0.0f; po::options_description desc{ @@ -323,6 +421,20 @@ int main(int argc, char **argv) optional_configs.add_options()("num_threads,T", po::value(&num_threads)->default_value(omp_get_num_procs()), program_options_utils::NUMBER_THREADS_DESCRIPTION); + optional_configs.add_options()("max_K_per_seller", + po::value(&max_L_per_seller)->default_value(0), + "How many results per seller we want search results to contain"); + optional_configs.add_options()("diverse_search", + po::value(&diverse_search)->default_value(false), + "Whether to run diverse search or baseline search"); + optional_configs.add_options()("scale_seller_limits", + po::value(&scale_seller_limits)->default_value(false), + "Whether to run scale the max_L_per_seller based on the L value"); + optional_configs.add_options()("post_process", + po::value(&post_process)->default_value(false), + "Whether to post-processing to ensure correct diversity"); + + optional_configs.add_options()( "dynamic", po::value(&dynamic)->default_value(false), "Whether the index is dynamic. Dynamic indices must have associated tags. Default false."); @@ -421,19 +533,19 @@ int main(int argc, char **argv) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index( metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, - Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below); + Lvec, dynamic, tags, show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else { @@ -447,19 +559,19 @@ int main(int argc, char **argv) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("uint8")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else if (data_type == std::string("float")) { return search_memory_index(metric, index_path_prefix, result_path, query_file, gt_file, num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, query_filters, fail_if_recall_below); + show_qps_per_thread, query_filters, fail_if_recall_below, max_L_per_seller, diverse_search, scale_seller_limits, post_process); } else { diff --git a/apps/utils/CMakeLists.txt b/apps/utils/CMakeLists.txt index 3b8cf223c..98edfce5d 100644 --- a/apps/utils/CMakeLists.txt +++ b/apps/utils/CMakeLists.txt @@ -51,6 +51,12 @@ add_executable(compute_groundtruth compute_groundtruth.cpp) target_include_directories(compute_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) target_link_libraries(compute_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) + +# Compute ground truth thing outside of DiskANN main source that depends on MKL. +add_executable(compute_diverse_groundtruth compute_diverse_groundtruth.cpp) +target_include_directories(compute_diverse_groundtruth PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) +target_link_libraries(compute_diverse_groundtruth ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) + add_executable(compute_groundtruth_for_filters compute_groundtruth_for_filters.cpp) target_include_directories(compute_groundtruth_for_filters PRIVATE ${DISKANN_MKL_INCLUDE_DIRECTORIES}) target_link_libraries(compute_groundtruth_for_filters ${PROJECT_NAME} ${DISKANN_MKL_LINK_LIBRARIES} ${DISKANN_ASYNC_LIB} Boost::program_options) diff --git a/apps/utils/compute_diverse_groundtruth.cpp b/apps/utils/compute_diverse_groundtruth.cpp new file mode 100644 index 000000000..48eac0a52 --- /dev/null +++ b/apps/utils/compute_diverse_groundtruth.cpp @@ -0,0 +1,681 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WINDOWS +#include +#else +#include +#endif +#include "filter_utils.h" +#include "utils.h" +#include "neighbor.h" + +// WORKS FOR UPTO 2 BILLION POINTS (as we use INT INSTEAD OF UNSIGNED) + +#define PARTSIZE 10000000 +#define ALIGNMENT 512 + +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + +namespace po = boost::program_options; + +template T div_round_up(const T numerator, const T denominator) +{ + return (numerator % denominator == 0) ? (numerator / denominator) : 1 + (numerator / denominator); +} + +using pairIF = std::pair; +struct cmpmaxstruct +{ + bool operator()(const pairIF &l, const pairIF &r) + { + return l.second < r.second; + }; +}; + +using maxPQIFCS = std::priority_queue, cmpmaxstruct>; + +template T *aligned_malloc(const size_t n, const size_t alignment) +{ +#ifdef _WINDOWS + return (T *)_aligned_malloc(sizeof(T) * n, alignment); +#else + return static_cast(aligned_alloc(alignment, sizeof(T) * n)); +#endif +} + +inline bool custom_dist(const std::pair &a, const std::pair &b) +{ + return a.second < b.second; +} + +void compute_l2sq(float *const points_l2sq, const float *const matrix, const int64_t num_points, const uint64_t dim) +{ + assert(points_l2sq != NULL); +#pragma omp parallel for schedule(static, 65536) + for (int64_t d = 0; d < num_points; ++d) + points_l2sq[d] = cblas_sdot((int64_t)dim, matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1, + matrix + (ptrdiff_t)d * (ptrdiff_t)dim, 1); +} + +void distsq_to_points(const size_t dim, + float *dist_matrix, // Col Major, cols are queries, rows are points + size_t npoints, const float *const points, + const float *const points_l2sq, // points in Col major + size_t nqueries, const float *const queries, + const float *const queries_l2sq, // queries in Col major + float *ones_vec = NULL) // Scratchspace of num_data size and init to 1.0 +{ + bool ones_vec_alloc = false; + if (ones_vec == NULL) + { + ones_vec = new float[nqueries > npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-2.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, points_l2sq, npoints, + ones_vec, nqueries, (float)1.0, dist_matrix, npoints); + cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, npoints, nqueries, 1, (float)1.0, ones_vec, npoints, + queries_l2sq, nqueries, (float)1.0, dist_matrix, npoints); + if (ones_vec_alloc) + delete[] ones_vec; +} + +/* +struct bestCandidates { + diskann::NeighborPriorityQueue best_L_nodes; + tsl::robin_map color_to_nodes; + uint32_t _Lsize = 0; + uint32_t _maxLperSeller = 0; + std::vector &_location_to_seller; + + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = diskann::NeighborPriorityQueue(_Lsize); + } + void insert(uint32_t cur_id, float cur_dist) { + //std::cout< npoints ? nqueries : npoints]; + std::fill_n(ones_vec, nqueries > npoints ? nqueries : npoints, (float)1.0); + ones_vec_alloc = true; + } + cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, npoints, nqueries, dim, (float)-1.0, points, dim, queries, dim, + (float)0.0, dist_matrix, npoints); + + if (ones_vec_alloc) + delete[] ones_vec; +} + + + + + + + +void exact_knn(const size_t dim, const size_t k, + size_t *const closest_points, // k * num_queries preallocated, col + // major, queries columns + float *const dist_closest_points, // k * num_queries + // preallocated, Dist to + // corresponding closes_points + size_t npoints, size_t start_id, + float *points_in, // points in Col major + size_t nqueries, float *queries_in, uint32_t kPerSeller, std::vector &running_results, + diskann::Metric metric = diskann::Metric::L2) // queries in Col major +{ + float *points_l2sq = new float[npoints]; + float *queries_l2sq = new float[nqueries]; + compute_l2sq(points_l2sq, points_in, npoints, dim); + compute_l2sq(queries_l2sq, queries_in, nqueries, dim); + + float *points = points_in; + float *queries = queries_in; + + if (metric == diskann::Metric::COSINE) + { // we convert cosine distance as + // normalized L2 distnace + points = new float[npoints * dim]; + queries = new float[nqueries * dim]; +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)npoints; i++) + { + float norm = std::sqrt(points_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + points[i * dim + j] = points_in[i * dim + j] / norm; + } + } + +#pragma omp parallel for schedule(static, 4096) + for (int64_t i = 0; i < (int64_t)nqueries; i++) + { + float norm = std::sqrt(queries_l2sq[i]); + if (norm == 0) + { + norm = std::numeric_limits::epsilon(); + } + for (uint32_t j = 0; j < dim; j++) + { + queries[i * dim + j] = queries_in[i * dim + j] / norm; + } + } + // recalculate norms after normalizing, they should all be one. + compute_l2sq(points_l2sq, points, npoints, dim); + compute_l2sq(queries_l2sq, queries, nqueries, dim); + } + + std::cout << "Going to compute " << k << " NNs for " << nqueries << " queries over " << npoints << " points in " + << dim << " dimensions using"; + if (metric == diskann::Metric::INNER_PRODUCT) + std::cout << " MIPS "; + else if (metric == diskann::Metric::COSINE) + std::cout << " Cosine "; + else + std::cout << " L2 "; + std::cout << "distance fn. " << std::endl; + + size_t q_batch_size = (1 << 9); + float *dist_matrix = new float[(size_t)q_batch_size * (size_t)npoints]; + + for (size_t b = 0; b < div_round_up(nqueries, q_batch_size); ++b) + { + int64_t q_b = b * q_batch_size; + int64_t q_e = ((b + 1) * q_batch_size > nqueries) ? nqueries : (b + 1) * q_batch_size; + + if (metric == diskann::Metric::L2 || metric == diskann::Metric::COSINE) + { + distsq_to_points(dim, dist_matrix, npoints, points, points_l2sq, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim, queries_l2sq + q_b); + } + else + { + inner_prod_to_points(dim, dist_matrix, npoints, points, q_e - q_b, + queries + (ptrdiff_t)q_b * (ptrdiff_t)dim); + } + std::cout << "Computed distances for queries: [" << q_b << "," << q_e << ")" << std::endl; + +#pragma omp parallel for schedule(dynamic, 16) + for (long long q = q_b; q < q_e; q++) + { + diskann::bestCandidates & cur_query_best_results = running_results[q]; +// for (size_t p = 0; p < k; p++) +// point_dist.emplace(p, dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]); + for (size_t p = 0; p < npoints; p++) + { + uint32_t cur_id = p + start_id; + float cur_dist = dist_matrix[(ptrdiff_t)p + (ptrdiff_t)(q - q_b) * (ptrdiff_t)npoints]; + cur_query_best_results.insert(cur_id, cur_dist); + } + } + std::cout << "Computed exact k-NN for queries: [" << q_b << "," << q_e << ")" << std::endl; + } + + delete[] dist_matrix; + + delete[] points_l2sq; + delete[] queries_l2sq; + + if (metric == diskann::Metric::COSINE) + { + delete[] points; + delete[] queries; + } +} + +template inline int get_num_parts(const char *filename) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + std::cout << "#pts = " << npts_i32 << ", #dims = " << ndims_i32 << std::endl; + reader.close(); + uint32_t num_parts = + (npts_i32 % PARTSIZE) == 0 ? npts_i32 / PARTSIZE : (uint32_t)std::floor(npts_i32 / PARTSIZE) + 1; + std::cout << "Number of parts: " << num_parts << std::endl; + return num_parts; +} + +template +inline void load_bin_as_float(const char *filename, float *&data, size_t &npts, size_t &ndims, int part_num) +{ + std::ifstream reader; + reader.exceptions(std::ios::failbit | std::ios::badbit); + reader.open(filename, std::ios::binary); + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t)npts_i32); + npts = end_id - start_id; + ndims = (uint64_t)ndims_i32; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims << ", size = " << npts * ndims * sizeof(T) << "B" + << std::endl; + + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), std::ios::beg); + T *data_T = new T[npts * ndims]; + reader.read((char *)data_T, sizeof(T) * npts * ndims); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + data = aligned_malloc(npts * ndims, ALIGNMENT); +#pragma omp parallel for schedule(dynamic, 32768) + for (int64_t i = 0; i < (int64_t)npts; i++) + { + for (int64_t j = 0; j < (int64_t)ndims; j++) + { + float cur_val_float = (float)data_T[i * ndims + j]; + std::memcpy((char *)(data + i * ndims + j), (char *)&cur_val_float, sizeof(float)); + } + } + delete[] data_T; + std::cout << "Finished converting part data to float." << std::endl; +} + +template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) +{ + std::ofstream writer; + writer.exceptions(std::ios::failbit | std::ios::badbit); + writer.open(filename, std::ios::binary | std::ios::out); + std::cout << "Writing bin: " << filename << "\n"; + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "bin: #pts = " << npts << ", #dims = " << ndims + << ", size = " << npts * ndims * sizeof(T) + 2 * sizeof(int) << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(T)); + writer.close(); + std::cout << "Finished writing bin" << std::endl; +} + +inline void save_groundtruth_as_one_file(const std::string filename, int32_t *data, float *distances, size_t npts, + size_t ndims) +{ + std::ofstream writer(filename, std::ios::binary | std::ios::out); + int npts_i32 = (int)npts, ndims_i32 = (int)ndims; + writer.write((char *)&npts_i32, sizeof(int)); + writer.write((char *)&ndims_i32, sizeof(int)); + std::cout << "Saving truthset in one file (npts, dim, npts*dim id-matrix, " + "npts*dim dist-matrix) with npts = " + << npts << ", dim = " << ndims << ", size = " << 2 * npts * ndims * sizeof(uint32_t) + 2 * sizeof(int) + << "B" << std::endl; + + writer.write((char *)data, npts * ndims * sizeof(uint32_t)); + writer.write((char *)distances, npts * ndims * sizeof(float)); + writer.close(); + std::cout << "Finished writing truthset" << std::endl; +} + +template +std::vector>> processUnfilteredParts(const std::string &base_file, + size_t &nqueries, size_t &npoints, + size_t &dim, size_t &k, float *query_data, + const diskann::Metric &metric, + std::vector &location_to_tag, std::vector &location_to_seller, uint32_t kperseller) +{ + float *base_data = nullptr; + int num_parts = get_num_parts(base_file.c_str()); + std::vector>> res(nqueries); + std::vector running_results(nqueries, diskann::bestCandidates(k, kperseller, location_to_seller)); + for (int p = 0; p < num_parts; p++) + { + size_t start_id = p * PARTSIZE; + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + + size_t *closest_points_part = new size_t[nqueries * k]; + float *dist_closest_points_part = new float[nqueries * k]; + + auto part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, npoints, start_id, base_data, nqueries, query_data, kperseller, running_results, + metric); + + + delete[] closest_points_part; + delete[] dist_closest_points_part; + + diskann::aligned_free(base_data); + } + + for (size_t i = 0; i < nqueries; i++) + { + auto & cur_results = running_results[i]; +// for (size_t j = 0; j < part_k; j++) + for (uint32_t x = 0; x < cur_results.best_L_nodes.size(); x++) + { + auto &nbr = cur_results.best_L_nodes[x]; + if (!location_to_tag.empty()) + if (location_to_tag[nbr.id] == 0) + continue; + + res[i].push_back(std::make_pair((uint32_t)(nbr.id), + nbr.distance)); + } + } + + return res; +}; + + +void parse_seller_file(const std::string &label_file, size_t &num_points, std::vector &location_to_seller) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + +template +int aux_main(const std::string &base_file, const std::string &query_file, const std::string &seller_file, const std::string >_file, size_t k, size_t kperseller, + const diskann::Metric &metric, const std::string &tags_file = std::string("")) +{ + size_t npoints, nqueries, dim; + + float *query_data; + std::cout<<"Inside k=" << k <<", and kPerSeller=" << kperseller << std::endl; + load_bin_as_float(query_file.c_str(), query_data, nqueries, dim, 0); + if (nqueries > PARTSIZE) + std::cerr << "WARNING: #Queries provided (" << nqueries << ") is greater than " << PARTSIZE + << ". Computing GT only for the first " << PARTSIZE << " queries." << std::endl; + + // load tags + const bool tags_enabled = tags_file.empty() ? false : true; + std::vector location_to_tag = diskann::loadTags(tags_file, base_file); + + int *closest_points = new int[nqueries * k]; + float *dist_closest_points = new float[nqueries * k]; + + std::vector location_to_seller; + uint64_t num_pts_seller_file; + parse_seller_file(seller_file, num_pts_seller_file, location_to_seller); + + std::vector>> results = + processUnfilteredParts(base_file, nqueries, npoints, dim, k, query_data, metric, location_to_tag, location_to_seller, kperseller); + + for (size_t i = 0; i < nqueries; i++) + { + std::vector> &cur_res = results[i]; + std::sort(cur_res.begin(), cur_res.end(), custom_dist); + size_t j = 0; + for (auto iter : cur_res) + { + if (j == k) + break; + if (tags_enabled) + { + std::uint32_t index_with_tag = location_to_tag[iter.first]; + closest_points[i * k + j] = (int32_t)index_with_tag; + } + else + { + closest_points[i * k + j] = (int32_t)iter.first; + } + + if (metric == diskann::Metric::INNER_PRODUCT) + dist_closest_points[i * k + j] = -iter.second; + else + dist_closest_points[i * k + j] = iter.second; + + ++j; + } + if (j < k) { + std::cout << "WARNING: found less than k GT entries for query " << i << std::endl; + while (j::max(); + closest_points[i * k + j] = std::numeric_limits::max(); + j++; + } + } + } + + save_groundtruth_as_one_file(gt_file, closest_points, dist_closest_points, nqueries, k); + delete[] closest_points; + delete[] dist_closest_points; + diskann::aligned_free(query_data); + + return 0; +} + +void load_truthset(const std::string &bin_file, uint32_t *&ids, float *&dists, size_t &npts, size_t &dim) +{ + size_t read_blk_size = 64 * 1024 * 1024; + cached_ifstream reader(bin_file, read_blk_size); + diskann::cout << "Reading truthset file " << bin_file.c_str() << " ..." << std::endl; + size_t actual_file_size = reader.get_file_size(); + + int npts_i32, dim_i32; + reader.read((char *)&npts_i32, sizeof(int)); + reader.read((char *)&dim_i32, sizeof(int)); + npts = (uint32_t)npts_i32; + dim = (uint32_t)dim_i32; + + diskann::cout << "Metadata: #pts = " << npts << ", #dims = " << dim << "... " << std::endl; + + int truthset_type = -1; // 1 means truthset has ids and distances, 2 means + // only ids, -1 is error + size_t expected_file_size_with_dists = 2 * npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_with_dists) + truthset_type = 1; + + size_t expected_file_size_just_ids = npts * dim * sizeof(uint32_t) + 2 * sizeof(uint32_t); + + if (actual_file_size == expected_file_size_just_ids) + truthset_type = 2; + + if (truthset_type == -1) + { + std::stringstream stream; + stream << "Error. File size mismatch. File should have bin format, with " + "npts followed by ngt followed by npts*ngt ids and optionally " + "followed by npts*ngt distance values; actual size: " + << actual_file_size << ", expected: " << expected_file_size_with_dists << " or " + << expected_file_size_just_ids; + diskann::cout << stream.str(); + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); + } + + ids = new uint32_t[npts * dim]; + reader.read((char *)ids, npts * dim * sizeof(uint32_t)); + + if (truthset_type == 1) + { + dists = new float[npts * dim]; + reader.read((char *)dists, npts * dim * sizeof(float)); + } +} + +int main(int argc, char **argv) +{ + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, seller_file; + uint64_t K, KperSeller; + + try + { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + + desc.add_options()("data_type", po::value(&data_type)->required(), "data type "); + desc.add_options()("dist_fn", po::value(&dist_fn)->required(), + "distance function "); + desc.add_options()("base_file", po::value(&base_file)->required(), + "File containing the base vectors in binary format"); + desc.add_options()("query_file", po::value(&query_file)->required(), + "File containing the query vectors in binary format"); + desc.add_options()("seller_file", po::value(&seller_file)->required(), + "File containing the seller per point"); + desc.add_options()("gt_file", po::value(>_file)->required(), + "File name for the writing ground truth in binary " + "format, please don' append .bin at end if " + "no filter_label or filter_label_file is provided it " + "will save the file with '.bin' at end." + "else it will save the file as filename_label.bin"); + desc.add_options()("K", po::value(&K)->required(), + "Number of ground truth nearest neighbors to compute"); + desc.add_options()("KperSeller", po::value(&KperSeller)->required(), + "Number of ground truth nearest neighbors to compute per Seller"); + desc.add_options()("tags_file", po::value(&tags_file)->default_value(std::string()), + "File containing the tags in binary format"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) + { + std::cout << desc; + return 0; + } + po::notify(vm); + } + catch (const std::exception &ex) + { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (data_type != std::string("float") && data_type != std::string("int8") && data_type != std::string("uint8")) + { + std::cout << "Unsupported type. float, int8 and uint8 types are supported." << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) + { + metric = diskann::Metric::L2; + } + else if (dist_fn == std::string("mips")) + { + metric = diskann::Metric::INNER_PRODUCT; + } + else if (dist_fn == std::string("cosine")) + { + metric = diskann::Metric::COSINE; + } + else + { + std::cerr << "Unsupported distance function. Use l2/mips/cosine." << std::endl; + return -1; + } + + try + { + if (data_type == std::string("float")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); + if (data_type == std::string("int8")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); + if (data_type == std::string("uint8")) + aux_main(base_file, query_file, seller_file, gt_file, K, KperSeller, metric, tags_file); + } + catch (const std::exception &e) + { + std::cout << std::string(e.what()) << std::endl; + diskann::cerr << "Compute GT failed." << std::endl; + return -1; + } +} diff --git a/apps/utils/generate_pq.cpp b/apps/utils/generate_pq.cpp index a881b1104..50540e0fb 100644 --- a/apps/utils/generate_pq.cpp +++ b/apps/utils/generate_pq.cpp @@ -31,7 +31,7 @@ bool generate_pq(const std::string &data_path, const std::string &index_prefix_p (uint32_t)num_pq_chunks, KMEANS_ITERS_FOR_PQ, pq_pivots_path); } diskann::generate_pq_data_from_pivots(data_path, (uint32_t)num_pq_centers, (uint32_t)num_pq_chunks, - pq_pivots_path, pq_compressed_vectors_path, true); + pq_pivots_path, pq_compressed_vectors_path, opq); delete[] train_data; diff --git a/include/abstract_index.h b/include/abstract_index.h index 059866f7c..c8b01105c 100644 --- a/include/abstract_index.h +++ b/include/abstract_index.h @@ -72,6 +72,10 @@ class AbstractIndex std::pair search(const data_type *query, const size_t K, const uint32_t L, IDType *indices, float *distances = nullptr); + template + std::pair diverse_search(const data_type *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType *indices, + float *distances = nullptr); + // Filter support search // IndexType is either uint32_t or uint64_t template @@ -110,6 +114,8 @@ class AbstractIndex virtual void _build(const DataType &data, const size_t num_points_to_load, TagVector &tags) = 0; virtual std::pair _search(const DataType &query, const size_t K, const uint32_t L, std::any &indices, float *distances = nullptr) = 0; + virtual std::pair _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances = nullptr) = 0; virtual std::pair _search_with_filters(const DataType &query, const std::string &filter_label, const size_t K, const uint32_t L, std::any &indices, float *distances) = 0; diff --git a/include/defaults.h b/include/defaults.h index ef1750fcf..fc10c4fb2 100644 --- a/include/defaults.h +++ b/include/defaults.h @@ -30,5 +30,9 @@ const uint32_t MAX_DEGREE = 64; const uint32_t BUILD_LIST_SIZE = 100; const uint32_t SATURATE_GRAPH = false; const uint32_t SEARCH_LIST_SIZE = 100; + +const bool DIVERSE_INDEX = false; +const std::string EMPTY_STRING = ""; +const bool NUM_DIVERSE_BUILD = 1; } // namespace defaults } // namespace diskann diff --git a/include/disk_utils.h b/include/disk_utils.h index 08f046dcd..35e60dfea 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -82,7 +82,8 @@ DISKANN_DLLEXPORT int build_merged_vamana_index(std::string base_file, diskann:: uint32_t num_threads, bool use_filters = false, const std::string &label_file = std::string(""), const std::string &labels_to_medoids_file = std::string(""), - const std::string &universal_label = "", const uint32_t Lf = 0); + const std::string &universal_label = "", const uint32_t Lf = 0, + bool diverse_index = false, const std::string &seller_file = std::string(""), size_t num_diverse_build = 0) ; template DISKANN_DLLEXPORT uint32_t optimize_beamwidth(std::unique_ptr> &_pFlashIndex, @@ -98,7 +99,8 @@ DISKANN_DLLEXPORT int build_disk_index( bool use_filters = false, const std::string &label_file = std::string(""), // default is empty string for no label_file const std::string &universal_label = "", const uint32_t filter_threshold = 0, - const uint32_t Lf = 0); // default is empty string for no universal label + const uint32_t Lf = 0, // default is empty string for no universal label + bool diverse_index = false, const std::string &seller_file = std::string(""), const uint32_t num_diverse_build = 0); template DISKANN_DLLEXPORT void create_disk_layout(const std::string base_file, const std::string mem_index_file, diff --git a/include/index.h b/include/index.h index b9bf4f384..71a1c37f2 100644 --- a/include/index.h +++ b/include/index.h @@ -132,7 +132,11 @@ template clas // can customize L on a per-query basis without tampering with "Parameters" template DISKANN_DLLEXPORT std::pair search(const T *query, const size_t K, const uint32_t L, - IDType *indices, float *distances = nullptr); + IDType *indices, float *distances = nullptr, const uint32_t maxLperSeller = 0); + + template + std::pair diverse_search(const T *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, IDType *indices, + float *distances = nullptr); // Initialize space for res_vectors before calling. DISKANN_DLLEXPORT size_t search_with_tags(const T *query, const uint64_t K, const uint32_t L, TagT *tags, @@ -210,6 +214,8 @@ template clas const std::string &filter_label_raw, const size_t K, const uint32_t L, std::any &indices, float *distances) override; + virtual std::pair _diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances = nullptr) override; virtual int _insert_point(const DataType &data_point, const TagType tag) override; virtual int _insert_point(const DataType &data_point, const TagType tag, Labelvector &labels) override; @@ -248,6 +254,7 @@ template clas uint32_t calculate_entry_point(); void parse_label_file(const std::string &label_file, size_t &num_pts_labels); + void parse_seller_file(const std::string &label_file, size_t &num_pts_labels); std::unordered_map load_label_map(const std::string &map_file); @@ -258,7 +265,7 @@ template clas // The query to use is placed in scratch->aligned_query std::pair iterate_to_fixed_point(InMemQueryScratch *scratch, const uint32_t Lindex, const std::vector &init_ids, bool use_filter, - const std::vector &filters, bool search_invocation); + const std::vector &filters, bool search_invocation, uint32_t maxLperSeller = 0); void search_for_point_and_prune(int location, uint32_t Lindex, std::vector &pruned_list, InMemQueryScratch *scratch, bool use_filter = false, @@ -384,6 +391,13 @@ template clas std::unordered_map _label_to_start_id; std::unordered_map _medoid_counts; + bool _diverse_index = false; + uint32_t _num_diverse_build =1; + uint32_t _max_L_per_seller = 0; + std::vector _location_to_seller; + std::string _seller_file; + + bool _use_universal_label = false; LabelT _universal_label = 0; uint32_t _filterIndexingQueueSize; diff --git a/include/neighbor.h b/include/neighbor.h index d7c0c25ed..5cd852f4f 100644 --- a/include/neighbor.h +++ b/include/neighbor.h @@ -6,6 +6,7 @@ #include #include #include +#include #include "utils.h" namespace diskann @@ -43,10 +44,14 @@ class NeighborPriorityQueue { } - explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1) + explicit NeighborPriorityQueue(size_t capacity) : _size(0), _capacity(capacity), _cur(0), _data(capacity + 1, Neighbor(std::numeric_limits::max(), std::numeric_limits::max())) { } + void setup(uint32_t capacity) { + _data.resize(capacity+1,Neighbor(std::numeric_limits::max(), std::numeric_limits::max())); + _capacity = capacity; + } // Inserts the item ordered into the set up to the sets capacity. // The item will be dropped if it is the same id as an exiting // set item or it has a greated distance than the final @@ -75,7 +80,7 @@ class NeighborPriorityQueue else { lo = mid + 1; - } + } } if (lo < _capacity) @@ -93,6 +98,64 @@ class NeighborPriorityQueue } } + + // Deletes the item if found. + void delete_id(const Neighbor &nbr) + { + size_t lo = 0, hi = _size; + size_t loc = std::numeric_limits::max(); + while ((lo < hi) && loc == std::numeric_limits::max()) + { + size_t mid = (lo + hi) >> 1; + if (nbr.distance < _data[mid].distance) + { + hi = mid; + } + else if (nbr.distance > _data[mid].distance) + { + lo = mid+1; + } + else + { + uint32_t itr = 0; + for (;; itr++) { + if (mid + itr < hi) { + if (_data[mid+itr].id == nbr.id) { + loc = mid+itr; + break; + } + } + if(mid - itr >= lo) { + if (_data[mid-itr].id == nbr.id) { + loc = mid-itr; + break; + } + } + } + } + } + + if (loc != std::numeric_limits::max()) + { + std::memmove(&_data[loc], &_data[loc+1], (_size - loc - 1) * sizeof(Neighbor)); + _size--; + _cur = 0; + while (_cur < _size && _data[_cur].expanded) // RK: inefficient! + { + _cur++; + } + } else { + std::cout<<"Found a problem! " << lo <<" " << hi <<" " <_data) { + std::cout< _data; }; + +struct bestCandidates { + NeighborPriorityQueue best_L_nodes; + tsl::robin_map color_to_nodes; + uint32_t _Lsize = 0; + uint32_t _maxLperSeller = 0; + std::vector &_location_to_seller; + + bestCandidates(std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + } + + bestCandidates(uint32_t Lsize, uint32_t maxLperSeller, std::vector &location_to_seller) : _location_to_seller(location_to_seller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = NeighborPriorityQueue(_Lsize); + } + + void clear() { + best_L_nodes.clear(); + color_to_nodes.clear(); + } + + void setup(uint32_t Lsize, uint32_t maxLperSeller) { + _Lsize = Lsize; + _maxLperSeller = maxLperSeller; + best_L_nodes = NeighborPriorityQueue(_Lsize); + } + + void insert(uint32_t cur_id, float cur_dist) { + //std::cout< class PQFlashIndex const bool shuffle = false); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, - uint64_t *res_ids, float *res_dists, const uint64_t beam_width, + uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT void cached_beam_search(const T *query, const uint64_t k_search, const uint64_t l_search, uint64_t *res_ids, float *res_dists, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data = false, + const uint32_t io_limit, const uint32_t max_l_per_seller = std::numeric_limits::max(), const bool use_reorder_data = false, QueryStats *stats = nullptr); DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &filter_label); @@ -118,10 +118,14 @@ template class PQFlashIndex DISKANN_DLLEXPORT inline bool point_has_label(uint32_t point_id, LabelT label_id); std::unordered_map load_label_map(std::basic_istream &infile); DISKANN_DLLEXPORT void parse_label_file(std::basic_istream &infile, size_t &num_pts_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(const std::string &fileContent, uint32_t &num_pts, uint32_t &num_total_labels); DISKANN_DLLEXPORT void generate_random_labels(std::vector &labels, const uint32_t num_labels, const uint32_t nthreads); + + DISKANN_DLLEXPORT void parse_seller_file(const std::string &label_file, size_t &num_pts_labels); + void reset_stream_for_reading(std::basic_istream &infile); // sector # on disk where node_id is present with in the graph part @@ -234,6 +238,13 @@ template class PQFlashIndex tsl::robin_map> _real_to_dummy_map; std::unordered_map _label_map; + + bool _diverse_index = false; + uint32_t _max_L_per_seller = 0; + std::vector _location_to_seller; + std::string _seller_file; + + #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate // any additions we make to the header. This is an outer limit diff --git a/include/program_options_utils.hpp b/include/program_options_utils.hpp index 2be60595b..383bdadaa 100644 --- a/include/program_options_utils.hpp +++ b/include/program_options_utils.hpp @@ -68,6 +68,7 @@ const char *GRAPH_BUILD_ALPHA = "Alpha controls density and diameter of graph, s "denser graphs with lower diameter"; const char *BUIlD_GRAPH_PQ_BYTES = "Number of PQ bytes to build the index; 0 for full precision build"; const char *USE_OPQ = "Use Optimized Product Quantization (OPQ)."; +const char *DIVERSE_INDEX = "Build Diverse Index"; const char *LABEL_FILE = "Input label file in txt format for Filtered Index build. The file should contain comma " "separated filters for each node with each line corresponding to a graph node"; const char *UNIVERSAL_LABEL = @@ -77,5 +78,7 @@ const char *UNIVERSAL_LABEL = "in the labels file instead of listing all labels for a node. DiskANN will not automatically assign a " "universal label to a node."; const char *FILTERED_LBUILD = "Build complexity for filtered points, higher value results in better graphs"; +const char *DIVERSITY_FILE = "Seller diversity file for diverse index"; +const char *NUM_DIVERSE = "Number of diverse edges needed per node in each local region"; } // namespace program_options_utils diff --git a/include/scratch.h b/include/scratch.h index 2f43e3365..af1c6e421 100644 --- a/include/scratch.h +++ b/include/scratch.h @@ -29,7 +29,7 @@ template class InMemQueryScratch : public AbstractScratch public: ~InMemQueryScratch(); InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, size_t aligned_dim, - size_t alignment_factor, bool init_pq_scratch = false); + size_t alignment_factor, std::vector &location_to_sellers, bool init_pq_scratch = false); void resize_for_new_L(uint32_t new_search_l); void clear(); @@ -61,6 +61,10 @@ template class InMemQueryScratch : public AbstractScratch { return _best_l_nodes; } + inline bestCandidates &best_diverse_nodes() + { + return _best_diverse_nodes; + } inline std::vector &occlude_factor() { return _occlude_factor; @@ -107,6 +111,7 @@ template class InMemQueryScratch : public AbstractScratch // _best_l_nodes is reserved for storing best L entries // Underlying storage is L+1 to support inserts NeighborPriorityQueue _best_l_nodes; + bestCandidates _best_diverse_nodes; // _occlude_factor.size() >= pool.size() in occlude_list function // _pool is clipped to maxc in occlude_list before affecting _occlude_factor @@ -149,8 +154,9 @@ template class SSDQueryScratch : public AbstractScratch tsl::robin_set visited; NeighborPriorityQueue retset; std::vector full_retset; + bestCandidates best_diverse_nodes; - SSDQueryScratch(size_t aligned_dim, size_t visited_reserve); + SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); ~SSDQueryScratch(); void reset(); @@ -162,7 +168,7 @@ template class SSDThreadData SSDQueryScratch scratch; IOContext ctx; - SSDThreadData(size_t aligned_dim, size_t visited_reserve); + SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers); void clear(); }; diff --git a/include/utils.h b/include/utils.h index d3af5c3a9..532bffd87 100644 --- a/include/utils.h +++ b/include/utils.h @@ -673,7 +673,7 @@ inline void copy_file(std::string in_file, std::string out_file) } DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, - unsigned *our_results, unsigned dim_or, unsigned recall_at); + unsigned *our_results, unsigned dim_or, unsigned recall_at, float* algo_distances = nullptr); DISKANN_DLLEXPORT double calculate_recall(unsigned num_queries, unsigned *gold_std, float *gs_dist, unsigned dim_gs, unsigned *our_results, unsigned dim_or, unsigned recall_at, diff --git a/python/src/static_disk_index.cpp b/python/src/static_disk_index.cpp index 9e86b0ad5..6bf307e28 100644 --- a/python/src/static_disk_index.cpp +++ b/python/src/static_disk_index.cpp @@ -65,7 +65,7 @@ NeighborsAndDistances StaticDiskIndex
::search( std::vector u64_ids(knn); diskann::QueryStats stats; - _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, false, + _index.cached_beam_search(query.data(), knn, complexity, u64_ids.data(), dists.mutable_data(), beam_width, std::numeric_limits::max(), false, &stats); auto r = ids.mutable_unchecked<1>(); diff --git a/src/abstract_index.cpp b/src/abstract_index.cpp index 92665825f..fd2aafa20 100644 --- a/src/abstract_index.cpp +++ b/src/abstract_index.cpp @@ -22,6 +22,15 @@ std::pair AbstractIndex::search(const data_type *query, cons return _search(any_query, K, L, any_indices, distances); } +template +std::pair AbstractIndex::diverse_search(const data_type *query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + IDType *indices, float *distances) +{ + auto any_indices = std::any(indices); + auto any_query = std::any(query); + return _diverse_search(any_query, K, L, maxLperSeller, any_indices, distances); +} + template size_t AbstractIndex::search_with_tags(const data_type *query, const uint64_t K, const uint32_t L, tag_type *tags, float *distances, std::vector &res_vectors, bool use_filters, @@ -155,6 +164,22 @@ template DISKANN_DLLEXPORT std::pair AbstractIndex::search AbstractIndex::search( const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const float *query, const size_t K, const uint32_t L, const uint32_t maxL, uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const uint8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint32_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint32_t *indices, float *distances); + +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const float *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const uint8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); +template DISKANN_DLLEXPORT std::pair AbstractIndex::diverse_search( + const int8_t *query, const size_t K, const uint32_t L, const uint32_t maxL,uint64_t *indices, float *distances); + + + template DISKANN_DLLEXPORT std::pair AbstractIndex::search_with_filters( const DataType &query, const std::string &raw_label, const size_t K, const uint32_t L, uint32_t *indices, float *distances); diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 016560217..f1af2abd9 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -630,7 +630,7 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string medoids_file, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, const std::string &labels_to_medoids_file, const std::string &universal_label, - const uint32_t Lf) + const uint32_t Lf, bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build) { size_t base_num, base_dim; diskann::get_bin_metadata(base_file, base_num, base_dim); @@ -647,6 +647,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr .with_filter_list_size(Lf) .with_saturate_graph(!use_filters) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(seller_file) + .with_num_diverse_build(num_diverse_build) .build(); using TagT = uint32_t; diskann::Index _index(compareMetric, base_dim, base_num, @@ -706,15 +709,24 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string shard_ids_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; + std::string shard_sellers_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_sellers.txt"; retrieve_shard_data_from_ids(base_file, shard_ids_file, shard_base_file); std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; + if(diverse_index) + { + diskann::extract_shard_labels(seller_file, shard_ids_file, shard_sellers_file); + } + diskann::IndexWriteParameters low_degree_params = diskann::IndexWriteParametersBuilder(L, 2 * R / 3) .with_filter_list_size(Lf) .with_saturate_graph(false) .with_num_threads(num_threads) + .with_diverse_index(diverse_index) + .with_seller_file(shard_sellers_file) + .with_num_diverse_build(num_diverse_build) .build(); uint64_t shard_base_dim, shard_base_pts; @@ -724,8 +736,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::make_shared(low_degree_params), nullptr, defaults::NUM_FROZEN_POINTS_STATIC, false, false, false, build_pq_bytes > 0, build_pq_bytes, use_opq); + if (!use_filters) - { + { _index.build(shard_base_file.c_str(), shard_base_pts); } else @@ -736,7 +749,9 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr LabelT unv_label_as_num = 0; _index.set_universal_label(unv_label_as_num); } + _index.build_filtered_index(shard_base_file.c_str(), shard_labels_file, shard_base_pts); + } _index.save(shard_index_file.c_str()); // copy universal label file from first shard to the final destination @@ -768,11 +783,13 @@ int build_merged_vamana_index(std::string base_file, diskann::Metric compareMetr std::string shard_labels_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_labels.txt"; std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; std::string shard_index_file_data = shard_index_file + ".data"; + std::string shard_sellers_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_sellers.txt"; std::remove(shard_base_file.c_str()); std::remove(shard_id_file.c_str()); std::remove(shard_index_file.c_str()); std::remove(shard_index_file_data.c_str()); + std::remove(shard_sellers_file.c_str()); if (use_filters) { std::string shard_index_label_file = shard_index_file + "_labels.txt"; @@ -813,7 +830,7 @@ uint32_t optimize_beamwidth(std::unique_ptr> &p { pFlashIndex->cached_beam_search(tuning_sample + (i * tuning_sample_aligned_dim), 1, L, tuning_sample_result_ids_64.data() + (i * 1), - tuning_sample_result_dists.data() + (i * 1), cur_bw, false, stats + i); + tuning_sample_result_dists.data() + (i * 1), cur_bw, std::numeric_limits::max(), false, stats + i); } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; @@ -1101,7 +1118,7 @@ template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, const uint32_t filter_threshold, - const uint32_t Lf) + const uint32_t Lf, bool diverse_index, const std::string &seller_file, const uint32_t num_diverse_build) { std::stringstream parser; parser << std::string(indexBuildParameters); @@ -1326,7 +1343,7 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const diskann::build_merged_vamana_index(data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, build_pq_bytes, use_opq, num_threads, use_filters, labels_file_to_use, - labels_to_medoids_path, universal_label, Lf); + labels_to_medoids_path, universal_label, Lf, diverse_index, seller_file, num_diverse_build); diskann::cout << timer.elapsed_seconds_for_step("building merged vamana index") << std::endl; timer.reset(); @@ -1368,6 +1385,11 @@ int build_disk_index(const char *dataFilePath, const char *indexFilePath, const if (use_disk_pq) std::remove(disk_pq_compressed_vectors_path.c_str()); + if (seller_file != "") { + std::string disk_index_seller_file = disk_index_path + "_sellers.txt"; + copy_file(seller_file, disk_index_seller_file); + } + auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; diskann::cout << "Indexing time: " << diff.count() << std::endl; @@ -1432,21 +1454,24 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); // LabelT = uint16 template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, @@ -1454,51 +1479,61 @@ template DISKANN_DLLEXPORT int build_disk_index(const char *da const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, bool use_opq, const std::string &codebook_prefix, bool use_filters, const std::string &label_file, const std::string &universal_label, - const uint32_t filter_threshold, const uint32_t Lf); + const uint32_t filter_threshold, const uint32_t Lf, bool diverse_index, + const std::string &seller_file, const uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); // Label=16_t template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, uint32_t L, uint32_t R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, std::string centroids_file, size_t build_pq_bytes, bool use_opq, uint32_t num_threads, bool use_filters, const std::string &label_file, - const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf); + const std::string &labels_to_medoids_file, const std::string &universal_label, const uint32_t Lf, + bool diverse_index, const std::string &seller_file, uint32_t num_diverse_build); }; // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index bf93344fa..f12415dcf 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -4,7 +4,7 @@ #include #include - +#include #include "boost/dynamic_bitset.hpp" #include "index_factory.h" #include "memory_mapper.h" @@ -101,6 +101,10 @@ Index::Index(const IndexConfig &index_config, std::shared_ptrfilter_list_size; _indexingThreads = index_config.index_write_params->num_threads; _saturate_graph = index_config.index_write_params->saturate_graph; + _diverse_index = index_config.index_write_params->diversity_index; + _seller_file = index_config.index_write_params->base_seller_labels; + _num_diverse_build = index_config.index_write_params->num_diverse_sellers; + std::cout<<"Set _num_diverse_build to " << _num_diverse_build << std::endl; if (index_config.index_search_params != nullptr) { @@ -110,7 +114,7 @@ Index::Index(const IndexConfig &index_config, std::shared_ptr Index::Index(Metric m, const size_t dim, const size_t max_points, const std::shared_ptr index_parameters, @@ -188,7 +192,7 @@ void Index::initialize_query_scratch(uint32_t num_threads, uint for (uint32_t i = 0; i < num_threads; i++) { auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _data_store->get_aligned_dim(), - _data_store->get_alignment_factor(), _pq_dist); + _data_store->get_alignment_factor(), _location_to_seller, _pq_dist); _query_scratch.push(scratch); } } @@ -290,6 +294,12 @@ void Index::save(const char *filename, bool compact_before_save if (!_save_as_one_file) { + if(_diverse_index) { + std::string index_seller_file = std::string(filename) + "_sellers.txt"; + std::filesystem::copy(_seller_file, index_seller_file); + std::cout<<"Saved seller file to " << index_seller_file <<"." << std::endl; + } + if (_filtered_index) { if (_label_to_start_id.size() > 0) @@ -588,6 +598,13 @@ void Index::load(const char *filename, uint32_t num_threads, ui throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, __LINE__); } + std::string index_seller_file = std::string(filename) + "_sellers.txt"; + if(file_exists(index_seller_file)) { + uint64_t nrows_seller_file; + parse_seller_file(index_seller_file, nrows_seller_file); + _diverse_index = true; + } + if (file_exists(labels_file)) { _label_map = load_label_map(labels_map_file); @@ -791,11 +808,26 @@ bool Index::detect_common_filters(uint32_t point_id, bool searc template std::pair Index::iterate_to_fixed_point( InMemQueryScratch *scratch, const uint32_t Lsize, const std::vector &init_ids, bool use_filter, - const std::vector &filter_labels, bool search_invocation) + const std::vector &filter_labels, bool search_invocation, uint32_t maxLperSeller) { + bool diverse_search = false; + if (maxLperSeller == 0) + maxLperSeller = Lsize; + else + diverse_search = true; std::vector &expanded_nodes = scratch->pool(); - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); - best_L_nodes.reserve(Lsize); + NeighborPriorityQueue &best_L_nodes_ref = scratch->best_l_nodes(); + bestCandidates &best_diverse_nodes_ref = scratch->best_diverse_nodes(); + best_L_nodes_ref.reserve(Lsize); + best_diverse_nodes_ref.setup(Lsize, maxLperSeller); + + NeighborPriorityQueue* best_L_nodes; + if(diverse_search) { + best_L_nodes = &(best_diverse_nodes_ref.best_L_nodes); + } else { + best_L_nodes = &(best_L_nodes_ref); + } + tsl::robin_set &inserted_into_pool_rs = scratch->inserted_into_pool_rs(); boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); std::vector &id_scratch = scratch->id_scratch(); @@ -873,17 +905,22 @@ std::pair Index::iterate_to_fixed_point( distance = distances[0]; Neighbor nn = Neighbor(id, distance); - best_L_nodes.insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, distance); + } else { + best_L_nodes->insert(nn); + } } } uint32_t hops = 0; uint32_t cmps = 0; - while (best_L_nodes.has_unexpanded_node()) + while (best_L_nodes->has_unexpanded_node()) { - auto nbr = best_L_nodes.closest_unexpanded(); + auto nbr = best_L_nodes->closest_unexpanded(); auto n = nbr.id; +// std::cout< Index::iterate_to_fixed_point( if (is_not_visited(id)) { id_scratch.push_back(id); - } - } - } - - // Mark nodes visited - for (auto id : id_scratch) - { if (fast_iterate) { inserted_into_pool_bs[id] = 1; @@ -960,8 +990,12 @@ std::pair Index::iterate_to_fixed_point( { inserted_into_pool_rs.insert(id); } + + } + } } + assert(dist_scratch.capacity() >= id_scratch.size()); compute_dists(id_scratch, dist_scratch); cmps += (uint32_t)id_scratch.size(); @@ -969,7 +1003,22 @@ std::pair Index::iterate_to_fixed_point( // Insert pairs into the pool of candidates for (size_t m = 0; m < id_scratch.size(); ++m) { - best_L_nodes.insert(Neighbor(id_scratch[m], dist_scratch[m])); +/* std::cout<<"Going to insert " << id_scratch[m] << " (nbr of " << n <<"), color " << _location_to_seller[id_scratch[m]] << std::endl; + for (auto &x : color_to_nodes) { + if (x.second.size() > 0) { + std::cout<insert(Neighbor(id_scratch[m], dist_scratch[m])); + } } } return std::make_pair(hops, cmps); @@ -984,10 +1033,15 @@ void Index::search_for_point_and_prune(int location, uint32_t L const std::vector init_ids = get_init_ids(); const std::vector unused_filter_label; + uint32_t maxLperSeller = 0; + if (_diverse_index) { + maxLperSeller = (Lindex/_num_diverse_build > 0)? Lindex/_num_diverse_build : 1; + } + if (!use_filter) { _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false, maxLperSeller); } else { @@ -1003,7 +1057,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L _data_store->get_vector(location, scratch->aligned_query()); iterate_to_fixed_point(scratch, filteredLindex, filter_specific_start_nodes, true, - _location_to_labels[location], false); + _location_to_labels[location], false, maxLperSeller); // combine candidate pools obtained with filter and unfiltered criteria. std::set best_candidate_pool; @@ -1016,7 +1070,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L scratch->clear(); _data_store->get_vector(location, scratch->aligned_query()); - iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false); + iterate_to_fixed_point(scratch, Lindex, init_ids, false, unused_filter_label, false, maxLperSeller); for (auto unfiltered_neighbour : scratch->pool()) { @@ -1053,6 +1107,7 @@ void Index::search_for_point_and_prune(int location, uint32_t L assert(_graph_store->get_total_points() == _max_points + _num_frozen_pts); } +/* template void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, const uint32_t degree, const uint32_t maxc, std::vector &result, @@ -1148,6 +1203,132 @@ void Index::occlude_list(const uint32_t location, std::vector +void Index::occlude_list(const uint32_t location, std::vector &pool, const float alpha, + const uint32_t degree, const uint32_t maxc, std::vector &result, + InMemQueryScratch *scratch, + const tsl::robin_set *const delete_set_ptr) +{ + if (pool.size() == 0) + return; + + // Truncate pool at maxc and initialize scratch spaces + assert(std::is_sorted(pool.begin(), pool.end())); + assert(result.size() == 0); + if (pool.size() > maxc) + pool.resize(maxc); + std::vector &occlude_factor = scratch->occlude_factor(); + std::vector> blockers(pool.size()); + // occlude_list can be called with the same scratch more than once by + // search_for_point_and_add_link through inter_insert. + occlude_factor.clear(); + // Initialize occlude_factor to pool.size() many 0.0f values for correctness + occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); + float cur_alpha = 1; + while (cur_alpha <= alpha && result.size() < degree) + { + std::vector> blockers(pool.size()); + // used for MIPS, where we store a value of eps in cur_alpha to + // denote pruned out entries which we can skip in later rounds. + float eps = cur_alpha + 0.01f; + for (auto iter = pool.begin(); result.size() < degree && iter != pool.end(); ++iter) + { + bool need_to_add_edge= true; + bool edge_added = false; + if (occlude_factor[iter - pool.begin()] == std::numeric_limits::min()) { + need_to_add_edge = false; // added as an edge in earlier round + edge_added =true; + } + if (occlude_factor[iter - pool.begin()] > cur_alpha) + { + if (_diverse_index) { + if (blockers[iter - pool.begin()].size() >= _num_diverse_build) + need_to_add_edge = false; + else if (blockers[iter - pool.begin()].find(_location_to_seller[iter->id]) != blockers[iter - pool.begin()].end()) + need_to_add_edge = false; + } else { + need_to_add_edge = false; + } + } + + // Set the entry to float::max so that is not considered again, similarly add its own color as a blocking color +// blockers[iter - pool.begin()].insert(_location_to_seller[iter->id]); + + if (need_to_add_edge) { + occlude_factor[iter - pool.begin()] = std::numeric_limits::min(); + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop + if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) + { + if (iter->id != location) + { + result.push_back(iter->id); + } + } + } + + if (need_to_add_edge || edge_added) { + // Update occlude factor for points from iter+1 to pool.end() + for (auto iter2 = iter + 1; iter2 != pool.end(); iter2++) + { + auto t = iter2 - pool.begin(); +// if (occlude_factor[t] > alpha) +// continue; + + bool prune_allowed = true; + if (_filtered_index) + { + uint32_t a = iter->id; + uint32_t b = iter2->id; + if (_location_to_labels.size() < b || _location_to_labels.size() < a) + continue; + for (auto &x : _location_to_labels[b]) + { + if (std::find(_location_to_labels[a].begin(), _location_to_labels[a].end(), x) == + _location_to_labels[a].end()) + { + prune_allowed = false; + } + if (!prune_allowed) + break; + } + } + if (!prune_allowed) + continue; + + float djk = _data_store->get_distance(iter2->id, iter->id); + if (_dist_metric == diskann::Metric::L2 || _dist_metric == diskann::Metric::COSINE) + { + occlude_factor[t] = (djk == 0) ? std::numeric_limits::max() + : std::max(occlude_factor[t], iter2->distance / djk); + if (_diverse_index) { + if (iter2->distance / djk > cur_alpha) { + blockers[t].insert(_location_to_seller[iter->id]); + } + } + } + else if (_dist_metric == diskann::Metric::INNER_PRODUCT) + { + // Improvization for flipping max and min dist for MIPS + float x = -iter2->distance; + float y = -djk; + if (y > cur_alpha * x) + { + occlude_factor[t] = std::max(occlude_factor[t], eps); + if (_diverse_index) + blockers[t].insert(_location_to_seller[iter->id]); + } + } + } + } + } + cur_alpha *= 1.2f; + } +} + template void Index::prune_neighbors(const uint32_t location, std::vector &pool, @@ -1530,6 +1711,12 @@ void Index::build_with_data_populated(const std::vector & } } + if (_diverse_index) { + uint64_t nrows; + parse_seller_file(_seller_file, nrows); + std::cout<<"Parsed seller file with " << nrows <<" rows" << std::endl; + } + uint32_t index_R = _indexingRange; uint32_t num_threads_index = _indexingThreads; uint32_t index_L = _indexingQueueSize; @@ -1842,6 +2029,53 @@ void Index::parse_label_file(const std::string &label_file, siz diskann::cout << "Identified " << _labels.size() << " distinct label(s)" << std::endl; } + +template +void Index::parse_seller_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + _location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + template void Index::_set_universal_label(const LabelType universal_label) { @@ -1955,10 +2189,46 @@ std::pair Index::_search(const DataType &qu } } + +template +std::pair Index::_diverse_search(const DataType &query, const size_t K, const uint32_t L, const uint32_t maxLperSeller, + std::any &indices, float *distances) +{ + try + { + auto typed_query = std::any_cast(query); + if (typeid(uint32_t *) == indices.type()) + { + auto u32_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u32_ptr, distances, maxLperSeller); + } + else if (typeid(uint64_t *) == indices.type()) + { + auto u64_ptr = std::any_cast(indices); + return this->search(typed_query, K, L, u64_ptr, distances, maxLperSeller); + } + else + { + throw ANNException("Error: indices type can only be uint64_t or uint32_t.", -1); + } + } + catch (const std::bad_any_cast &e) + { + throw ANNException("Error: bad any cast while searching. " + std::string(e.what()), -1); + } + catch (const std::exception &e) + { + throw ANNException("Error: " + std::string(e.what()), -1); + } +} + + + + template template std::pair Index::search(const T *query, const size_t K, const uint32_t L, - IdType *indices, float *distances) + IdType *indices, float *distances, const uint32_t maxLperSeller) { if (K > (uint64_t)L) { @@ -1983,26 +2253,31 @@ std::pair Index::search(const T *query, con _data_store->preprocess_query(query, scratch); - auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true); + auto retval = iterate_to_fixed_point(scratch, L, init_ids, false, unused_filter_label, true, maxLperSeller); - NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); + NeighborPriorityQueue *best_L_nodes; + if (maxLperSeller == 0) { + best_L_nodes = &(scratch->best_l_nodes()); + } else { + best_L_nodes = &(scratch->best_diverse_nodes().best_L_nodes); + } size_t pos = 0; - for (size_t i = 0; i < best_L_nodes.size(); ++i) + for (size_t i = 0; i < best_L_nodes->size(); ++i) { - if (best_L_nodes[i].id < _max_points) + if ((*best_L_nodes)[i].id < _max_points) { // safe because Index uses uint32_t ids internally // and IDType will be uint32_t or uint64_t - indices[pos] = (IdType)best_L_nodes[i].id; + indices[pos] = (IdType)(*best_L_nodes)[i].id; if (distances != nullptr) { #ifdef EXEC_ENV_OLS // DLVS expects negative distances distances[pos] = best_L_nodes[i].distance; #else - distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * best_L_nodes[i].distance - : best_L_nodes[i].distance; + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT ? -1 * ((*best_L_nodes)[i].distance) + : (*best_L_nodes)[i].distance; #endif } pos++; @@ -2010,9 +2285,11 @@ std::pair Index::search(const T *query, con if (pos == K) break; } - if (pos < K) + while (pos < K) { - diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; + indices[pos] = std::numeric_limits::max(); + pos++; +// diskann::cerr << "Found pos: " << pos << "fewer than K elements " << K << " for query" << std::endl; } return retval; @@ -3355,30 +3632,31 @@ template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT class Index; template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); + template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint32_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, @@ -3419,30 +3697,30 @@ template DISKANN_DLLEXPORT std::pair Index Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const float *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const uint8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint64_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search( - const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances); + const int8_t *query, const size_t K, const uint32_t L, uint32_t *indices, float *distances, const uint32_t maxLperSeller); template DISKANN_DLLEXPORT std::pair Index::search_with_filters< uint64_t>(const float *query, const uint16_t &filter_label, const size_t K, const uint32_t L, uint64_t *indices, diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index fbb81d55f..5510e0b15 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -131,7 +131,7 @@ void PQFlashIndex::setup_thread_data(uint64_t nthreads, uint64_t visi { #pragma omp critical { - SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve); + SSDThreadData *data = new SSDThreadData(this->_aligned_dim, visited_reserve, this->_location_to_seller); this->reader->register_thread(); data->ctx = this->reader->get_ctx(); this->_thread_data.push(data); @@ -326,7 +326,7 @@ void PQFlashIndex::generate_cache_list_from_sample_queries(std::strin // concurrently update the node_visit_counter to track most visited nodes. The last false is to not use the // "use_reorder_data" option which enables a final reranking if the disk index itself contains only PQ data. cached_beam_search(samples + (i * sample_aligned_dim), 1, l_search, tmp_result_ids_64.data() + i, - tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, false); + tmp_result_dists.data() + i, beamwidth, filtered_search, label_for_search, std::numeric_limits::max(), false); } std::sort(this->_node_visit_counter.begin(), _node_visit_counter.end(), @@ -752,6 +752,53 @@ void PQFlashIndex::parse_label_file(std::basic_istream &infile, reset_stream_for_reading(infile); } +template +void PQFlashIndex::parse_seller_file(const std::string &label_file, size_t &num_points) +{ + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) + { + throw diskann::ANNException(std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + uint32_t line_cnt = 0; + std::set sellers; + while (std::getline(infile, line)) + { + line_cnt++; + } + _location_to_seller.resize(line_cnt); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) + { + std::istringstream iss(line); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + uint32_t seller; + while (getline(new_iss, token, ',')) + { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + uint32_t token_as_num = (uint32_t)std::stoul(token); + seller = token_as_num; + sellers.insert(seller); + } + + _location_to_seller[line_cnt] = seller; + line_cnt++; + } + num_points = (size_t)line_cnt; + diskann::cout << "Identified " << sellers.size() << " distinct seller(s) across " << num_points <<" points." << std::endl; +} + + template void PQFlashIndex::set_universal_label(const LabelT &label) { _use_universal_label = true; @@ -1013,6 +1060,19 @@ int PQFlashIndex::load_from_separate_paths(uint32_t num_threads, cons << std::endl; } + +#ifndef EXEC_ENV_OLS +// TODO: Make this friendly for DLVS + this->_seller_file = std ::string(index_filepath) + "_sellers.txt"; + std::cout<_seller_file << std::endl; + if(file_exists(this->_seller_file)) { + uint64_t nrows_seller_file; + parse_seller_file(this->_seller_file, nrows_seller_file); + this->_diverse_index = true; + } +#endif + + // read index metadata #ifdef EXEC_ENV_OLS // This is a bit tricky. We have to read the header from the @@ -1241,31 +1301,31 @@ bool getNextCompletedRequest(std::shared_ptr &reader, IOConte template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, + uint64_t *indices, float *distances, const uint64_t beam_width, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, - const bool use_filter, const LabelT &filter_label, + const bool use_filter, const LabelT &filter_label, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, use_filter, filter_label, - std::numeric_limits::max(), use_reorder_data, stats); + std::numeric_limits::max(), max_l_per_seller, use_reorder_data, stats); } template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, - uint64_t *indices, float *distances, const uint64_t beam_width, - const uint32_t io_limit, const bool use_reorder_data, + uint64_t *indices, float *distances, const uint64_t beam_width, + const uint32_t io_limit, const uint32_t max_l_per_seller, const bool use_reorder_data, QueryStats *stats) { LabelT dummy_filter = 0; - cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, + cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, false, dummy_filter, io_limit, max_l_per_seller, use_reorder_data, stats); } @@ -1273,10 +1333,18 @@ template void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t k_search, const uint64_t l_search, uint64_t *indices, float *distances, const uint64_t beam_width, const bool use_filter, const LabelT &filter_label, - const uint32_t io_limit, const bool use_reorder_data, + const uint32_t io_limit, const uint32_t max_k_per_seller, const bool use_reorder_data, QueryStats *stats) { + bool diverse_search = false; + uint32_t max_l_per_seller = std::numeric_limits::max(); + if (max_k_per_seller != std::numeric_limits::max()) + { + diverse_search = true; + max_l_per_seller = max_k_per_seller * (l_search / k_search); + } + uint64_t num_sector_per_nodes = DIV_ROUND_UP(_max_node_len, defaults::SECTOR_LEN); if (beam_width > num_sector_per_nodes * defaults::MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than defaults::MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, @@ -1358,8 +1426,18 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t Timer query_timer, io_timer, cpu_timer; tsl::robin_set &visited = query_scratch->visited; - NeighborPriorityQueue &retset = query_scratch->retset; - retset.reserve(l_search); + //NeighborPriorityQueue &retset = query_scratch->retset; + bestCandidates &best_diverse_nodes_ref = query_scratch->best_diverse_nodes; + + NeighborPriorityQueue* retset; + if(diverse_search) { + best_diverse_nodes_ref.setup(l_search, max_l_per_seller); + retset = &(best_diverse_nodes_ref.best_L_nodes); + } else { + retset = &(query_scratch->retset); + retset->reserve(l_search); + } + std::vector &full_retset = query_scratch->full_retset; uint32_t best_medoid = 0; @@ -1402,7 +1480,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } compute_dists(&best_medoid, 1, dist_scratch); - retset.insert(Neighbor(best_medoid, dist_scratch[0])); + if (diverse_search) { + best_diverse_nodes_ref.insert(best_medoid, dist_scratch[0]); + } else { + retset->insert(Neighbor(best_medoid, dist_scratch[0])); + } + + //retset->insert(Neighbor(best_medoid, dist_scratch[0])); visited.insert(best_medoid); uint32_t cmps = 0; @@ -1419,7 +1503,7 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::vector>> cached_nhoods; cached_nhoods.reserve(2 * beam_width); - while (retset.has_unexpanded_node() && num_ios < io_limit) + while (retset->has_unexpanded_node() && num_ios < io_limit) { // clear iteration state frontier.clear(); @@ -1429,9 +1513,9 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t sector_scratch_idx = 0; // find new beam uint32_t num_seen = 0; - while (retset.has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) + while (retset->has_unexpanded_node() && frontier.size() < beam_width && num_seen < beam_width) { - auto nbr = retset.closest_unexpanded(); + auto nbr = retset->closest_unexpanded(); num_seen++; auto iter = _nhood_cache.find(nbr.id); if (iter != _nhood_cache.end()) @@ -1533,8 +1617,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t continue; cmps++; float dist = dist_scratch[m]; - Neighbor nn(id, dist); - retset.insert(nn); + +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } } @@ -1602,7 +1691,12 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } Neighbor nn(id, dist); - retset.insert(nn); +// retset->insert(nn); + if (diverse_search) { + best_diverse_nodes_ref.insert(id, dist); + } else { + retset->insert(Neighbor(id, dist)); + } } } @@ -1616,12 +1710,13 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t } // re-sort by distance + std::sort(full_retset.begin(), full_retset.end()); if (use_reorder_data) { if (!(this->_reorder_data_exists)) - { + { throw ANNException("Requested use of reordering data which does " "not exist in index " "file", @@ -1668,6 +1763,17 @@ void PQFlashIndex::cached_beam_search(const T *query1, const uint64_t std::sort(full_retset.begin(), full_retset.end()); } + + if (diverse_search) { + best_diverse_nodes_ref.clear(); + best_diverse_nodes_ref.setup(k_search, max_k_per_seller); + + for (auto &x : full_retset) { + best_diverse_nodes_ref.insert(x.id, x.distance); + } + full_retset = best_diverse_nodes_ref.best_L_nodes._data; + } + // copy k_search values for (uint64_t i = 0; i < k_search; i++) { @@ -1725,7 +1831,7 @@ uint32_t PQFlashIndex::range_search(const T *query1, const double ran cur_bw = (cur_bw > 100) ? 100 : cur_bw; for (auto &x : distances) x = std::numeric_limits::max(); - this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, false, stats); + this->cached_beam_search(query1, l_search, l_search, indices.data(), distances.data(), cur_bw, std::numeric_limits::max(), false, stats); for (uint32_t i = 0; i < l_search; i++) { if (distances[i] > (float)range) diff --git a/src/scratch.cpp b/src/scratch.cpp index 1f8a34bb1..a7a7e9c98 100644 --- a/src/scratch.cpp +++ b/src/scratch.cpp @@ -14,8 +14,8 @@ namespace diskann // template InMemQueryScratch::InMemQueryScratch(uint32_t search_l, uint32_t indexing_l, uint32_t r, uint32_t maxc, size_t dim, - size_t aligned_dim, size_t alignment_factor, bool init_pq_scratch) - : _L(0), _R(r), _maxc(maxc) + size_t aligned_dim, size_t alignment_factor, std::vector &location_to_sellers, bool init_pq_scratch) + : _L(0), _R(r), _maxc(maxc), _best_diverse_nodes(location_to_sellers) { if (search_l == 0 || indexing_l == 0 || r == 0 || dim == 0) { @@ -56,6 +56,8 @@ template void InMemQueryScratch::clear() _expanded_nodes_set.clear(); _expanded_nghrs_vec.clear(); _occlude_list_output.clear(); + _best_diverse_nodes.clear(); + } template void InMemQueryScratch::resize_for_new_L(uint32_t new_l) @@ -91,9 +93,10 @@ template void SSDQueryScratch::reset() visited.clear(); retset.clear(); full_retset.clear(); + best_diverse_nodes.clear(); } -template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve) +template SSDQueryScratch::SSDQueryScratch(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : best_diverse_nodes(location_to_sellers) { size_t coord_alloc_size = ROUND_UP(sizeof(T) * aligned_dim, 256); @@ -121,7 +124,7 @@ template SSDQueryScratch::~SSDQueryScratch() } template -SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve) : scratch(aligned_dim, visited_reserve) +SSDThreadData::SSDThreadData(size_t aligned_dim, size_t visited_reserve, std::vector &location_to_sellers) : scratch(aligned_dim, visited_reserve, location_to_sellers) { } diff --git a/src/utils.cpp b/src/utils.cpp index 3773cda22..74875d4be 100644 --- a/src/utils.cpp +++ b/src/utils.cpp @@ -126,14 +126,19 @@ void normalize_data_file(const std::string &inFileName, const std::string &outFi diskann::cout << "Wrote normalized points to file: " << outFileName << std::endl; } + double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist, uint32_t dim_gs, - uint32_t *our_results, uint32_t dim_or, uint32_t recall_at) + uint32_t *our_results, uint32_t dim_or, uint32_t recall_at, float* algo_distances) { + bool use_distances_to_break_ties = false; + if (algo_distances != nullptr) { + use_distances_to_break_ties = true; + } double total_recall = 0; std::set gt, res; - for (size_t i = 0; i < num_queries; i++) { + if (!use_distances_to_break_ties) { gt.clear(); res.clear(); uint32_t *gt_vec = gold_std + dim_gs * i; @@ -160,6 +165,14 @@ double calculate_recall(uint32_t num_queries, uint32_t *gold_std, float *gs_dist } } total_recall += cur_recall; + } else { // only works if dim_or == dim_gs. Not for the k-recall@k' regime. + uint32_t cur_recall =0; + for (uint32_t rr = 0; rr < std::min(dim_or, dim_gs); rr++) { + if (algo_distances[i*dim_or + rr] <= gs_dist[i*dim_gs + (recall_at-1)]) + cur_recall++; + } + total_recall += cur_recall; + } } return total_recall / (num_queries) * (100.0 / recall_at); }