Skip to content

Commit

Permalink
Use search_pool to handle iterator's next_batch instead of init
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Dec 18, 2024
1 parent aef09cb commit 64010b8
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 63 deletions.
75 changes: 37 additions & 38 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ class FaissHnswIterator : public IndexIterator {
const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer,
const float refine_ratio = 0.5f)
: IndexIterator(larger_is_closer, refine_ratio), index{index_in} {
search_pool = ThreadPool::GetGlobalSearchThreadPool();
//
workspace.accumulated_alpha =
(bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold))
Expand Down Expand Up @@ -736,17 +737,22 @@ class FaissHnswIterator : public IndexIterator {

void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(search_pool->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;

next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);
next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);

next_batch(batch_handler, sel);
}
next_batch(batch_handler, sel);
}
}));
WaitAllSuccess(futs);
}

float
Expand All @@ -759,6 +765,7 @@ class FaissHnswIterator : public IndexIterator {
std::shared_ptr<faiss::Index> index;

FaissHnswIteratorWorkspace workspace;
std::shared_ptr<ThreadPool> search_pool;
};

//
Expand Down Expand Up @@ -1227,44 +1234,36 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
const auto ef = hnsw_cfg.ef.value_or(kIteratorSeedEf);

try {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(n_queries);
for (int64_t i = 0; i < n_queries; i++) {
futs.emplace_back(search_pool->push([&, idx = i] {
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}
if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + i * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, i, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}

const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);

const float iterator_refine_ratio =
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;
const float iterator_refine_ratio =
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;

// create an iterator and initialize it
auto it =
std::make_shared<FaissHnswIterator>(index, std::move(cur_query), bitset, ef, larger_is_closer,
// // refine is not needed for flat
// hnsw_cfg.iterator_refine_ratio.value_or(0.5f)
iterator_refine_ratio);
// create an iterator and initialize it
auto it = std::make_shared<FaissHnswIterator>(index, std::move(cur_query), bitset, ef, larger_is_closer,
// // refine is not needed for flat
// hnsw_cfg.iterator_refine_ratio.value_or(0.5f)
iterator_refine_ratio);

it->initialize();
it->initialize();

// store
vec[idx] = it;
}));
// store
vec[i] = it;
}

// wait for the completion
WaitAllSuccess(futs);

} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::faiss_inner_error, e.what());
Expand Down
50 changes: 25 additions & 25 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,21 @@ class IvfIndexNode : public IndexNode {
ivf_search_params_.max_codes = 0;

workspace_ = index_->getIteratorWorkspace(copied_query_.get(), &ivf_search_params_);

search_pool = ThreadPool::GetGlobalSearchThreadPool();
}

protected:
void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(search_pool->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
}));
WaitAllSuccess(futs);
}

private:
Expand All @@ -345,6 +352,7 @@ class IvfIndexNode : public IndexNode {
std::unique_ptr<float[]> copied_query_ = nullptr;
std::unique_ptr<BitsetViewIDSelector> bw_idselector_ = nullptr;
faiss::IVFSearchParameters ivf_search_params_;
std::shared_ptr<ThreadPool> search_pool;
};

std::unique_ptr<IndexType> index_;
Expand Down Expand Up @@ -943,32 +951,24 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSetPtr dataset, std::un
// TODO: if SCANN support Iterator, iterator_refine_ratio should be set.
float iterator_refine_ratio = 0.0f;
try {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(rows);
for (int i = 0; i < rows; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
auto cur_query = (const float*)data + index * dim;
// if cosine, need normalize
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
} else {
copied_query = std::make_unique<float[]>(dim);
std::copy_n(cur_query, dim, copied_query.get());
}
auto cur_query = (const float*)data + i * dim;
// if cosine, need normalize
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
} else {
copied_query = std::make_unique<float[]>(dim);
std::copy_n(cur_query, dim, copied_query.get());
}

// iterator only own the copied_query.
auto it = std::make_shared<iterator>(index_.get(), std::move(copied_query), bitset, nprobe,
larger_is_closer, iterator_refine_ratio);
it->initialize();
vec[index] = it;
}));
// iterator only own the copied_query.
auto it = std::make_shared<iterator>(index_.get(), std::move(copied_query), bitset, nprobe,
larger_is_closer, iterator_refine_ratio);
it->initialize();
vec[i] = it;
}

// wait for the completion
// initial search - scan at least (nprobe/nlist)% codes
WaitAllSuccess(futs);

} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::faiss_inner_error, e.what());
Expand Down

0 comments on commit 64010b8

Please sign in to comment.