Skip to content

Commit

Permalink
wand: optimize cursor class to get ~10% performance improvement; (#968)
Browse files Browse the repository at this point in the history
simplify code to drop

Signed-off-by: Buqian Zheng <[email protected]>
  • Loading branch information
zhengbuqian authored Dec 2, 2024
1 parent 1cb9937 commit 75538c5
Showing 1 changed file with 73 additions and 67 deletions.
140 changes: 73 additions & 67 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,15 +360,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
vals.push_back(fabs(data[i][j].val));
}
}
auto pos = vals.begin() + static_cast<size_t>(drop_ratio_build * vals.size());
// pos may be vals.end() if drop_ratio_build is 1.0, in that case we use
// the largest value as the threshold.
if (pos == vals.end()) {
pos--;
}
std::nth_element(vals.begin(), pos, vals.end());

value_threshold_ = *pos;
value_threshold_ = get_threshold(vals, drop_ratio_build);
drop_during_build_ = true;
return Status::success;
}
Expand Down Expand Up @@ -413,9 +405,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
for (size_t i = 0; i < query.size(); ++i) {
values[i] = std::abs(query[i].val);
}
auto pos = values.begin() + static_cast<size_t>(drop_ratio_search * values.size());
std::nth_element(values.begin(), pos, values.end());
auto q_threshold = *pos;
auto q_threshold = get_threshold(values, drop_ratio_search);

