diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 1e89942ba..3d5573df8 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -736,17 +736,24 @@ class FaissHnswIterator : public IndexIterator { void next_batch(std::function&)> batch_handler) override { - if (workspace.bitset.empty()) { - using filter_type = faiss::IDSelectorAll; - filter_type sel; - - next_batch(batch_handler, sel); - } else { - using filter_type = knowhere::BitsetViewIDSelector; - filter_type sel(workspace.bitset); + std::vector> futs; + futs.reserve(1); + std::shared_ptr search_pool = ThreadPool::GetGlobalSearchThreadPool(); + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + if (workspace.bitset.empty()) { + using filter_type = faiss::IDSelectorAll; + filter_type sel; + + next_batch(batch_handler, sel); + } else { + using filter_type = knowhere::BitsetViewIDSelector; + filter_type sel(workspace.bitset); - next_batch(batch_handler, sel); - } + next_batch(batch_handler, sel); + } + })); + WaitAllSuccess(futs); } float @@ -1231,6 +1238,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { futs.reserve(n_queries); for (int64_t i = 0; i < n_queries; i++) { futs.emplace_back(search_pool->push([&, idx = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); // The query data is always cloned std::unique_ptr cur_query = std::make_unique(dim); diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 8d43db8a0..a9e617c31 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -334,9 +334,15 @@ class IvfIndexNode : public IndexNode { protected: void next_batch(std::function&)> batch_handler) override { - index_->getIteratorNextBatch(workspace_.get(), res_.size()); - batch_handler(workspace_->dists); - workspace_->dists.clear(); + std::vector> futs; + futs.reserve(1); + futs.emplace_back(search_pool_->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + index_->getIteratorNextBatch(workspace_.get(), res_.size()); + batch_handler(workspace_->dists); + workspace_->dists.clear(); + })); + WaitAllSuccess(futs); } private: @@ -947,6 +953,7 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un futs.reserve(rows); for (int i = 0; i < rows; ++i) { futs.emplace_back(search_pool_->push([&, index = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); auto cur_query = (const float*)data + index * dim; // if cosine, need normalize std::unique_ptr copied_query = nullptr;