diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index 15b8fe8ca..1033a46a1 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -427,20 +427,26 @@ class InvertedIndex : public BaseInvertedIndex { if (drop_ratio_search == 0) { refine_factor = 1; } + + auto q_vec = parse_query(query, q_threshold); + if (q_vec.empty()) { + return; + } + MaxMinHeap heap(k * refine_factor); // DAAT_WAND and DAAT_MAXSCORE are based on the implementation in PISA. if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { - search_daat_wand(query, q_threshold, heap, bitset, computer); + search_daat_wand(q_vec, heap, bitset, computer); } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { - search_daat_maxscore(query, q_threshold, heap, bitset, computer); + search_daat_maxscore(q_vec, heap, bitset, computer); } else { - search_taat_naive(query, q_threshold, heap, bitset, computer); + search_taat_naive(q_vec, heap, bitset, computer); } if (refine_factor == 1) { collect_result(heap, distances, labels); } else { - refine_and_collect(query, heap, k, distances, labels, computer); + refine_and_collect(q_vec, heap, k, distances, labels, computer); } } @@ -456,7 +462,9 @@ class InvertedIndex : public BaseInvertedIndex { values[i] = std::abs(query[i].val); } auto q_threshold = get_threshold(values, drop_ratio_search); - auto distances = compute_all_distances(query, q_threshold, computer); + auto q_vec = parse_query(query, q_threshold); + + auto distances = compute_all_distances(q_vec, computer); if (!bitset.empty()) { for (size_t i = 0; i < distances.size(); ++i) { if (bitset.test(i)) { @@ -544,26 +552,16 @@ class InvertedIndex : public BaseInvertedIndex { } std::vector - compute_all_distances(const SparseRow& q_vec, DType q_threshold, - const DocValueComputer& computer) const { + compute_all_distances(std::vector>& q_vec, const DocValueComputer& computer) const { std::vector scores(n_rows_internal_, 0.0f); - for (size_t idx = 0; idx < q_vec.size(); ++idx) { - auto [i, v] = q_vec[idx]; - if (v < q_threshold || i >= max_dim_) { - continue; - } - auto dim_id = dim_map_.find(i); - if (dim_id == dim_map_.end()) { - continue; - } - auto& plist_ids = inverted_index_ids_[dim_id->second]; - auto& plist_vals = inverted_index_vals_[dim_id->second]; + for (size_t i = 0; i < q_vec.size(); ++i) { + auto& plist_ids = inverted_index_ids_[q_vec[i].first]; + auto& plist_vals = inverted_index_vals_[q_vec[i].first]; // TODO: improve with SIMD for (size_t j = 0; j < plist_ids.size(); ++j) { auto doc_id = plist_ids[j]; - auto val = plist_vals[j]; float val_sum = bm25 ? bm25_params_->row_sums.at(doc_id) : 0; - scores[doc_id] += v * computer(val, val_sum); + scores[doc_id] += q_vec[i].second * computer(plist_vals[j], val_sum); } } return scores; @@ -632,14 +630,43 @@ class InvertedIndex : public BaseInvertedIndex { } }; // struct Cursor + std::vector> + parse_query(const SparseRow& query, DType q_threshold) const { + std::vector> filtered_query; + for (size_t i = 0; i < query.size(); ++i) { + auto [idx, val] = query[i]; + auto dim_id = dim_map_.find(idx); + if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) { + continue; + } + filtered_query.emplace_back(dim_id->second, val); + } + return filtered_query; + } + + template + std::vector> + make_cursors(const std::vector>& q_vec, const DocValueComputer& computer, + DocIdFilter& filter) const { + std::vector> cursors; + cursors.reserve(q_vec.size()); + for (auto q_dim : q_vec) { + auto& plist_ids = inverted_index_ids_[q_dim.first]; + auto& plist_vals = inverted_index_vals_[q_dim.first]; + cursors.emplace_back(plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[q_dim.first] * q_dim.second, + q_dim.second, filter); + } + return cursors; + } + // find the top-k candidates using brute force search, k as specified by the capacity of the heap. // any value in q_vec that is smaller than q_threshold and any value with dimension >= n_cols() will be ignored. // TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed. template void - search_taat_naive(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + search_taat_naive(std::vector>& q_vec, MaxMinHeap& heap, DocIdFilter& filter, const DocValueComputer& computer) const { - auto scores = compute_all_distances(q_vec, q_threshold, computer); + auto scores = compute_all_distances(q_vec, computer); for (size_t i = 0; i < n_rows_internal_; ++i) { if ((filter.empty() || !filter.test(i)) && scores[i] != 0) { heap.push(i, scores[i]); @@ -650,28 +677,16 @@ class InvertedIndex : public BaseInvertedIndex { // any value in q_vec that is smaller than q_threshold will be ignored. template void - search_daat_wand(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, DocIdFilter& filter, + search_daat_wand(std::vector>& q_vec, MaxMinHeap& heap, DocIdFilter& filter, const DocValueComputer& computer) const { - auto q_dim = q_vec.size(); - std::vector>> cursors(q_dim); - size_t valid_q_dim = 0; - for (size_t i = 0; i < q_dim; ++i) { - auto [idx, val] = q_vec[i]; - auto dim_id = dim_map_.find(idx); - if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) { - continue; - } - auto& plist_ids = inverted_index_ids_[dim_id->second]; - auto& plist_vals = inverted_index_vals_[dim_id->second]; - cursors[valid_q_dim++] = std::make_shared>( - plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter); - } - if (valid_q_dim == 0) { - return; + std::vector> cursors = make_cursors(q_vec, computer, filter); + std::vector*> cursor_ptrs(cursors.size()); + for (size_t i = 0; i < cursors.size(); ++i) { + cursor_ptrs[i] = &cursors[i]; } - cursors.resize(valid_q_dim); - auto sort_cursors = [&cursors] { - std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; }); + auto sort_cursors = [&cursor_ptrs] { + std::sort(cursor_ptrs.begin(), cursor_ptrs.end(), + [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; }); }; sort_cursors(); while (true) { @@ -679,11 +694,11 @@ class InvertedIndex : public BaseInvertedIndex { float upper_bound = 0; size_t pivot; bool found_pivot = false; - for (pivot = 0; pivot < valid_q_dim; ++pivot) { - if (cursors[pivot]->loc_ >= cursors[pivot]->plist_size_) { + for (pivot = 0; pivot < q_vec.size(); ++pivot) { + if (cursor_ptrs[pivot]->cur_vec_id_ >= n_rows_internal_) { break; } - upper_bound += cursors[pivot]->max_score_; + upper_bound += cursor_ptrs[pivot]->max_score_; if (upper_bound > threshold) { found_pivot = true; break; @@ -692,29 +707,29 @@ class InvertedIndex : public BaseInvertedIndex { if (!found_pivot) { break; } - table_t pivot_id = cursors[pivot]->cur_vec_id_; - if (pivot_id == cursors[0]->cur_vec_id_) { + table_t pivot_id = cursor_ptrs[pivot]->cur_vec_id_; + if (pivot_id == cursor_ptrs[0]->cur_vec_id_) { float score = 0; - for (auto& cursor : cursors) { - if (cursor->cur_vec_id_ != pivot_id) { + for (auto& cursor_ptr : cursor_ptrs) { + if (cursor_ptr->cur_vec_id_ != pivot_id) { break; } - float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0; - score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum); - cursor->next(); + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor_ptr->cur_vec_id_) : 0; + score += cursor_ptr->q_value_ * computer(cursor_ptr->cur_vec_val(), cur_vec_sum); + cursor_ptr->next(); } heap.push(pivot_id, score); sort_cursors(); } else { size_t next_list = pivot; - for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) { + for (; cursor_ptrs[next_list]->cur_vec_id_ == pivot_id; --next_list) { } - cursors[next_list]->seek(pivot_id); - for (size_t i = next_list + 1; i < valid_q_dim; ++i) { - if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) { + cursor_ptrs[next_list]->seek(pivot_id); + for (size_t i = next_list + 1; i < q_vec.size(); ++i) { + if (cursor_ptrs[i]->cur_vec_id_ >= cursor_ptrs[i - 1]->cur_vec_id_) { break; } - std::swap(cursors[i], cursors[i - 1]); + std::swap(cursor_ptrs[i], cursor_ptrs[i - 1]); } } } @@ -722,42 +737,27 @@ class InvertedIndex : public BaseInvertedIndex { template void - search_daat_maxscore(const SparseRow& q_vec, DType q_threshold, MaxMinHeap& heap, - const DocIdFilter& filter, const DocValueComputer& computer) const { - auto q_dim = q_vec.size(); - std::vector>> cursors(q_dim); - size_t valid_q_dim = 0; - for (size_t i = 0; i < q_dim; ++i) { - auto [idx, val] = q_vec[i]; - auto dim_id = dim_map_.find(idx); - if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) { - continue; - } - auto& plist_ids = inverted_index_ids_[dim_id->second]; - auto& plist_vals = inverted_index_vals_[dim_id->second]; - cursors[valid_q_dim++] = std::make_shared>( - plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter); - } - if (valid_q_dim == 0) { - return; - } - cursors.resize(valid_q_dim); + search_daat_maxscore(std::vector>& q_vec, MaxMinHeap& heap, DocIdFilter& filter, + const DocValueComputer& computer) const { + std::sort(q_vec.begin(), q_vec.end(), [this](auto& a, auto& b) { + return a.second * max_score_in_dim_[a.first] > b.second * max_score_in_dim_[b.first]; + }); - std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->max_score_ > y->max_score_; }); + std::vector> cursors = make_cursors(q_vec, computer, filter); float threshold = heap.full() ? heap.top().val : 0; std::vector upper_bounds(cursors.size()); float bound_sum = 0.0; for (size_t i = cursors.size() - 1; i + 1 > 0; --i) { - bound_sum += cursors[i]->max_score_; + bound_sum += cursors[i].max_score_; upper_bounds[i] = bound_sum; } - uint32_t next_cand_vec_id = n_rows_internal_; + table_t next_cand_vec_id = n_rows_internal_; for (size_t i = 0; i < cursors.size(); ++i) { - if (cursors[i]->cur_vec_id_ < next_cand_vec_id) { - next_cand_vec_id = cursors[i]->cur_vec_id_; + if (cursors[i].cur_vec_id_ < next_cand_vec_id) { + next_cand_vec_id = cursors[i].cur_vec_id_; } } @@ -772,7 +772,7 @@ class InvertedIndex : public BaseInvertedIndex { } float curr_cand_score = 0.0f; - uint32_t curr_cand_vec_id = 0; + table_t curr_cand_vec_id = 0; while (curr_cand_vec_id < n_rows_internal_) { auto found_cand = false; @@ -788,13 +788,13 @@ class InvertedIndex : public BaseInvertedIndex { next_cand_vec_id = n_rows_internal_; for (size_t i = 0; i < first_ne_idx; ++i) { - if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) { - float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0; - curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum); - cursors[i]->next(); + if (cursors[i].cur_vec_id_ == curr_cand_vec_id) { + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i].cur_vec_id_) : 0; + curr_cand_score += cursors[i].q_value_ * computer(cursors[i].cur_vec_val(), cur_vec_sum); + cursors[i].next(); } - if (cursors[i]->cur_vec_id_ < next_cand_vec_id) { - next_cand_vec_id = cursors[i]->cur_vec_id_; + if (cursors[i].cur_vec_id_ < next_cand_vec_id) { + next_cand_vec_id = cursors[i].cur_vec_id_; } } @@ -804,10 +804,10 @@ class InvertedIndex : public BaseInvertedIndex { found_cand = false; break; } - cursors[i]->seek(curr_cand_vec_id); - if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) { - float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0; - curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum); + cursors[i].seek(curr_cand_vec_id); + if (cursors[i].cur_vec_id_ == curr_cand_vec_id) { + float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i].cur_vec_id_) : 0; + curr_cand_score += cursors[i].q_value_ * computer(cursors[i].cur_vec_val(), cur_vec_sum); } } } @@ -826,8 +826,8 @@ class InvertedIndex : public BaseInvertedIndex { } void - refine_and_collect(const SparseRow& q_vec, MaxMinHeap& inacc_heap, size_t k, float* distances, - label_t* labels, const DocValueComputer& computer) const { + refine_and_collect(std::vector>& q_vec, MaxMinHeap& inacc_heap, size_t k, + float* distances, label_t* labels, const DocValueComputer& computer) const { std::vector docids; MaxMinHeap heap(k); @@ -839,11 +839,11 @@ class InvertedIndex : public BaseInvertedIndex { DocIdFilterByVector filter(std::move(docids)); if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) { - search_daat_wand(q_vec, 0, heap, filter, computer); + search_daat_wand(q_vec, heap, filter, computer); } else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) { - search_daat_maxscore(q_vec, 0, heap, filter, computer); + search_daat_maxscore(q_vec, heap, filter, computer); } else { - search_taat_naive(q_vec, 0, heap, filter, computer); + search_taat_naive(q_vec, heap, filter, computer); } collect_result(heap, distances, labels); }