// if no data was dropped during both build and search, no refinement is
// needed.
Expand Down Expand Up @@ -447,9 +437,7 @@ class InvertedIndex : public BaseInvertedIndex<T> {
for (size_t i = 0; i < query.size(); ++i) {
values[i] = std::abs(query[i].val);
}
auto pos = values.begin() + static_cast<size_t>(drop_ratio_search * values.size());
std::nth_element(values.begin(), pos, values.end());
auto q_threshold = *pos;
auto q_threshold = get_threshold(values, drop_ratio_search);
auto distances = compute_all_distances(query, q_threshold, computer);
for (size_t i = 0; i < distances.size(); ++i) {
if (bitset.empty() || !bitset.test(i)) {
Expand Down Expand Up @@ -512,6 +500,22 @@ class InvertedIndex : public BaseInvertedIndex<T> {
}

private:
// Given a vector of values, returns the threshold value.
// All values strictly smaller than the threshold will be ignored.
// values will be modified in this function.
inline T
get_threshold(std::vector<T>& values, float drop_ratio) const {
// drop_ratio is in [0, 1) thus drop_count is guaranteed to be less
// than values.size().
auto drop_count = static_cast<size_t>(drop_ratio * values.size());
if (drop_count == 0) {
return 0;
}
auto pos = values.begin() + drop_count;
std::nth_element(values.begin(), pos, values.end());
return *pos;
}

size_t
n_rows_internal() const {
return raw_data_.size();
Expand Down Expand Up @@ -561,74 +565,77 @@ class InvertedIndex : public BaseInvertedIndex<T> {

// LUT supports size() and operator[] which returns an SparseIdVal.
template <typename LUT>
class Cursor {
struct Cursor {
public:
Cursor(const LUT& lut, size_t num_vec, float max_score, float q_value, const BitsetView bitset)
: lut_(lut), num_vec_(num_vec), max_score_(max_score), q_value_(q_value), bitset_(bitset) {
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
: lut_(lut),
lut_size_(lut.size()),
total_num_vec_(num_vec),
max_score_(max_score),
q_value_(q_value),
bitset_(bitset) {
while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) {
loc_++;
}
update_cur_vec_id();
}
Cursor(const Cursor& rhs) = delete;

void
next() {
loc_++;
while (loc_ < lut_.size() && !bitset_.empty() && bitset_.test(cur_vec_id())) {
loc_++;
}
next_internal();
update_cur_vec_id();
}
// advance loc until cur_vec_id() >= vec_id

// advance loc until cur_vec_id_ >= vec_id
void
seek(table_t vec_id) {
while (loc_ < lut_.size() && cur_vec_id() < vec_id) {
next();
while (loc_ < lut_size_ && lut_[loc_].id < vec_id) {
next_internal();
}
update_cur_vec_id();
}
[[nodiscard]] table_t
cur_vec_id() const {
if (is_end()) {
return num_vec_;
}
return lut_[loc_].id;
}

T
cur_vec_val() const {
return lut_[loc_].val;
}
[[nodiscard]] bool
is_end() const {
return loc_ >= size();
}
[[nodiscard]] float
q_value() const {
return q_value_;
}
[[nodiscard]] size_t
size() const {
return lut_.size();
}
[[nodiscard]] float
max_score() const {
return max_score_;
}

private:
const LUT& lut_;
const size_t lut_size_;
size_t loc_ = 0;
size_t num_vec_ = 0;
size_t total_num_vec_ = 0;
float max_score_ = 0.0f;
float q_value_ = 0.0f;
const BitsetView bitset_;
}; // class Cursor
table_t cur_vec_id_ = 0;

private:
inline void
update_cur_vec_id() {
if (loc_ >= lut_size_) {
cur_vec_id_ = total_num_vec_;
} else {
cur_vec_id_ = lut_[loc_].id;
}
}

inline void
next_internal() {
loc_++;
while (loc_ < lut_size_ && !bitset_.empty() && bitset_.test(lut_[loc_].id)) {
loc_++;
}
}
}; // struct Cursor

// any value in q_vec that is smaller than q_threshold will be ignored.
void
search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset,
const DocValueComputer<T>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<const typename decltype(inverted_lut_)::value_type&>>> cursors(q_dim);
auto valid_q_dim = 0;
size_t valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
auto dim_id = dim_map_.find(idx);
Expand All @@ -644,48 +651,47 @@ class InvertedIndex : public BaseInvertedIndex<T> {
}
cursors.resize(valid_q_dim);
auto sort_cursors = [&cursors] {
std::sort(cursors.begin(), cursors.end(),
[](auto& x, auto& y) { return x->cur_vec_id() < y->cur_vec_id(); });
std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->cur_vec_id_ < y->cur_vec_id_; });
};
sort_cursors();
auto score_above_threshold = [&heap](float x) { return !heap.full() || x > heap.top().val; };
while (true) {
float threshold = heap.full() ? heap.top().val : 0;
float upper_bound = 0;
size_t pivot;
bool found_pivot = false;
for (pivot = 0; pivot < cursors.size(); ++pivot) {
if (cursors[pivot]->is_end()) {
for (pivot = 0; pivot < valid_q_dim; ++pivot) {
if (cursors[pivot]->loc_ >= cursors[pivot]->lut_size_) {
break;
}
upper_bound += cursors[pivot]->max_score();
if (score_above_threshold(upper_bound)) {
upper_bound += cursors[pivot]->max_score_;
if (upper_bound > threshold) {
found_pivot = true;
break;
}
}
if (!found_pivot) {
break;
}
table_t pivot_id = cursors[pivot]->cur_vec_id();
if (pivot_id == cursors[0]->cur_vec_id()) {
table_t pivot_id = cursors[pivot]->cur_vec_id_;
if (pivot_id == cursors[0]->cur_vec_id_) {
float score = 0;
for (auto& cursor : cursors) {
if (cursor->cur_vec_id() != pivot_id) {
if (cursor->cur_vec_id_ != pivot_id) {
break;
}
T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id()) : 0;
score += cursor->q_value() * computer(cursor->cur_vec_val(), cur_vec_sum);
T cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursor->cur_vec_id_) : 0;
score += cursor->q_value_ * computer(cursor->cur_vec_val(), cur_vec_sum);
cursor->next();
}
heap.push(pivot_id, score);
sort_cursors();
} else {
size_t next_list = pivot;
for (; cursors[next_list]->cur_vec_id() == pivot_id; --next_list) {
for (; cursors[next_list]->cur_vec_id_ == pivot_id; --next_list) {
}
cursors[next_list]->seek(pivot_id);
for (size_t i = next_list + 1; i < cursors.size(); ++i) {
if (cursors[i]->cur_vec_id() >= cursors[i - 1]->cur_vec_id()) {
for (size_t i = next_list + 1; i < valid_q_dim; ++i) {
if (cursors[i]->cur_vec_id_ >= cursors[i - 1]->cur_vec_id_) {
break;
}
std::swap(cursors[i], cursors[i - 1]);
Expand Down

0 comments on commit 75538c5

Please sign in to comment.