Skip to content

Commit

Permalink
sparse: encapsulate cursor operation so that more algorithms can be i…
Browse files Browse the repository at this point in the history
…ntroduced

Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack committed Jan 9, 2025
1 parent 770faac commit 7f5f76e
Showing 1 changed file with 101 additions and 98 deletions.
199 changes: 101 additions & 98 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,20 +439,26 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
if (drop_ratio_search == 0) {
refine_factor = 1;
}

auto q_vec = parse_query(query, q_threshold);
if (q_vec.empty()) {
return;
}

MaxMinHeap<float> heap(k * refine_factor);
// DAAT_WAND and DAAT_MAXSCORE are based on the implementation in PISA.
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) {
search_daat_wand(query, q_threshold, heap, bitset, computer);
search_daat_wand(q_vec, heap, bitset, computer);
} else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
search_daat_maxscore(query, q_threshold, heap, bitset, computer);
search_daat_maxscore(q_vec, heap, bitset, computer);
} else {
search_taat_naive(query, q_threshold, heap, bitset, computer);
search_taat_naive(q_vec, heap, bitset, computer);
}

if (refine_factor == 1) {
collect_result(heap, distances, labels);
} else {
refine_and_collect(query, heap, k, distances, labels, computer);
refine_and_collect(q_vec, heap, k, distances, labels, computer);
}
}

Expand All @@ -468,7 +474,9 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
values[i] = std::abs(query[i].val);
}
auto q_threshold = get_threshold(values, drop_ratio_search);
auto distances = compute_all_distances(query, q_threshold, computer);
auto q_vec = parse_query(query, q_threshold);

