Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified search layer #310

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 31 additions & 107 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ namespace hnswlib {
VisitedListPool *visited_list_pool_;
std::mutex cur_element_count_guard_;

std::vector<std::mutex> link_list_locks_;
mutable std::vector<std::mutex> link_list_locks_;

// Locks to prevent race condition during update/insert of an element at same time.
// Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel.
Expand Down Expand Up @@ -158,95 +158,12 @@ namespace hnswlib {
return (int) r;
}


std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayer(tableint ep_id, const void *data_point, int layer) {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidateSet;

dist_t lowerBound;
if (!isMarkedDeleted(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
top_candidates.emplace(dist, ep_id);
lowerBound = dist;
candidateSet.emplace(-dist, ep_id);
} else {
lowerBound = std::numeric_limits<dist_t>::max();
candidateSet.emplace(-lowerBound, ep_id);
}
visited_array[ep_id] = visited_array_tag;

while (!candidateSet.empty()) {
std::pair<dist_t, tableint> curr_el_pair = candidateSet.top();
if ((-curr_el_pair.first) > lowerBound) {
break;
}
candidateSet.pop();

tableint curNodeNum = curr_el_pair.second;

std::unique_lock <std::mutex> lock(link_list_locks_[curNodeNum]);

int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_);
if (layer == 0) {
data = (int*)get_linklist0(curNodeNum);
} else {
data = (int*)get_linklist(curNodeNum, layer);
// data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_);
}
size_t size = getListCount((linklistsizeint*)data);
tableint *datal = (tableint *) (data + 1);
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0);
#endif

for (size_t j = 0; j < size; j++) {
tableint candidate_id = *(datal + j);
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0);
#endif
if (visited_array[candidate_id] == visited_array_tag) continue;
visited_array[candidate_id] = visited_array_tag;
char *currObj1 = (getDataByInternalId(candidate_id));

dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_);
if (top_candidates.size() < ef_construction_ || lowerBound > dist1) {
candidateSet.emplace(-dist1, candidate_id);
#ifdef USE_SSE
_mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0);
#endif

if (!isMarkedDeleted(candidate_id))
top_candidates.emplace(dist1, candidate_id);

if (top_candidates.size() > ef_construction_)
top_candidates.pop();

if (!top_candidates.empty())
lowerBound = top_candidates.top().first;
}
}
}
visited_list_pool_->releaseVisitedList(vl);

return top_candidates;
}

mutable std::atomic<long> metric_distance_computations;
mutable std::atomic<long> metric_hops;

template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
searchLayer(tableint ep_id, const void *data_point, size_t ef, int layer, bool is_st) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand Down Expand Up @@ -277,18 +194,28 @@ namespace hnswlib {
candidate_set.pop();

tableint current_node_id = current_node_pair.second;
int *data = (int *) get_linklist0(current_node_id);

std::unique_lock<std::mutex> lk(link_list_locks_[current_node_id]);
if (is_st)
lk.unlock();

int *data;
if (layer == 0) {
data = (int*)get_linklist0(current_node_id);
} else {
data = (int*)get_linklist(current_node_id, layer);
}
size_t size = getListCount((linklistsizeint*)data);
// bool cur_node_deleted = isMarkedDeleted(current_node_id);

if(collect_metrics){
metric_hops++;
metric_distance_computations+=size;
}

#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0);
_mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0);
_mm_prefetch(getDataByInternalId((tableint)(*(data + 1))), _MM_HINT_T0);
_mm_prefetch(getDataByInternalId((tableint)(*(data + 2))), _MM_HINT_T0);
_mm_prefetch((char *) (data + 2), _MM_HINT_T0);
#endif

Expand All @@ -297,10 +224,9 @@ namespace hnswlib {
// if (candidate_id == 0) continue;
#ifdef USE_SSE
_mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0);
_mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_,
_MM_HINT_T0);////////////
_mm_prefetch(getDataByInternalId((tableint)(*(data + j + 1))), _MM_HINT_T0);////////////
#endif
if (!(visited_array[candidate_id] == visited_array_tag)) {
if (visited_array[candidate_id] != visited_array_tag) {

visited_array[candidate_id] = visited_array_tag;

Expand All @@ -310,9 +236,7 @@ namespace hnswlib {
if (top_candidates.size() < ef || lowerBound > dist) {
candidate_set.emplace(-dist, candidate_id);
#ifdef USE_SSE
_mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ +
offsetLevel0_,///////////
_MM_HINT_T0);////////////////////////
_mm_prefetch(getDataByInternalId(candidate_set.top().second),_MM_HINT_T0);////////////////////////
#endif

if (!has_deletions || !isMarkedDeleted(candidate_id))
Expand Down Expand Up @@ -548,13 +472,13 @@ namespace hnswlib {
}

if (has_deletions_) {
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<true>(currObj, query_data,
ef_);
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchLayer<true>(currObj, query_data,
ef_, 0, true);
top_candidates.swap(top_candidates1);
}
else{
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchBaseLayerST<false>(currObj, query_data,
ef_);
std::priority_queue<std::pair<dist_t, tableint >> top_candidates1=searchLayer<false>(currObj, query_data,
ef_, 0, true);
top_candidates.swap(top_candidates1);
}

Expand Down Expand Up @@ -941,8 +865,8 @@ namespace hnswlib {
throw std::runtime_error("Level of item to be updated cannot be bigger than max level");

for (int level = dataPointLevel; level >= 0; level--) {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchBaseLayer(
currObj, dataPoint, level);
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> topCandidates = searchLayer<false>(
currObj, dataPoint, ef_construction_, level, false);

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> filteredTopCandidates;
while (topCandidates.size() > 0) {
Expand Down Expand Up @@ -1073,8 +997,8 @@ namespace hnswlib {
if (level > maxlevelcopy || level < 0) // possible?
throw std::runtime_error("Level error");

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchBaseLayer(
currObj, data_point, level);
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates = searchLayer<false>(
currObj, data_point, ef_construction_, level, false);
if (epDeleted) {
top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy);
if (top_candidates.size() > ef_construction_)
Expand Down Expand Up @@ -1136,12 +1060,12 @@ namespace hnswlib {

std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (has_deletions_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
top_candidates=searchLayer<true,true>(
currObj, query_data, std::max(ef_, k), 0, true);
}
else{
top_candidates=searchBaseLayerST<false,true>(
currObj, query_data, std::max(ef_, k));
top_candidates=searchLayer<false,true>(
currObj, query_data, std::max(ef_, k), 0, true);
}

while (top_candidates.size() > k) {
Expand Down