Skip to content

Commit

Permalink
sparse: add daat maxscore algorithm support
Browse files Browse the repository at this point in the history
Signed-off-by: Shawn Wang <[email protected]>
  • Loading branch information
sparknack committed Jan 9, 2025
1 parent 104b211 commit 770faac
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 30 deletions.
3 changes: 2 additions & 1 deletion include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ constexpr const char* HNSW_REFINE_TYPE = "refine_type";
constexpr const char* SQ_TYPE = "sq_type"; // for IVF_SQ and HNSW_SQ
constexpr const char* PRQ_NUM = "nrq"; // for PRQ, number of redisual quantizers

// Sparse Params
// Sparse Inverted Index Params
constexpr const char* INVERTED_INDEX_ALGO = "inverted_index_algo";
constexpr const char* DROP_RATIO_BUILD = "drop_ratio_build";
constexpr const char* DROP_RATIO_SEARCH = "drop_ratio_search";
} // namespace indexparam
Expand Down
40 changes: 34 additions & 6 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

namespace knowhere {

// Inverted Index impl for sparse vectors. May optionally use WAND algorithm to speed up search.
// Inverted Index impl for sparse vectors.
//
// Not overriding RangeSearch, will use the default implementation in IndexNode.
//
Expand Down Expand Up @@ -348,8 +348,6 @@ class SparseInvertedIndexNode : public IndexNode {
expected<sparse::BaseInvertedIndex<T>*>
CreateIndex(const SparseInvertedIndexConfig& cfg) const {
if (IsMetricType(cfg.metric_type.value(), metric::BM25)) {
// quantize float to uint16_t when BM25 metric type is used.
auto idx = new sparse::InvertedIndex<T, uint16_t, use_wand, true, mmapped>();
if (!cfg.bm25_k1.has_value() || !cfg.bm25_b.has_value() || !cfg.bm25_avgdl.has_value()) {
return expected<sparse::BaseInvertedIndex<T>*>::Err(
Status::invalid_args, "BM25 parameters k1, b, and avgdl must be set when building/loading");
Expand All @@ -358,10 +356,40 @@ class SparseInvertedIndexNode : public IndexNode {
auto b = cfg.bm25_b.value();
auto avgdl = cfg.bm25_avgdl.value();
auto max_score_ratio = cfg.wand_bm25_max_score_ratio.value();
idx->SetBM25Params(k1, b, avgdl, max_score_ratio);
return idx;
if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_WAND, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index =
new sparse::InvertedIndex<T, uint16_t, sparse::InvertedIndexAlgo::TAAT_NAIVE, true, mmapped>();
index->SetBM25Params(k1, b, avgdl, max_score_ratio);
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
"Invalid search algorithm for SparseInvertedIndex");
}
} else {
return new sparse::InvertedIndex<T, T, use_wand, false, mmapped>();
if (use_wand || cfg.inverted_index_algo.value() == "DAAT_WAND") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_WAND, false, mmapped>();
return index;
} else if (cfg.inverted_index_algo.value() == "DAAT_MAXSCORE") {
auto index =
new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::DAAT_MAXSCORE, false, mmapped>();
return index;
} else if (cfg.inverted_index_algo.value() == "TAAT_NAIVE") {
auto index = new sparse::InvertedIndex<T, T, sparse::InvertedIndexAlgo::TAAT_NAIVE, false, mmapped>();
return index;
} else {
return expected<sparse::BaseInvertedIndex<T>*>::Err(Status::invalid_args,
"Invalid search algorithm for SparseInvertedIndex");
}
}
}

Expand Down
160 changes: 139 additions & 21 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
#include "knowhere/utils.h"

namespace knowhere::sparse {

enum class InvertedIndexAlgo {
TAAT_NAIVE,
DAAT_WAND,
DAAT_MAXSCORE,
};

template <typename T>
class BaseInvertedIndex {
public:
Expand Down Expand Up @@ -77,7 +84,7 @@ class BaseInvertedIndex {
n_cols() const = 0;
};

template <typename DType, typename QType, bool use_wand = false, bool bm25 = false, bool mmapped = false>
template <typename DType, typename QType, InvertedIndexAlgo algo, bool bm25 = false, bool mmapped = false>
class InvertedIndex : public BaseInvertedIndex<DType> {
public:
explicit InvertedIndex() {
Expand Down Expand Up @@ -132,12 +139,13 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
"avgdl must be supplied during searching");
}
auto avgdl = cfg.bm25_avgdl.value();
if constexpr (use_wand) {
// wand: search time k1/b must equal load time config.
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
// daat_wand and daat_maxscore: search time k1/b must equal load time config.
if ((cfg.bm25_k1.has_value() && cfg.bm25_k1.value() != bm25_params_->k1) ||
((cfg.bm25_b.has_value() && cfg.bm25_b.value() != bm25_params_->b))) {
return expected<DocValueComputer<float>>::Err(
Status::invalid_args, "search time k1/b must equal load time config for WAND index.");
Status::invalid_args,
"search time k1/b must equal load time config for DAAT_WAND or DAAT_MAXSCORE algorithm.");
}
return GetDocValueBM25Computer<float>(bm25_params_->k1, bm25_params_->b, avgdl);
} else {
Expand Down Expand Up @@ -293,7 +301,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {

map_byte_size_ =
inverted_index_ids_byte_size + inverted_index_vals_byte_size + plists_ids_byte_size + plists_vals_byte_size;
if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
map_byte_size_ += max_score_in_dim_byte_size;
}
if constexpr (bm25) {
Expand Down Expand Up @@ -342,7 +350,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
inverted_index_vals_.initialize(ptr, inverted_index_vals_byte_size);
ptr += inverted_index_vals_byte_size;

if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
max_score_in_dim_.initialize(ptr, max_score_in_dim_byte_size);
ptr += max_score_in_dim_byte_size;
}
Expand All @@ -367,7 +375,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
size_t dim_id = 0;
for (const auto& [idx, count] : idx_counts) {
dim_map_[idx] = dim_id;
if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
max_score_in_dim_.emplace_back(0.0f);
}
++dim_id;
Expand Down Expand Up @@ -432,10 +440,13 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
refine_factor = 1;
}
MaxMinHeap<float> heap(k * refine_factor);
if constexpr (!use_wand) {
search_brute_force(query, q_threshold, heap, bitset, computer);
// DAAT_WAND and DAAT_MAXSCORE are based on the implementation in PISA.
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) {
search_daat_wand(query, q_threshold, heap, bitset, computer);
} else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
search_daat_maxscore(query, q_threshold, heap, bitset, computer);
} else {
search_wand(query, q_threshold, heap, bitset, computer);
search_taat_naive(query, q_threshold, heap, bitset, computer);
}

if (refine_factor == 1) {
Expand Down Expand Up @@ -510,7 +521,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
res += sizeof(typename decltype(inverted_index_vals_)::value_type::value_type) *
inverted_index_vals_[i].capacity();
}
if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
res += sizeof(typename decltype(max_score_in_dim_)::value_type) * max_score_in_dim_.capacity();
}
return res;
Expand Down Expand Up @@ -638,8 +649,8 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
// TODO: may switch to row-wise brute force if filter rate is high. Benchmark needed.
template <typename DocIdFilter>
void
search_brute_force(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
search_taat_naive(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
auto scores = compute_all_distances(q_vec, q_threshold, computer);
for (size_t i = 0; i < n_rows_internal_; ++i) {
if ((filter.empty() || !filter.test(i)) && scores[i] != 0) {
Expand All @@ -651,8 +662,8 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
// any value in q_vec that is smaller than q_threshold will be ignored.
template <typename DocIdFilter>
void
search_wand(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
search_daat_wand(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap, DocIdFilter& filter,
const DocValueComputer<float>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<DocIdFilter>>> cursors(q_dim);
size_t valid_q_dim = 0;
Expand Down Expand Up @@ -721,6 +732,111 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}
}

template <typename DocIdFilter>
void
search_daat_maxscore(const SparseRow<DType>& q_vec, DType q_threshold, MaxMinHeap<float>& heap,
const DocIdFilter& filter, const DocValueComputer<float>& computer) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<DocIdFilter>>> cursors(q_dim);
size_t valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
auto dim_id = dim_map_.find(idx);
if (dim_id == dim_map_.end() || std::abs(val) < q_threshold) {
continue;
}
auto& plist_ids = inverted_index_ids_[dim_id->second];
auto& plist_vals = inverted_index_vals_[dim_id->second];
cursors[valid_q_dim++] = std::make_shared<Cursor<DocIdFilter>>(
plist_ids, plist_vals, n_rows_internal_, max_score_in_dim_[dim_id->second] * val, val, filter);
}
if (valid_q_dim == 0) {
return;
}
cursors.resize(valid_q_dim);

std::sort(cursors.begin(), cursors.end(), [](auto& x, auto& y) { return x->max_score_ > y->max_score_; });

float threshold = heap.full() ? heap.top().val : 0;

std::vector<float> upper_bounds(cursors.size());
float bound_sum = 0.0;
for (size_t i = cursors.size() - 1; i + 1 > 0; --i) {
bound_sum += cursors[i]->max_score_;
upper_bounds[i] = bound_sum;
}

uint32_t next_cand_vec_id = n_rows_internal_;
for (size_t i = 0; i < cursors.size(); ++i) {
if (cursors[i]->cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i]->cur_vec_id_;
}
}

// first_ne_idx is the index of the first non-essential cursor
size_t first_ne_idx = cursors.size();

while (first_ne_idx != 0 && upper_bounds[first_ne_idx - 1] <= threshold) {
--first_ne_idx;
if (first_ne_idx == 0) {
return;
}
}

float curr_cand_score = 0.0f;
uint32_t curr_cand_vec_id = 0;

while (curr_cand_vec_id < n_rows_internal_) {
auto found_cand = false;
while (found_cand == false) {
// start find from next_vec_id
if (next_cand_vec_id >= n_rows_internal_) {
return;
}
// get current candidate vector
curr_cand_vec_id = next_cand_vec_id;
curr_cand_score = 0.0f;
// update next_cand_vec_id
next_cand_vec_id = n_rows_internal_;

for (size_t i = 0; i < first_ne_idx; ++i) {
if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) {
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0;
curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum);
cursors[i]->next();
}
if (cursors[i]->cur_vec_id_ < next_cand_vec_id) {
next_cand_vec_id = cursors[i]->cur_vec_id_;
}
}

found_cand = true;
for (size_t i = first_ne_idx; i < cursors.size(); ++i) {
if (curr_cand_score + upper_bounds[i] <= threshold) {
found_cand = false;
break;
}
cursors[i]->seek(curr_cand_vec_id);
if (cursors[i]->cur_vec_id_ == curr_cand_vec_id) {
float cur_vec_sum = bm25 ? bm25_params_->row_sums.at(cursors[i]->cur_vec_id_) : 0;
curr_cand_score += cursors[i]->q_value_ * computer(cursors[i]->cur_vec_val(), cur_vec_sum);
}
}
}

if (curr_cand_score > threshold) {
heap.push(curr_cand_vec_id, curr_cand_score);
threshold = heap.full() ? heap.top().val : 0;
while (first_ne_idx != 0 && upper_bounds[first_ne_idx - 1] <= threshold) {
--first_ne_idx;
if (first_ne_idx == 0) {
return;
}
}
}
}
}

