Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix iterator: set omp to 1 #989

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 37 additions & 38 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ class FaissHnswIterator : public IndexIterator {
const BitsetView& bitset_in, const int32_t ef_in, bool larger_is_closer,
const float refine_ratio = 0.5f)
: IndexIterator(larger_is_closer, refine_ratio), index{index_in} {
search_pool = ThreadPool::GetGlobalSearchThreadPool();
//
workspace.accumulated_alpha =
(bitset_in.count() >= (index->ntotal * HnswSearchThresholds::kHnswSearchKnnBFFilterThreshold))
Expand Down Expand Up @@ -736,17 +737,22 @@ class FaissHnswIterator : public IndexIterator {

void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(search_pool->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
if (workspace.bitset.empty()) {
using filter_type = faiss::IDSelectorAll;
filter_type sel;

next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);
next_batch(batch_handler, sel);
} else {
using filter_type = knowhere::BitsetViewIDSelector;
filter_type sel(workspace.bitset);

next_batch(batch_handler, sel);
}
next_batch(batch_handler, sel);
}
}));
WaitAllSuccess(futs);
}

float
Expand All @@ -759,6 +765,7 @@ class FaissHnswIterator : public IndexIterator {
std::shared_ptr<faiss::Index> index;

FaissHnswIteratorWorkspace workspace;
std::shared_ptr<ThreadPool> search_pool;
};

//
Expand Down Expand Up @@ -1227,44 +1234,36 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
const auto ef = hnsw_cfg.ef.value_or(kIteratorSeedEf);

try {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(n_queries);
for (int64_t i = 0; i < n_queries; i++) {
futs.emplace_back(search_pool->push([&, idx = i] {
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);
// The query data is always cloned
std::unique_ptr<float[]> cur_query = std::make_unique<float[]>(dim);

if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + idx * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, idx, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}
if (data_format == DataFormatEnum::fp32) {
std::copy_n(reinterpret_cast<const float*>(data) + i * dim, dim, cur_query.get());
} else if (data_format == DataFormatEnum::fp16 || data_format == DataFormatEnum::bf16) {
convert_rows_to_fp32(data, cur_query.get(), data_format, i, 1, dim);
} else {
// invalid one. Should not be triggered, bcz input parameters are validated
throw;
}

const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);

const float iterator_refine_ratio =
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;
const float iterator_refine_ratio =
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;

// create an iterator and initialize it
auto it =
std::make_shared<FaissHnswIterator>(index, std::move(cur_query), bitset, ef, larger_is_closer,
// // refine is not needed for flat
// hnsw_cfg.iterator_refine_ratio.value_or(0.5f)
iterator_refine_ratio);
// create an iterator and initialize it
auto it = std::make_shared<FaissHnswIterator>(index, std::move(cur_query), bitset, ef, larger_is_closer,
// // refine is not needed for flat
// hnsw_cfg.iterator_refine_ratio.value_or(0.5f)
iterator_refine_ratio);

it->initialize();
it->initialize();

// store
vec[idx] = it;
}));
// store
vec[i] = it;
}

// wait for the completion
WaitAllSuccess(futs);

} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::faiss_inner_error, e.what());
Expand Down
50 changes: 25 additions & 25 deletions src/index/ivf/ivf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,21 @@ class IvfIndexNode : public IndexNode {
ivf_search_params_.max_codes = 0;

workspace_ = index_->getIteratorWorkspace(copied_query_.get(), &ivf_search_params_);

search_pool = ThreadPool::GetGlobalSearchThreadPool();
}

protected:
void
next_batch(std::function<void(const std::vector<DistId>&)> batch_handler) override {
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
std::vector<folly::Future<folly::Unit>> futs;
futs.emplace_back(search_pool->push([&] {
ThreadPool::ScopedSearchOmpSetter setter(1);
index_->getIteratorNextBatch(workspace_.get(), res_.size());
batch_handler(workspace_->dists);
workspace_->dists.clear();
}));
WaitAllSuccess(futs);
}

private:
Expand All @@ -345,6 +352,7 @@ class IvfIndexNode : public IndexNode {
std::unique_ptr<float[]> copied_query_ = nullptr;
std::unique_ptr<BitsetViewIDSelector> bw_idselector_ = nullptr;
faiss::IVFSearchParameters ivf_search_params_;
std::shared_ptr<ThreadPool> search_pool;
};

std::unique_ptr<IndexType> index_;
Expand Down Expand Up @@ -943,32 +951,24 @@ IvfIndexNode<DataType, IndexType>::AnnIterator(const DataSetPtr dataset, std::un
// TODO: if SCANN support Iterator, iterator_refine_ratio should be set.
float iterator_refine_ratio = 0.0f;
try {
std::vector<folly::Future<folly::Unit>> futs;
futs.reserve(rows);
for (int i = 0; i < rows; ++i) {
futs.emplace_back(search_pool_->push([&, index = i] {
auto cur_query = (const float*)data + index * dim;
// if cosine, need normalize
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
} else {
copied_query = std::make_unique<float[]>(dim);
std::copy_n(cur_query, dim, copied_query.get());
}
auto cur_query = (const float*)data + i * dim;
// if cosine, need normalize
std::unique_ptr<float[]> copied_query = nullptr;
if (is_cosine) {
copied_query = CopyAndNormalizeVecs(cur_query, 1, dim);
} else {
copied_query = std::make_unique<float[]>(dim);
std::copy_n(cur_query, dim, copied_query.get());
}

// iterator only own the copied_query.
auto it = std::make_shared<iterator>(index_.get(), std::move(copied_query), bitset, nprobe,
larger_is_closer, iterator_refine_ratio);
it->initialize();
vec[index] = it;
}));
// iterator only own the copied_query.
auto it = std::make_shared<iterator>(index_.get(), std::move(copied_query), bitset, nprobe,
larger_is_closer, iterator_refine_ratio);
it->initialize();
vec[i] = it;
}

// wait for the completion
// initial search - scan at least (nprobe/nlist)% codes
WaitAllSuccess(futs);

} catch (const std::exception& e) {
LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what();
return expected<std::vector<IndexNode::IteratorPtr>>::Err(Status::faiss_inner_error, e.what());
Expand Down
Loading