Skip to content

Commit

Permalink
fix iterator: set omp to 1
Browse files Browse the repository at this point in the history
Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 committed Dec 16, 2024
1 parent aef09cb commit d38114d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
28 changes: 18 additions & 10 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,17 +736,24 @@ class FaissHnswIterator : public IndexIterator {

void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;

next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(1);
std::shared_ptr<ThreadPool> search_pool = ThreadPool::GetGlobalSearchThreadPool();
futs.emplace_back(search_pool->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;

next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);

next_batch(batch_handler, sel);
}
next_batch(batch_handler, sel);
}
}));
WaitAllSuccess(futs);
}

float
Expand Down Expand Up @@ -1231,6 +1238,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
futs.reserve(n_queries);
for (int64_t i = 0; i < n_queries; i++) {
futs.emplace_back(search_pool->push([&, idx = i] {
ThreadPool::ScopedSearchOmpSetter setter(1);
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

Expand Down
13 changes: 10 additions & 3 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,15 @@ class IvfIndexNode : public IndexNode {
protected:
void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(1);
futs.emplace_back(search_pool_->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
}));
WaitAllSuccess(futs);
}

private:
Expand Down Expand Up @@ -947,6 +953,7 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSetPtr dataset, std::un
futs.reserve(rows);
for (int i = 0; i < rows; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
ThreadPool::ScopedSearchOmpSetter setter(1);
auto cur_query = (const float*)data + index * dim;
// if cosine, need normalize
std::unique_ptr<float[]> copied_query = nullptr;
Expand Down

0 comments on commit d38114d

Please sign in to comment.