Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve DiskANN Hybrid Search via Alpha-Strategy #143

Merged
merged 1 commit into from
Oct 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 60 additions & 52 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,10 @@
((((_u64) (id)) % nvecs_per_sector) * data_dim * sizeof(float))

namespace {
constexpr size_t kReadBatchSize = 32;
constexpr _u64 kRefineBeamWidthFactor = 2;
constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2;
auto calcFilterThreshold = [](const auto topk) -> const float {
return std::max(-0.04570166137874405f * log2(topk + 58.96422392240403) +
1.1982775974217197,
0.5);
};
constexpr float kFilterThreshold = 0.93f;
constexpr float kAlpha = 0.15f;
} // namespace

namespace diskann {
Expand Down Expand Up @@ -1016,7 +1012,7 @@ namespace diskann {

if (!bitset_view.empty()) {
const auto filter_threshold =
filter_ratio_in < 0 ? calcFilterThreshold(k_search) : filter_ratio_in;
filter_ratio_in < 0 ? kFilterThreshold : filter_ratio_in;
const auto bv_cnt = bitset_view.count();
if (bitset_view.size() == bv_cnt) {
for (_u64 i = 0; i < k_search; i++) {
Expand Down Expand Up @@ -1116,6 +1112,31 @@ namespace diskann {
unsigned num_ios = 0;
unsigned k = 0;

float accumulative_alpha = 0;
std::vector<unsigned> filtered_nbrs;
filtered_nbrs.reserve(this->max_degree);
auto filter_nbrs = [&](_u64 nnbrs,
unsigned *node_nbrs) -> std::pair<_u64, unsigned *> {
filtered_nbrs.clear();
for (_u64 m = 0; m < nnbrs; ++m) {
unsigned id = node_nbrs[m];
if (visited.find(id) != visited.end()) {
continue;
}
visited.insert(id);
if (!bitset_view.empty() && bitset_view.test(id)) {
accumulative_alpha += kAlpha;
if (accumulative_alpha < 1.0f) {
continue;
}
accumulative_alpha -= 1.0f;
}
cmps++;
filtered_nbrs.push_back(id);
}
return {filtered_nbrs.size(), filtered_nbrs.data()};
};

while (k < cur_list_size) {
auto nk = cur_list_size;
// clear iteration state
Expand Down Expand Up @@ -1219,8 +1240,8 @@ namespace diskann {
feder->id_set_.insert(cached_nhood.first);
}
}
_u64 nnbrs = cached_nhood.second.first;
unsigned *node_nbrs = cached_nhood.second.second;
auto [nnbrs, node_nbrs] =
filter_nbrs(cached_nhood.second.first, cached_nhood.second.second);

// compute node_nbrs <-> query dists in PQ space
cpu_timer.reset();
Expand All @@ -1241,26 +1262,20 @@ namespace diskann {
feder->id_set_.insert(id);
}

if (visited.find(id) != visited.end()) {
float dist = dist_scratch[m];
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
} else {
visited.insert(id);
cmps++;
float dist = dist_scratch[m];
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
// Return position in sorted list where nn inserted.
auto r = InsertIntoPool(retset.data(), cur_list_size, nn);
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
// nk logs the best position in the retset that was
// updated due to neighbors of n.
nk = r;
}
Neighbor nn(id, dist, true);
// Return position in sorted list where nn inserted.
auto r = InsertIntoPool(retset.data(), cur_list_size, nn);
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
// nk logs the best position in the retset that was
// updated due to neighbors of n.
nk = r;
}
}
#ifdef USE_BING_INFRA
Expand All @@ -1282,7 +1297,6 @@ namespace diskann {
char *node_disk_buf =
get_offset_to_node(frontier_nhood.second, frontier_nhood.first);
unsigned *node_buf = OFFSET_TO_NODE_NHOOD(node_disk_buf);
_u64 nnbrs = (_u64) (*node_buf);
T *node_fp_coords = OFFSET_TO_NODE_COORDS(node_disk_buf);

T *node_fp_coords_copy = data_buf;
Expand Down Expand Up @@ -1312,7 +1326,7 @@ namespace diskann {
feder->id_set_.insert(frontier_nhood.first);
}
}
unsigned *node_nbrs = (node_buf + 1);
auto [nnbrs, node_nbrs] = filter_nbrs(*node_buf, (node_buf + 1));
// compute node_nbrs <-> query dist in PQ space
cpu_timer.reset();
compute_dists(node_nbrs, nnbrs, dist_scratch);
Expand All @@ -1333,29 +1347,23 @@ namespace diskann {
feder->id_set_.insert(frontier_nhood.first);
}

if (visited.find(id) != visited.end()) {
continue;
} else {
visited.insert(id);
cmps++;
float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
auto r = InsertIntoPool(
retset.data(), cur_list_size,
nn); // Return position in sorted list where nn inserted.
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
nk = r; // nk logs the best position in the retset that was
// updated due to neighbors of n.
float dist = dist_scratch[m];
if (stats != nullptr) {
stats->n_cmps++;
}
if (cur_list_size > 0 &&
dist >= retset[cur_list_size - 1].distance &&
(cur_list_size == l_search))
continue;
Neighbor nn(id, dist, true);
auto r = InsertIntoPool(
retset.data(), cur_list_size,
nn); // Return position in sorted list where nn inserted.
if (cur_list_size < l_search)
++cur_list_size;
if (r < nk)
nk = r; // nk logs the best position in the retset that was
// updated due to neighbors of n.
}

if (stats != nullptr) {
Expand Down
Loading