void
refine_and_collect(const SparseRow<DType>& q_vec, MaxMinHeap<float>& inacc_heap, size_t k, float* distances,
label_t* labels, const DocValueComputer<float>& computer) const {
Expand All @@ -734,10 +850,12 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
}

DocIdFilterByVector filter(std::move(docids));
if (use_wand) {
search_wand(q_vec, 0, heap, filter, computer);
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND) {
search_daat_wand(q_vec, 0, heap, filter, computer);
} else if constexpr (algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
search_daat_maxscore(q_vec, 0, heap, filter, computer);
} else {
search_brute_force(q_vec, 0, heap, filter, computer);
search_taat_naive(q_vec, 0, heap, filter, computer);
}
collect_result(heap, distances, labels);
}
Expand Down Expand Up @@ -774,13 +892,13 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
dim_it = dim_map_.insert({idx, next_dim_id_++}).first;
inverted_index_ids_.emplace_back();
inverted_index_vals_.emplace_back();
if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
max_score_in_dim_.emplace_back(0.0f);
}
}
inverted_index_ids_[dim_it->second].emplace_back(vec_id);
inverted_index_vals_[dim_it->second].emplace_back(get_quant_val(val));
if constexpr (use_wand) {
if constexpr (algo == InvertedIndexAlgo::DAAT_WAND || algo == InvertedIndexAlgo::DAAT_MAXSCORE) {
auto score = static_cast<float>(val);
if constexpr (bm25) {
score = bm25_params_->max_score_ratio * bm25_params_->wand_max_score_computer(val, row_sum);
Expand Down Expand Up @@ -832,7 +950,7 @@ class InvertedIndex : public BaseInvertedIndex<DType> {
// corresponds to the document length of each doc in the BM25 formula.
Vector<float> row_sums;

// below are used only for WAND index.
// below are used only for DAAT_WAND and DAAT_MAXSCORE algorithms.
float max_score_ratio;
DocValueComputer<float> wand_max_score_computer;

Expand Down
7 changes: 7 additions & 0 deletions src/index/sparse/sparse_inverted_index_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class SparseInvertedIndexConfig : public BaseConfig {
CFG_FLOAT drop_ratio_search;
CFG_INT refine_factor;
CFG_FLOAT wand_bm25_max_score_ratio;
CFG_STRING inverted_index_algo;
KNOHWERE_DECLARE_CONFIG(SparseInvertedIndexConfig) {
// NOTE: drop_ratio_build has been deprecated, it won't change anything
KNOWHERE_CONFIG_DECLARE_FIELD(drop_ratio_build)
Expand Down Expand Up @@ -61,6 +62,12 @@ class SparseInvertedIndexConfig : public BaseConfig {
.for_train()
.for_deserialize()
.for_deserialize_from_file();
KNOWHERE_CONFIG_DECLARE_FIELD(inverted_index_algo)
.description("inverted index algorithm")
.set_default("DAAT_MAXSCORE")
.for_train_and_search()
.for_deserialize()
.for_deserialize_from_file();
}
}; // class SparseInvertedIndexConfig

Expand Down
Loading

0 comments on commit 770faac

Please sign in to comment.