diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 1e89942ba..7eced8fe6 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -522,6 +522,7 @@ class FaissHnswIterator : public IndexIterator { const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer, const float refine_ratio = 0.5f) : IndexIterator(larger_is_closer, refine_ratio), index{index_in} { + search_pool = ThreadPool::GetGlobalSearchThreadPool(); // workspace.accumulated_alpha = (bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold)) @@ -736,17 +737,22 @@ class FaissHnswIterator : public IndexIterator { void next_batch(std::function&)> batch_handler) override { - if (workspace.bitset.empty()) { - using filter_type = faiss::IDSelectorAll; - filter_type sel; + std::vector> futs; + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + if (workspace.bitset.empty()) { + using filter_type = faiss::IDSelectorAll; + filter_type sel; - next_batch(batch_handler, sel); - } else { - using filter_type = knowhere::BitsetViewIDSelector; - filter_type sel(workspace.bitset); + next_batch(batch_handler, sel); + } else { + using filter_type = knowhere::BitsetViewIDSelector; + filter_type sel(workspace.bitset); - next_batch(batch_handler, sel); - } + next_batch(batch_handler, sel); + } + })); + WaitAllSuccess(futs); } float @@ -759,6 +765,7 @@ class FaissHnswIterator : public IndexIterator { std::shared_ptr index; FaissHnswIteratorWorkspace workspace; + std::shared_ptr search_pool; }; // @@ -1227,44 +1234,36 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { const auto ef = hnsw_cfg.ef.value_or(kIteratorSeedEf); try { - std::vector> futs; - futs.reserve(n_queries); for (int64_t i = 0; i < n_queries; i++) { - futs.emplace_back(search_pool->push([&, idx = i] { - // The query data is always cloned - std::unique_ptr cur_query = std::make_unique(dim); + // The query data is always cloned + std::unique_ptr cur_query = std::make_unique(dim); - if (data_format == DataFormatEnum::fp32) { - std::copy_n(reinterpret_cast(data) + idx * dim, dim, cur_query.get()); - } else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) { - convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim); - } else { - // invalid one. Should not be triggered, bcz input parameters are validated - throw; - } + if (data_format == DataFormatEnum::fp32) { + std::copy_n(reinterpret_cast(data) + i * dim, dim, cur_query.get()); + } else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) { + convert_rows_to_fp32(data, cur_query.get(), data_format, i, 1, dim); + } else { + // invalid one. Should not be triggered, bcz input parameters are validated + throw; + } - const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); + const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); - const float iterator_refine_ratio = - should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; + const float iterator_refine_ratio = + should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; - // create an iterator and initialize it - auto it = - std::make_shared(index, std::move(cur_query), bitset, ef, larger_is_closer, - // // refine is not needed for flat - // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) - iterator_refine_ratio); + // create an iterator and initialize it + auto it = std::make_shared(index, std::move(cur_query), bitset, ef, larger_is_closer, + // // refine is not needed for flat + // hnsw_cfg.iterator_refine_ratio.value_or(0.5f) + iterator_refine_ratio); - it->initialize(); + it->initialize(); - // store - vec[idx] = it; - })); + // store + vec[i] = it; } - // wait for the completion - WaitAllSuccess(futs); - } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected>::Err(Status::faiss_inner_error, e.what()); diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 8d43db8a0..d67b1da6e 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -329,14 +329,21 @@ class IvfIndexNode : public IndexNode { ivf_search_params_.max_codes = 0; workspace_ = index_->getIteratorWorkspace(copied_query_.get(), &ivf_search_params_); + + search_pool = ThreadPool::GetGlobalSearchThreadPool(); } protected: void next_batch(std::function&)> batch_handler) override { - index_->getIteratorNextBatch(workspace_.get(), res_.size()); - batch_handler(workspace_->dists); - workspace_->dists.clear(); + std::vector> futs; + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + index_->getIteratorNextBatch(workspace_.get(), res_.size()); + batch_handler(workspace_->dists); + workspace_->dists.clear(); + })); + WaitAllSuccess(futs); } private: @@ -345,6 +352,7 @@ class IvfIndexNode : public IndexNode { std::unique_ptr copied_query_ = nullptr; std::unique_ptr bw_idselector_ = nullptr; faiss::IVFSearchParameters ivf_search_params_; + std::shared_ptr search_pool; }; std::unique_ptr index_; @@ -943,32 +951,24 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un // TODO: if SCANN support Iterator, iterator_refine_ratio should be set. float iterator_refine_ratio = 0.0f; try { - std::vector> futs; - futs.reserve(rows); for (int i = 0; i < rows; ++i) { - futs.emplace_back(search_pool_->push([&, index = i] { - auto cur_query = (const float*)data + index * dim; - // if cosine, need normalize - std::unique_ptr copied_query = nullptr; - if (is_cosine) { - copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); - } else { - copied_query = std::make_unique(dim); - std::copy_n(cur_query, dim, copied_query.get()); - } + auto cur_query = (const float*)data + i * dim; + // if cosine, need normalize + std::unique_ptr copied_query = nullptr; + if (is_cosine) { + copied_query = CopyAndNormalizeVecs(cur_query, 1, dim); + } else { + copied_query = std::make_unique(dim); + std::copy_n(cur_query, dim, copied_query.get()); + } - // iterator only own the copied_query. - auto it = std::make_shared(index_.get(), std::move(copied_query), bitset, nprobe, - larger_is_closer, iterator_refine_ratio); - it->initialize(); - vec[index] = it; - })); + // iterator only own the copied_query. + auto it = std::make_shared(index_.get(), std::move(copied_query), bitset, nprobe, + larger_is_closer, iterator_refine_ratio); + it->initialize(); + vec[i] = it; } - // wait for the completion - // initial search - scan at least (nprobe/nlist)% codes - WaitAllSuccess(futs); - } catch (const std::exception& e) { LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return expected>::Err(Status::faiss_inner_error, e.what());