From 75538c58c0b097c537f2d00f927ef10fd12ed9b9 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Mon, 2 Dec 2024 22:48:34 +0800 Subject: [PATCH] wand: optimize cursor class to get ~10% performance improvement; (#968) simplify code to drop Signed-off-by: Buqian Zheng --- src/index/sparse/sparse_inverted_index.h | 140 ++++++++++++----------- 1 file changed, 73 insertions(+), 67 deletions(-) diff --git a/src/index/sparse/sparse_inverted_index.h b/src/index/sparse/sparse_inverted_index.h index 69e12453b..06665894a 100644 --- a/src/index/sparse/sparse_inverted_index.h +++ b/src/index/sparse/sparse_inverted_index.h @@ -360,15 +360,7 @@ class InvertedIndex : public BaseInvertedIndex { vals.push_back(fabs(data[i][j].val)); } } - auto pos = vals.begin() + static_cast(drop_ratio_build * vals.size()); - // pos may be vals.end() if drop_ratio_build is 1.0, in that case we use - // the largest value as the threshold. - if (pos == vals.end()) { - pos--; - } - std::nth_element(vals.begin(), pos, vals.end()); - - value_threshold_ = *pos; + value_threshold_ = get_threshold(vals, drop_ratio_build); drop_during_build_ = true; return Status::success; } @@ -413,9 +405,7 @@ class InvertedIndex : public BaseInvertedIndex { for (size_t i = 0; i < query.size(); ++i) { values[i] = std::abs(query[i].val); } - auto pos = values.begin() + static_cast(drop_ratio_search * values.size()); - std::nth_element(values.begin(), pos, values.end()); - auto q_threshold = *pos; + auto q_threshold = get_threshold(values, drop_ratio_search); // if no data was dropped during both build and search, no refinement is // needed. @@ -447,9 +437,7 @@ class InvertedIndex : public BaseInvertedIndex { for (size_t i = 0; i < query.size(); ++i) { values[i] = std::abs(query[i].val); } - auto pos = values.begin() + static_cast(drop_ratio_search * values.size()); - std::nth_element(values.begin(), pos, values.end()); - auto q_threshold = *pos; + auto q_threshold = get_threshold(values, drop_ratio_search); auto distances = compute_all_distances(query, q_threshold, computer); for (size_t i = 0; i < distances.size(); ++i) { if (bitset.empty() || !bitset.test(i)) { @@ -512,6 +500,22 @@ class InvertedIndex : public BaseInvertedIndex { } private: + // Given a vector of values, returns the threshold value. + // All values strictly smaller than the threshold will be ignored. + // values will be modified in this function. + inline T + get_threshold(std::vector& values, float drop_ratio) const { + // drop_ratio is in [0, 1) thus drop_count is guaranteed to be less + // than values.size(). + auto drop_count = static_cast(drop_ratio * values.size()); + if (drop_count == 0) { + return 0; + } + auto pos = values.begin() + drop_count; + std::nth_element(values.begin(), pos, values.end()); + return *pos; + } + size_t n_rows_internal() const { return raw_data_.size(); @@ -561,66 +565,69 @@ class InvertedIndex : public BaseInvertedIndex { // LUT supports size() and operator[] which returns an SparseIdVal. template - class Cursor { + struct Cursor { public: Cursor(const LUT& lut, size_t num_vec, float max_score, float q_value, const BitsetView bitset) - : lut_(lut), num_vec_(num_vec), max_score_(max_score), q_value_(q_value), bitset_(bitset) { - while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) { + : lut_(lut), + lut_size_(lut.size()), + total_num_vec_(num_vec), + max_score_(max_score), + q_value_(q_value), + bitset_(bitset) { + while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) { loc_++; } + update_cur_vec_id(); } Cursor(const Cursor& rhs) = delete; void next() { - loc_++; - while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) { - loc_++; - } + next_internal(); + update_cur_vec_id(); } - // advance loc until cur_vec_id() >= vec_id + + // advance loc until cur_vec_id_ >= vec_id void seek(table_t vec_id) { - while (loc_ < lut_.size() && cur_vec_id() < vec_id) { - next(); + while (loc_ < lut_size_ && lut_[loc_].id < vec_id) { + next_internal(); } + update_cur_vec_id(); } - [[nodiscard]] table_t - cur_vec_id() const { - if (is_end()) { - return num_vec_; - } - return lut_[loc_].id; - } + T cur_vec_val() const { return lut_[loc_].val; } - [[nodiscard]] bool - is_end() const { - return loc_ >= size(); - } - [[nodiscard]] float - q_value() const { - return q_value_; - } - [[nodiscard]] size_t - size() const { - return lut_.size(); - } - [[nodiscard]] float - max_score() const { - return max_score_; - } - private: const LUT& lut_; + const size_t lut_size_; size_t loc_ = 0; - size_t num_vec_ = 0; + size_t total_num_vec_ = 0; float max_score_ = 0.0f; float q_value_ = 0.0f; const BitsetView bitset_; - }; // class Cursor + table_t cur_vec_id_ = 0; + + private: + inline void + update_cur_vec_id() { + if (loc_ >= lut_size_) { + cur_vec_id_ = total_num_vec_; + } else { + cur_vec_id_ = lut_[loc_].id; + } + } + + inline void + next_internal() { + loc_++; + while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) { + loc_++; + } + } + }; // struct Cursor // any value in q_vec that is smaller than q_threshold will be ignored. void @@ -628,7 +635,7 @@ class InvertedIndex : public BaseInvertedIndex { const DocValueComputer& computer) const { auto q_dim = q_vec.size(); std::vector>> cursors(q_dim); - auto valid_q_dim = 0; + size_t valid_q_dim = 0; for (size_t i = 0; i < q_dim; ++i) { auto [idx, val] = q_vec[i]; auto dim_id = dim_map_.find(idx); @@ -644,21 +651,20 @@ class InvertedIndex : public BaseInvertedIndex { } cursors.resize(valid_q_dim); auto sort_cursors = [&cursors] { - std::sort(cursors.begin(), cursors.end(), - [](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); }); + std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; }); }; sort_cursors(); - auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; }; while (true) { + float threshold = heap.full() ? heap.top().val : 0; float upper_bound = 0; size_t pivot; bool found_pivot = false; - for (pivot = 0; pivot < cursors.size(); ++pivot) { - if (cursors[pivot]->is_end()) { + for (pivot = 0; pivot < valid_q_dim; ++pivot) { + if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) { break; } - upper_bound += cursors[pivot]->max_score(); - if (score_above_threshold(upper_bound)) { + upper_bound += cursors[pivot]->max_score_; + if (upper_bound > threshold) { found_pivot = true; break; } @@ -666,26 +672,26 @@ class InvertedIndex : public BaseInvertedIndex { if (!found_pivot) { break; } - table_t pivot_id = cursors[pivot]->cur_vec_id(); - if (pivot_id == cursors[0]->cur_vec_id()) { + table_t pivot_id = cursors[pivot]->cur_vec_id_; + if (pivot_id == cursors[0]->cur_vec_id_) { float score = 0; for (auto& cursor : cursors) { - if (cursor->cur_vec_id() != pivot_id) { + if (cursor->cur_vec_id_ != pivot_id) { break; } - T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id()) : 0; - score += cursor->q_value() * computer(cursor->cur_vec_val(), cur_vec_sum); + T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0; + score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum); cursor->next(); } heap.push(pivot_id, score); sort_cursors(); } else { size_t next_list = pivot; - for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) { + for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) { } cursors[next_list]->seek(pivot_id); - for (size_t i = next_list + 1; i < cursors.size(); ++i) { - if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) { + for (size_t i = next_list + 1; i < valid_q_dim; ++i) { + if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) { break; } std::swap(cursors[i], cursors[i - 1]);