From 27dea7f097f3eb238eb38767f55650123173f295 Mon Sep 17 00:00:00 2001 From: Alexander Guzhva Date: Fri, 8 Nov 2024 19:56:19 -0500 Subject: [PATCH] fix_hnsw_pq_iterator (#932) Signed-off-by: Alexandr Guzhva --- src/index/hnsw/faiss_hnsw.cc | 63 +++++++++++-------- src/index/hnsw/faiss_hnsw_config.h | 8 ++- .../hnsw/impl/IndexConditionalWrapper.cc | 10 ++- src/index/hnsw/impl/IndexConditionalWrapper.h | 8 ++- tests/ut/test_iterator.cc | 29 ++++----- thirdparty/faiss/faiss/IndexRefine.cpp | 41 ++++++++++++ thirdparty/faiss/faiss/IndexRefine.h | 7 +++ 7 files changed, 116 insertions(+), 50 deletions(-) diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 0f43c83e0..bfec433fd 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -551,24 +551,32 @@ class FaissHnswIterator : public IndexIterator { // wrap a sign, if needed workspace.qdis = std::unique_ptr(storage_distance_computer(index_hnsw)); - // a tricky point here. - // Basically, if out hnsw index's storage is HasInverseL2Norms, then - // this is a cosine index. But because refine always keeps original - // data, then we need to use a wrapper over a distance computer - const faiss::HasInverseL2Norms* has_l2_norms = - dynamic_cast(index_hnsw->storage); - if (has_l2_norms != nullptr) { - // add a cosine wrapper over it - // DO NOT WRAP A SIGN, by design - workspace.qdis_refine = - std::unique_ptr(new faiss::WithCosineNormDistanceComputer( - has_l2_norms->get_inverse_l2_norms(), index->d, - std::unique_ptr(index_refine->refine_index->get_distance_computer()))); + if (refine_ratio != 0) { + // the refine is needed + + // a tricky point here. + // Basically, if out hnsw index's storage is HasInverseL2Norms, then + // this is a cosine index. But because refine always keeps original + // data, then we need to use a wrapper over a distance computer + const faiss::HasInverseL2Norms* has_l2_norms = + dynamic_cast(index_hnsw->storage); + if (has_l2_norms != nullptr) { + // add a cosine wrapper over it + // DO NOT WRAP A SIGN, by design + workspace.qdis_refine = + std::unique_ptr(new faiss::WithCosineNormDistanceComputer( + has_l2_norms->get_inverse_l2_norms(), index->d, + std::unique_ptr( + index_refine->refine_index->get_distance_computer()))); + } else { + // use it as is + // DO NOT WRAP A SIGN, by design + workspace.qdis_refine = + std::unique_ptr(index_refine->refine_index->get_distance_computer()); + } } else { - // use it as is - // DO NOT WRAP A SIGN, by design - workspace.qdis_refine = - std::unique_ptr(index_refine->refine_index->get_distance_computer()); + // the refine is not needed + workspace.qdis_refine = nullptr; } } else { const faiss::IndexHNSW* index_hnsw = dynamic_cast(index.get()); @@ -882,9 +890,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return expected::Err(Status::invalid_args, "k parameter is missing"); } + // whether a user wants a refine + const bool whether_to_enable_refine = hnsw_cfg.refine_k.has_value(); + // set up an index wrapper - auto [index_wrapper, is_refined] = - create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false)); + auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( + index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -1011,9 +1022,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { return expected::Err(Status::invalid_args, "ef parameter is missing"); } + // whether a user wants a refine + const bool whether_to_enable_refine = true; + // set up an index wrapper - auto [index_wrapper, is_refined] = - create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false)); + auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper( + index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine); if (index_wrapper == nullptr) { return expected::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW"); @@ -1229,11 +1243,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { throw; } - // + const bool should_use_refine = (dynamic_cast(index.get()) != nullptr); + const float iterator_refine_ratio = - (dynamic_cast(index.get()) != nullptr) - ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) - : 0; + should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0; // create an iterator and initialize it auto it = diff --git a/src/index/hnsw/faiss_hnsw_config.h b/src/index/hnsw/faiss_hnsw_config.h index 48710a9be..9e1e9966d 100644 --- a/src/index/hnsw/faiss_hnsw_config.h +++ b/src/index/hnsw/faiss_hnsw_config.h @@ -43,7 +43,7 @@ class FaissHnswConfig : public BaseHnswConfig { .for_static(); KNOWHERE_CONFIG_DECLARE_FIELD(refine_k) .description("refine k") - .allow_empty_without_default() + .set_default(1) .set_range(1, std::numeric_limits::max()) .for_search(); KNOWHERE_CONFIG_DECLARE_FIELD(refine_type) @@ -83,7 +83,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig { // check our parameters if (param_type == PARAM_TYPE::TRAIN) { // prohibit refine - if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) { + if (refine.value_or(false) || refine_type.has_value()) { if (err_msg) { *err_msg = "refine is not supported for this index"; LOG_KNOWHERE_ERROR_ << *err_msg; @@ -189,6 +189,8 @@ class FaissHnswPqConfig : public FaissHnswConfig { } } } + default: + break; } return Status::success; } @@ -232,6 +234,8 @@ class FaissHnswPrqConfig : public FaissHnswConfig { } } } + default: + break; } return Status::success; } diff --git a/src/index/hnsw/impl/IndexConditionalWrapper.cc b/src/index/hnsw/impl/IndexConditionalWrapper.cc index 5660f10e2..e3ab59e82 100644 --- a/src/index/hnsw/impl/IndexConditionalWrapper.cc +++ b/src/index/hnsw/impl/IndexConditionalWrapper.cc @@ -94,9 +94,13 @@ WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswCo return false; } -// returns nullptr in case of invalid index +// returns nullptr in case of invalid index. +// +// `whether_to_enable_refine` allows to enable the refine for the search if the +// index was trained with the refine. std::tuple, bool> -create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search) { +create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search, + const bool whether_to_enable_refine) { const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE); // check if we have a refine available. @@ -126,7 +130,7 @@ create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw } // check if a user wants a refined result - if (hnsw_cfg.refine_k.has_value()) { + if (whether_to_enable_refine) { // yes, a user wants to perform a refine // thus, we need to define a new refine index and pass diff --git a/src/index/hnsw/impl/IndexConditionalWrapper.h b/src/index/hnsw/impl/IndexConditionalWrapper.h index 0f140bf7d..84d86ee13 100644 --- a/src/index/hnsw/impl/IndexConditionalWrapper.h +++ b/src/index/hnsw/impl/IndexConditionalWrapper.h @@ -42,8 +42,12 @@ std::optional WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswConfig& cfg, const BitsetView& bitset); // first return arg: returns nullptr in case of invalid index -// second return arg: returns whether an index does refine +// second return arg: returns whether an index does the refine +// +// `whether_to_enable_refine` allows to enable the refine for the search if the +// index was trained with the refine. std::tuple, bool> -create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search); +create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search, + const bool whether_to_enable_refine); } // namespace knowhere diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 9e89cd130..17a166114 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -197,9 +197,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen), @@ -207,11 +204,11 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen), - // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), - // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); @@ -286,8 +283,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen), @@ -295,11 +290,11 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen), - // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), - // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); @@ -334,8 +329,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { SECTION("Test Search with Bitset using iterator insufficient results") { using std::make_tuple; auto [name, gen] = GENERATE_REF(table>({ - make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen), make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen), @@ -343,11 +336,11 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen), make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), - make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), + // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); diff --git a/thirdparty/faiss/faiss/IndexRefine.cpp b/thirdparty/faiss/faiss/IndexRefine.cpp index a65664cc6..bbeaff0cc 100644 --- a/thirdparty/faiss/faiss/IndexRefine.cpp +++ b/thirdparty/faiss/faiss/IndexRefine.cpp @@ -189,6 +189,47 @@ void IndexRefine::search( } } +void IndexRefine::range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params_in) const +{ + const IndexRefineSearchParameters* params = nullptr; + if (params_in) { + params = dynamic_cast(params_in); + FAISS_THROW_IF_NOT_MSG( + params, "IndexRefine params have incorrect type"); + } + + SearchParameters* base_index_params = + (params != nullptr) ? params->base_index_params : nullptr; + + base_index->range_search( + n, x, radius, result, base_index_params); + +#pragma omp parallel if (n > 1) + { + std::unique_ptr dc( + refine_index->get_distance_computer()); + +#pragma omp for + for (idx_t i = 0; i < n; i++) { + dc->set_query(x + i * d); + + // reevaluate distances + const size_t idx_start = result->lims[i]; + const size_t idx_end = result->lims[i + 1]; + + for (size_t j = idx_start; j < idx_end; j++) { + const auto label = result->labels[j]; + result->distances[j] = (*dc)(label); + } + } + } +} + void IndexRefine::reconstruct(idx_t key, float* recons) const { refine_index->reconstruct(key, recons); } diff --git a/thirdparty/faiss/faiss/IndexRefine.h b/thirdparty/faiss/faiss/IndexRefine.h index 23687af9f..f912df957 100644 --- a/thirdparty/faiss/faiss/IndexRefine.h +++ b/thirdparty/faiss/faiss/IndexRefine.h @@ -54,6 +54,13 @@ struct IndexRefine : Index { idx_t* labels, const SearchParameters* params = nullptr) const override; + void range_search( + idx_t n, + const float* x, + float radius, + RangeSearchResult* result, + const SearchParameters* params = nullptr) const override; + // reconstruct is routed to the refine_index void reconstruct(idx_t key, float* recons) const override;