From aabc066d73f7b4f4ef2aff93044793c7b57846f6 Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Thu, 12 Oct 2023 16:34:22 +0800 Subject: [PATCH] Improve DiskANN Hybrid Search via Alpha-Strategy Signed-off-by: Patrick Weizhi Xu --- thirdparty/DiskANN/src/pq_flash_index.cpp | 112 ++++++++++++---------- 1 file changed, 60 insertions(+), 52 deletions(-) diff --git a/thirdparty/DiskANN/src/pq_flash_index.cpp b/thirdparty/DiskANN/src/pq_flash_index.cpp index 5415a1f4a..8e5219c7f 100644 --- a/thirdparty/DiskANN/src/pq_flash_index.cpp +++ b/thirdparty/DiskANN/src/pq_flash_index.cpp @@ -54,14 +54,10 @@ ((((_u64) (id)) % nvecs_per_sector) * data_dim * sizeof(float)) namespace { - constexpr size_t kReadBatchSize = 32; constexpr _u64 kRefineBeamWidthFactor = 2; constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2; - auto calcFilterThreshold = [](const auto topk) -> const float { - return std::max(-0.04570166137874405f * log2(topk + 58.96422392240403) + - 1.1982775974217197, - 0.5); - }; + constexpr float kFilterThreshold = 0.93f; + constexpr float kAlpha = 0.15f; } // namespace namespace diskann { @@ -1016,7 +1012,7 @@ namespace diskann { if (!bitset_view.empty()) { const auto filter_threshold = - filter_ratio_in < 0 ? calcFilterThreshold(k_search) : filter_ratio_in; + filter_ratio_in < 0 ? kFilterThreshold : filter_ratio_in; const auto bv_cnt = bitset_view.count(); if (bitset_view.size() == bv_cnt) { for (_u64 i = 0; i < k_search; i++) { @@ -1116,6 +1112,31 @@ namespace diskann { unsigned num_ios = 0; unsigned k = 0; + float accumulative_alpha = 0; + std::vector filtered_nbrs; + filtered_nbrs.reserve(this->max_degree); + auto filter_nbrs = [&](_u64 nnbrs, + unsigned *node_nbrs) -> std::pair<_u64, unsigned *> { + filtered_nbrs.clear(); + for (_u64 m = 0; m < nnbrs; ++m) { + unsigned id = node_nbrs[m]; + if (visited.find(id) != visited.end()) { + continue; + } + visited.insert(id); + if (!bitset_view.empty() && bitset_view.test(id)) { + accumulative_alpha += kAlpha; + if (accumulative_alpha < 1.0f) { + continue; + } + accumulative_alpha -= 1.0f; + } + cmps++; + filtered_nbrs.push_back(id); + } + return {filtered_nbrs.size(), filtered_nbrs.data()}; + }; + while (k < cur_list_size) { auto nk = cur_list_size; // clear iteration state @@ -1219,8 +1240,8 @@ namespace diskann { feder->id_set_.insert(cached_nhood.first); } } - _u64 nnbrs = cached_nhood.second.first; - unsigned *node_nbrs = cached_nhood.second.second; + auto [nnbrs, node_nbrs] = + filter_nbrs(cached_nhood.second.first, cached_nhood.second.second); // compute node_nbrs <-> query dists in PQ space cpu_timer.reset(); @@ -1241,26 +1262,20 @@ namespace diskann { feder->id_set_.insert(id); } - if (visited.find(id) != visited.end()) { + float dist = dist_scratch[m]; + if (cur_list_size > 0 && + dist >= retset[cur_list_size - 1].distance && + (cur_list_size == l_search)) continue; - } else { - visited.insert(id); - cmps++; - float dist = dist_scratch[m]; - if (cur_list_size > 0 && - dist >= retset[cur_list_size - 1].distance && - (cur_list_size == l_search)) - continue; - Neighbor nn(id, dist, true); - // Return position in sorted list where nn inserted. - auto r = InsertIntoPool(retset.data(), cur_list_size, nn); - if (cur_list_size < l_search) - ++cur_list_size; - if (r < nk) - // nk logs the best position in the retset that was - // updated due to neighbors of n. - nk = r; - } + Neighbor nn(id, dist, true); + // Return position in sorted list where nn inserted. + auto r = InsertIntoPool(retset.data(), cur_list_size, nn); + if (cur_list_size < l_search) + ++cur_list_size; + if (r < nk) + // nk logs the best position in the retset that was + // updated due to neighbors of n. + nk = r; } } #ifdef USE_BING_INFRA @@ -1282,7 +1297,6 @@ namespace diskann { char *node_disk_buf = get_offset_to_node(frontier_nhood.second, frontier_nhood.first); unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf); - _u64 nnbrs = (_u64) (*node_buf); T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf); T *node_fp_coords_copy = data_buf; @@ -1312,7 +1326,7 @@ namespace diskann { feder->id_set_.insert(frontier_nhood.first); } } - unsigned *node_nbrs = (node_buf + 1); + auto [nnbrs, node_nbrs] = filter_nbrs(*node_buf, (node_buf + 1)); // compute node_nbrs <-> query dist in PQ space cpu_timer.reset(); compute_dists(node_nbrs, nnbrs, dist_scratch); @@ -1333,29 +1347,23 @@ namespace diskann { feder->id_set_.insert(frontier_nhood.first); } - if (visited.find(id) != visited.end()) { - continue; - } else { - visited.insert(id); - cmps++; - float dist = dist_scratch[m]; - if (stats != nullptr) { - stats->n_cmps++; - } - if (cur_list_size > 0 && - dist >= retset[cur_list_size - 1].distance && - (cur_list_size == l_search)) - continue; - Neighbor nn(id, dist, true); - auto r = InsertIntoPool( - retset.data(), cur_list_size, - nn); // Return position in sorted list where nn inserted. - if (cur_list_size < l_search) - ++cur_list_size; - if (r < nk) - nk = r; // nk logs the best position in the retset that was - // updated due to neighbors of n. + float dist = dist_scratch[m]; + if (stats != nullptr) { + stats->n_cmps++; } + if (cur_list_size > 0 && + dist >= retset[cur_list_size - 1].distance && + (cur_list_size == l_search)) + continue; + Neighbor nn(id, dist, true); + auto r = InsertIntoPool( + retset.data(), cur_list_size, + nn); // Return position in sorted list where nn inserted. + if (cur_list_size < l_search) + ++cur_list_size; + if (r < nk) + nk = r; // nk logs the best position in the retset that was + // updated due to neighbors of n. } if (stats != nullptr) {