auto distances = compute_all_distances(q_vec, computer);
if (!bitset.empty()) {
for (size_t i = 0; i < distances.size(); ++i) {
if (bitset.test(i)) {
Expand Down Expand Up @@ -556,26 +564,16 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}

std::vector<float>
compute_all_distances(const SparseRow<DType>& q_vec, DType q_threshold,
const DocValueComputer<float>& computer) const {
compute_all_distances(std::vector<std::pair<size_t, DType>>& q_vec, const DocValueComputer<float>& computer) const {
std::vector<float> scores(n_rows_internal_, 0.0f);
for (size_t idx = 0; idx < q_vec.size(); ++idx) {
auto [i, v] = q_vec[idx];
if (v < q_threshold || i >= max_dim_) {
continue;
}
auto dim_id = dim_map_.find(i);
if (dim_id == dim_map_.end()) {
continue;
}
auto& plist_ids = inverted_index_ids_[dim_id->second];
auto& plist_vals = inverted_index_vals_[dim_id->second];
for (size_t i = 0; i < q_vec.size(); ++i) {
auto& plist_ids = inverted_index_ids_[q_vec[i].first];
auto& plist_vals = inverted_index_vals_[q_vec[i].first];
// TODO: improve with SIMD
for (size_t j = 0; j < plist_ids.size(); ++j) {
auto doc_id = plist_ids[j];
auto val = plist_vals[j];
float val_sum = bm25 ? bm25_params_->row_sums.at(doc_id) : 0;
scores[doc_id] += v * computer(val, val_sum);
scores[doc_id] += q_vec[i].second * computer(plist_vals[j], val_sum);
}
}
return scores;
Expand Down Expand Up @@ -644,14 +642,43 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}
}; // struct Cursor

std::vector<std::pair<size_t, DType>>
parse_query(const SparseRow<DType>& query, DType q_threshold) const {
std::vector<std::pair<size_t, DType>> filtered_query;
for (size_t i = 0; i < query.size(); ++i) {
auto [idx, val] = query[i];
auto dim_id = dim_map_.find(idx);
if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) {
continue;
}
filtered_query.emplace_back(dim_id->second, val);
}
return filtered_query;
}

template <typename DocIdFilter>
std::vector<Cursor<DocIdFilter>>
make_cursors(const std::vector<std::pair<size_t, DType>>& q_vec, const DocValueComputer<float>& computer,
DocIdFilter& filter) const {
std::vector<Cursor<DocIdFilter>> cursors;
cursors.reserve(q_vec.size());
for (auto q_dim : q_vec) {
auto& plist_ids = inverted_index_ids_[q_dim.first];
auto& plist_vals = inverted_index_vals_[q_dim.first];
cursors.emplace_back(plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[q_dim.first] * q_dim.second,
q_dim.second, filter);
}
return cursors;
}

// find the top-k candidates using brute force search, k as specified by the capacity of the heap.
// any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored.
// TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed.
template <typename DocIdFilter>
void
search_taat_naive(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
search_taat_naive(std::vector<std::pair<size_t, DType>>& q_vec, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
auto scores = compute_all_distances(q_vec, q_threshold, computer);
auto scores = compute_all_distances(q_vec, computer);
for (size_t i = 0; i < n_rows_internal_; ++i) {
if ((filter.empty() || !filter.test(i)) && scores[i] != 0) {
heap.push(i, scores[i]);
Expand All @@ -662,40 +689,31 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
// any value in q_vec that is smaller than q_threshold will be ignored.
template <typename DocIdFilter>
void
search_daat_wand(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
search_daat_wand(std::vector<std::pair<size_t, DType>>& q_vec, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<DocIdFilter>>> cursors(q_dim);
size_t valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
auto dim_id = dim_map_.find(idx);
if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) {
continue;
}
auto& plist_ids = inverted_index_ids_[dim_id->second];
auto& plist_vals = inverted_index_vals_[dim_id->second];
cursors[valid_q_dim++] = std::make_shared<Cursor<DocIdFilter>>(
plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter);
}
if (valid_q_dim == 0) {
return;
std::vector<Cursor<DocIdFilter>> cursors = make_cursors(q_vec, computer, filter);
std::vector<Cursor<DocIdFilter>*> cursor_ptrs(cursors.size());
for (size_t i = 0; i < cursors.size(); ++i) {
cursor_ptrs[i] = &cursors[i];
}
cursors.resize(valid_q_dim);
auto sort_cursors = [&cursors] {
std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; });

auto sort_cursors = [&cursor_ptrs] {
std::sort(cursor_ptrs.begin(), cursor_ptrs.end(),
[](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; });
};
sort_cursors();

while (true) {
float threshold = heap.full() ? heap.top().val : 0;
float upper_bound = 0;
size_t pivot;

bool found_pivot = false;
for (pivot = 0; pivot < valid_q_dim; ++pivot) {
if (cursors[pivot]->loc_ >= cursors[pivot]->plist_size_) {
for (pivot = 0; pivot < q_vec.size(); ++pivot) {
if (cursor_ptrs[pivot]->cur_vec_id_ >= n_rows_internal_) {
break;
}
upper_bound += cursors[pivot]->max_score_;
upper_bound += cursor_ptrs[pivot]->max_score_;
if (upper_bound > threshold) {
found_pivot = true;
break;
Expand All @@ -704,72 +722,58 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
if (!found_pivot) {
break;
}
table_t pivot_id = cursors[pivot]->cur_vec_id_;
if (pivot_id == cursors[0]->cur_vec_id_) {

table_t pivot_id = cursor_ptrs[pivot]->cur_vec_id_;
if (pivot_id == cursor_ptrs[0]->cur_vec_id_) {
float score = 0;
for (auto& cursor : cursors) {
if (cursor->cur_vec_id_ != pivot_id) {
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(pivot_id) : 0;
for (auto& cursor_ptr : cursor_ptrs) {
if (cursor_ptr->cur_vec_id_ != pivot_id) {
break;
}
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0;
score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum);
cursor->next();
score += cursor_ptr->q_value_ * computer(cursor_ptr->cur_vec_val(), cur_vec_sum);
cursor_ptr->next();
}
heap.push(pivot_id, score);
sort_cursors();
} else {
size_t next_list = pivot;
for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) {
for (; cursor_ptrs[next_list]->cur_vec_id_ == pivot_id; --next_list) {
}
cursors[next_list]->seek(pivot_id);
for (size_t i = next_list + 1; i < valid_q_dim; ++i) {
if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) {
cursor_ptrs[next_list]->seek(pivot_id);
for (size_t i = next_list + 1; i < q_vec.size(); ++i) {
if (cursor_ptrs[i]->cur_vec_id_ >= cursor_ptrs[i - 1]->cur_vec_id_) {
break;
}
std::swap(cursors[i], cursors[i - 1]);
std::swap(cursor_ptrs[i], cursor_ptrs[i - 1]);
}
}
}
}

template <typename DocIdFilter>
void
search_daat_maxscore(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap,
const DocIdFilter& filter, const DocValueComputer<float>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<DocIdFilter>>> cursors(q_dim);
size_t valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
auto dim_id = dim_map_.find(idx);
if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) {
continue;
}
auto& plist_ids = inverted_index_ids_[dim_id->second];
auto& plist_vals = inverted_index_vals_[dim_id->second];
cursors[valid_q_dim++] = std::make_shared<Cursor<DocIdFilter>>(
plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter);
}
if (valid_q_dim == 0) {
return;
}
cursors.resize(valid_q_dim);
search_daat_maxscore(std::vector<std::pair<size_t, DType>>& q_vec, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
std::sort(q_vec.begin(), q_vec.end(), [this](auto& a, auto& b) {
return a.second * max_score_in_dim_[a.first] > b.second * max_score_in_dim_[b.first];
});

std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->max_score_ > y->max_score_; });
std::vector<Cursor<DocIdFilter>> cursors = make_cursors(q_vec, computer, filter);

float threshold = heap.full() ? heap.top().val : 0;

std::vector<float> upper_bounds(cursors.size());
float bound_sum = 0.0;
for (size_t i = cursors.size() - 1; i + 1 > 0; --i) {
bound_sum += cursors[i]->max_score_;
bound_sum += cursors[i].max_score_;
upper_bounds[i] = bound_sum;
}

uint32_t next_cand_vec_id = n_rows_internal_;
table_t next_cand_vec_id = n_rows_internal_;
for (size_t i = 0; i < cursors.size(); ++i) {
if (cursors[i]->cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i]->cur_vec_id_;
if (cursors[i].cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i].cur_vec_id_;
}
}

Expand All @@ -784,7 +788,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}

float curr_cand_score = 0.0f;
uint32_t curr_cand_vec_id = 0;
table_t curr_cand_vec_id = 0;

while (curr_cand_vec_id < n_rows_internal_) {
auto found_cand = false;
Expand All @@ -798,15 +802,15 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
curr_cand_score = 0.0f;
// update next_cand_vec_id
next_cand_vec_id = n_rows_internal_;
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(curr_cand_vec_id) : 0;

for (size_t i = 0; i < first_ne_idx; ++i) {
if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) {
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0;
curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum);
cursors[i]->next();
if (cursors[i].cur_vec_id_ == curr_cand_vec_id) {
curr_cand_score += cursors[i].q_value_ * computer(cursors[i].cur_vec_val(), cur_vec_sum);
cursors[i].next();
}
if (cursors[i]->cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i]->cur_vec_id_;
if (cursors[i].cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i].cur_vec_id_;
}
}

Expand All @@ -816,10 +820,9 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
found_cand = false;
break;
}
cursors[i]->seek(curr_cand_vec_id);
if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) {
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0;
curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum);
cursors[i].seek(curr_cand_vec_id);
if (cursors[i].cur_vec_id_ == curr_cand_vec_id) {
curr_cand_score += cursors[i].q_value_ * computer(cursors[i].cur_vec_val(), cur_vec_sum);
}
}
}
Expand All @@ -838,8 +841,8 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}

void
refine_and_collect(const SparseRow<DType>& q_vec, MaxMinHeap<float>& inacc_heap, size_t k, float* distances,
label_t* labels, const DocValueComputer<float>& computer) const {
refine_and_collect(std::vector<std::pair<size_t, DType>>& q_vec, MaxMinHeap<float>& inacc_heap, size_t k,
float* distances, label_t* labels, const DocValueComputer<float>& computer) const {
std::vector<table_t> docids;
MaxMinHeap<float> heap(k);

Expand All @@ -851,11 +854,11 @@ class InvertedIndex : public BaseInvertedIndex<DType> {

DocIdFilterByVector filter(std::move(docids));
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) {
search_daat_wand(q_vec, 0, heap, filter, computer);
search_daat_wand(q_vec, heap, filter, computer);
} else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
search_daat_maxscore(q_vec, 0, heap, filter, computer);
search_daat_maxscore(q_vec, heap, filter, computer);
} else {
search_taat_naive(q_vec, 0, heap, filter, computer);
search_taat_naive(q_vec, heap, filter, computer);
}
collect_result(heap, distances, labels);
}
Expand Down

0 comments on commit 7f5f76e

Please sign in to comment.