From a63a4392bd5f944de3e58859817975f11343857d Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 18 Dec 2024 15:56:06 +0800 Subject: [PATCH] enhance: scann support iterator Signed-off-by: cqy123456 --- src/index/ivf/ivf.cc | 24 ++++- tests/ut/test_iterator.cc | 20 ++++ thirdparty/faiss/faiss/IndexIVF.cpp | 12 +++ thirdparty/faiss/faiss/IndexIVF.h | 12 ++- thirdparty/faiss/faiss/IndexIVFFastScan.cpp | 100 ++++++++++++++++++ thirdparty/faiss/faiss/IndexIVFFastScan.h | 40 +++++++ thirdparty/faiss/faiss/IndexScaNN.cpp | 35 ++++++ thirdparty/faiss/faiss/IndexScaNN.h | 8 ++ .../faiss/faiss/impl/simd_result_handlers.h | 81 ++++++++++++++ 9 files changed, 324 insertions(+), 8 deletions(-) diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 8d43db8a0..eab652362 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -311,7 +311,7 @@ class IvfIndexNode : public IndexNode { } private: - // only support IVFFlat,IVFFlatCC, IVFSQ and IVFSQCC + // only support IVFFlat,IVFFlatCC, IVFSQ, IVFSQCC and SCANN // iterator will own the copied_norm_query // TODO: iterator should copy and own query data. // TODO: If SCANN support Iterator, raw_distance() function should be override. @@ -339,6 +339,18 @@ class IvfIndexNode : public IndexNode { workspace_->dists.clear(); } + float + raw_distance(int64_t id) override { + if constexpr (std::is_same_v) { + if (refine_) { + return workspace_->dis_refine->operator()(id); + } else { + throw std::runtime_error("raw_distance should not be called if refine == false"); + } + } + throw std::runtime_error("raw_distance not implemented"); + } + private: const IndexType* index_ = nullptr; std::unique_ptr workspace_ = nullptr; @@ -923,9 +935,10 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un if constexpr (!std::is_same::value && !std::is_same::value && !std::is_same::value && - !std::is_same::value) { + !std::is_same::value && + !std::is_same::value) { LOG_KNOWHERE_WARNING_ << "Current index_type: " << Type() - << ", only IVFFlat, IVFFlatCC, IVF_SQ8 and IVF_SQ_CC support Iterator."; + << ", only IVFFlat, IVFFlatCC, IVF_SQ8, IVF_SQ_CC and SCANN support Iterator."; return expected>::Err(Status::not_implemented, "index not supported"); } else { auto dim = dataset->GetDim(); @@ -942,6 +955,11 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un // set iterator_refine_ratio = 0.0. If quantizer != flat, faiss:indexivf will not keep raw data; // TODO: if SCANN support Iterator, iterator_refine_ratio should be set. float iterator_refine_ratio = 0.0f; + if constexpr (std::is_same_v) { + if (HasRawData(ivf_cfg.metric_type.value())) { + iterator_refine_ratio = ivf_cfg.iterator_refine_ratio.value(); + } + } try { std::vector> futs; futs.reserve(rows); diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 17a166114..611996b1a 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -187,6 +187,20 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { return json; }; + auto scann_gen = [ivf_base_gen]() { + knowhere::Json json = ivf_base_gen(); + json[knowhere::indexparam::NPROBE] = 14; + json[knowhere::indexparam::REORDER_K] = 200; + json[knowhere::indexparam::WITH_RAW_DATA] = true; + return json; + }; + + auto scann_gen2 = [ivf_base_gen]() { + knowhere::Json json = ivf_base_gen(); + json[knowhere::indexparam::WITH_RAW_DATA] = false; + return json; + }; + auto rand = GENERATE(1, 2); const auto train_ds = GenDataSet(nb, dim, rand); @@ -209,6 +223,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); @@ -295,6 +311,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); @@ -341,6 +359,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") { make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen), // make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen), + make_tuple(knowhere::IndexEnum::INDEX_FAISS_SCANN, scann_gen2), })); auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); auto cfg_json = gen().dump(); diff --git a/thirdparty/faiss/faiss/IndexIVF.cpp b/thirdparty/faiss/faiss/IndexIVF.cpp index 4ea287753..7170acc0d 100644 --- a/thirdparty/faiss/faiss/IndexIVF.cpp +++ b/thirdparty/faiss/faiss/IndexIVF.cpp @@ -152,6 +152,18 @@ idx_t Level1Quantizer::decode_listno(const uint8_t* code) const { return list_no; } +/***************************************** + * IVFIteratorWorkspace implementation + ******************************************/ +IVFIteratorWorkspace::IVFIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* search_params) + : query_data(query_data), + search_params(search_params), + dis_refine(nullptr) {} + +IVFIteratorWorkspace::~IVFIteratorWorkspace() {} + /***************************************** * IndexIVF implementation ******************************************/ diff --git a/thirdparty/faiss/faiss/IndexIVF.h b/thirdparty/faiss/faiss/IndexIVF.h index 34ad84371..8fa948e2c 100644 --- a/thirdparty/faiss/faiss/IndexIVF.h +++ b/thirdparty/faiss/faiss/IndexIVF.h @@ -93,12 +93,13 @@ struct SearchParametersIVF : SearchParameters { // the new convention puts the index type after SearchParameters using IVFSearchParameters = SearchParametersIVF; - +struct DistanceComputer; struct IVFIteratorWorkspace { + IVFIteratorWorkspace() = default; IVFIteratorWorkspace( const float* query_data, - const IVFSearchParameters* search_params) - : query_data(query_data), search_params(search_params) {} + const IVFSearchParameters* search_params); + virtual ~IVFIteratorWorkspace(); const float* query_data = nullptr; // single query const IVFSearchParameters* search_params = nullptr; @@ -112,6 +113,7 @@ struct IVFIteratorWorkspace { nullptr; // backup coarse centroids ids (heap) std::unique_ptr coarse_list_sizes = nullptr; // snapshot of the list_size + std::unique_ptr dis_refine; }; struct InvertedListScanner; @@ -245,7 +247,7 @@ struct IndexIVF : Index, IndexIVFInterface { size_t code_size, MetricType metric = METRIC_L2); - std::unique_ptr getIteratorWorkspace( + virtual std::unique_ptr getIteratorWorkspace( const float* query_data, const IVFSearchParameters* ivfsearchParams) const; @@ -255,7 +257,7 @@ struct IndexIVF : Index, IndexIVFInterface { // iterator `Next()` operation. // When there are not enough nodes in the heap, iterator will scan the // next coarse list. - void getIteratorNextBatch( + virtual void getIteratorNextBatch( IVFIteratorWorkspace* workspace, size_t current_backup_count) const; diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp index d93ac1481..d2eca4914 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp @@ -395,6 +395,59 @@ void IndexIVFFastScan::range_search( range_search_dispatch_implem(n, x, radius, *result, cq, nullptr, params); } +std::unique_ptr IndexIVFFastScan::getIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* ivfsearchParams) const { + auto base_workspace = + IndexIVF::getIteratorWorkspace(query_data, ivfsearchParams); + + auto ivf_fast_scan_workspace = + std::make_unique( + std::move(base_workspace)); + + ivf_fast_scan_workspace->dim12 = ksub * M2; + CoarseQuantized cq{ + ivf_fast_scan_workspace->nprobe, + ivf_fast_scan_workspace->coarse_dis.get(), + ivf_fast_scan_workspace->coarse_idx.get()}; + compute_LUT_uint8( + 1, + ivf_fast_scan_workspace->query_data, + cq, + ivf_fast_scan_workspace->dis_tables, + ivf_fast_scan_workspace->biases, + ivf_fast_scan_workspace->normalizers); + return ivf_fast_scan_workspace; +} + +void IndexIVFFastScan::getIteratorNextBatch( + IVFIteratorWorkspace* workspace, + size_t current_backup_count) const { + auto ivf_fast_scan_workspace = + dynamic_cast(workspace); + ivf_fast_scan_workspace->dists.clear(); + + std::unique_ptr handler; + bool is_max = !is_similarity_metric(metric_type); + auto id_selector = ivf_fast_scan_workspace->search_params->sel + ? ivf_fast_scan_workspace->search_params->sel + : nullptr; + if (is_max) { + handler.reset(new SingleQueryResultCollectHandler< + CMax, + true>( + ivf_fast_scan_workspace->dists, ntotal, id_selector)); + } else { + handler.reset(new SingleQueryResultCollectHandler< + CMin, + true>( + ivf_fast_scan_workspace->dists, ntotal, id_selector)); + } + + get_interator_next_batch_implem_10( + *handler.get(), ivf_fast_scan_workspace, current_backup_count); +} + namespace { template @@ -1701,6 +1754,53 @@ void IndexIVFFastScan::reconstruct_orig_invlists() { } } +void IndexIVFFastScan::get_interator_next_batch_implem_10( + SIMDResultHandlerToFloat& handler, + IVFFastScanIteratorWorkspace* workspace, + size_t current_backup_count) const { + bool single_LUT = !lookup_table_is_3d(); + handler.begin(skip & 16 ? nullptr : workspace->normalizers); + auto dim12 = workspace->dim12; + const uint8_t* LUT = nullptr; + + if (single_LUT) { + LUT = workspace->dis_tables.get(); + } + while (current_backup_count + workspace->dists.size() < + workspace->backup_count_threshold && + workspace->next_visit_coarse_list_idx < nlist) { + auto next_list_idx = workspace->next_visit_coarse_list_idx; + workspace->next_visit_coarse_list_idx++; + if (!single_LUT) { + LUT = workspace->dis_tables.get() + next_list_idx * dim12; + } + invlists->prefetch_lists( + workspace->coarse_idx.get() + next_list_idx, 1); + if (workspace->biases.get()) { + handler.dbias = workspace->biases.get() + next_list_idx; + } + idx_t list_no = workspace->coarse_idx[next_list_idx]; + size_t ls = invlists->list_size(list_no); + if (list_no < 0 || ls == 0) + continue; + + InvertedLists::ScopedCodes codes(invlists, list_no); + InvertedLists::ScopedIds ids(invlists, list_no); + handler.ntotal = ls; + handler.id_map = ids.get(); + pq4_accumulate_loop( + 1, + roundup(ls, bbs), + bbs, + M2, + codes.get(), + LUT, + handler, + nullptr); + } + handler.end(); +} + // IVFFastScanStats IVFFastScan_stats; } // namespace faiss diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.h b/thirdparty/faiss/faiss/IndexIVFFastScan.h index 51121acd5..4ca5b5db8 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.h +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.h @@ -37,8 +37,34 @@ struct SIMDResultHandlerToFloat; * For range search, only 10 and 12 are supported. * add 100 to the implem to force single-thread scanning (the coarse quantizer * may still use multiple threads). + * + * For search interator, only 10 are supported, one query, no qbs */ +struct IVFFastScanIteratorWorkspace : IVFIteratorWorkspace { + IVFFastScanIteratorWorkspace() = default; + IVFFastScanIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* search_params) + : IVFIteratorWorkspace(query_data, search_params){}; + IVFFastScanIteratorWorkspace( + std::unique_ptr&& base_workspace) { + this->query_data = base_workspace->query_data; + this->search_params = base_workspace->search_params; + this->nprobe = base_workspace->nprobe; + this->backup_count_threshold = base_workspace->backup_count_threshold; + this->coarse_dis = std::move(base_workspace->coarse_dis); + this->coarse_idx = std::move(base_workspace->coarse_idx); + this->coarse_list_sizes = std::move(base_workspace->coarse_list_sizes); + base_workspace = nullptr; + return; + } + size_t dim12; + AlignedTable dis_tables; + AlignedTable biases; + float normalizers[2]; +}; + struct IndexIVFFastScan : IndexIVF { // size of the kernel int bbs; // set at build time @@ -147,6 +173,14 @@ struct IndexIVFFastScan : IndexIVF { const IVFSearchParameters* params = nullptr, IndexIVFStats* stats = nullptr) const override; + std::unique_ptr getIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* ivfsearchParams) const override; + + void getIteratorNextBatch( + IVFIteratorWorkspace* workspace, + size_t current_backup_count) const override; + // range_search implementation was introduced in Knowhere, // diff 73f03354568b4bf5a370df6f37e8d56dfc3a9c85 void range_search( @@ -243,6 +277,12 @@ struct IndexIVFFastScan : IndexIVF { const NormTableScaler* scaler, const IVFSearchParameters* params = nullptr) const; + // one query call, no qbs + void get_interator_next_batch_implem_10( + SIMDResultHandlerToFloat& handler, + IVFFastScanIteratorWorkspace* workspace, + size_t current_backup_count) const; + // implem 14 is multithreaded internally across nprobes and queries void search_implem_14( idx_t n, diff --git a/thirdparty/faiss/faiss/IndexScaNN.cpp b/thirdparty/faiss/faiss/IndexScaNN.cpp index 015127bd5..16ff89667 100644 --- a/thirdparty/faiss/faiss/IndexScaNN.cpp +++ b/thirdparty/faiss/faiss/IndexScaNN.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include namespace faiss { @@ -255,4 +257,37 @@ void IndexScaNN::range_search( result->lims[1] = current; } +std::unique_ptr IndexScaNN::getIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* ivfsearchParams) const { + auto base = dynamic_cast(base_index); + auto iterator = base->getIteratorWorkspace(query_data, ivfsearchParams); + if (refine_index) { + auto refine = dynamic_cast(refine_index); + if (base->is_cosine) { + iterator->dis_refine = std::unique_ptr( + new faiss::WithCosineNormDistanceComputer( + base->norms.data(), + base->d, + std::unique_ptr( + refine->get_distance_computer()))); + } else { + iterator->dis_refine = std::unique_ptr( + refine->get_FlatCodesDistanceComputer()); + } + iterator->dis_refine->set_query(query_data); + } else { + iterator->dis_refine = nullptr; + } + + return iterator; +} + +void IndexScaNN::getIteratorNextBatch( + IVFIteratorWorkspace* workspace, + size_t current_backup_count) const { + auto base = dynamic_cast(base_index); + return base->getIteratorNextBatch(workspace, current_backup_count); +} + } // namespace faiss \ No newline at end of file diff --git a/thirdparty/faiss/faiss/IndexScaNN.h b/thirdparty/faiss/faiss/IndexScaNN.h index d1a2f6a13..9748953b3 100644 --- a/thirdparty/faiss/faiss/IndexScaNN.h +++ b/thirdparty/faiss/faiss/IndexScaNN.h @@ -44,6 +44,14 @@ struct IndexScaNN : IndexRefine { float radius, RangeSearchResult* result, const SearchParameters* params = nullptr) const override; + + std::unique_ptr getIteratorWorkspace( + const float* query_data, + const IVFSearchParameters* ivfsearchParams) const; + + void getIteratorNextBatch( + IVFIteratorWorkspace* workspace, + size_t current_backup_count) const; }; } // namespace faiss \ No newline at end of file diff --git a/thirdparty/faiss/faiss/impl/simd_result_handlers.h b/thirdparty/faiss/faiss/impl/simd_result_handlers.h index 26ebd88e0..c13d18071 100644 --- a/thirdparty/faiss/faiss/impl/simd_result_handlers.h +++ b/thirdparty/faiss/faiss/impl/simd_result_handlers.h @@ -434,6 +434,87 @@ struct HeapHandler : ResultHandlerCompare { } }; +/** Structure that collects results, and return all */ + +/** Structure that collects results in a min- or max-heap */ +template +struct SingleQueryResultCollectHandler : ResultHandlerCompare { + using T = typename C::T; + using TI = typename C::TI; + using RHC = ResultHandlerCompare; + using RHC::normalizers; + + std::vector idis; + std::vector iids; + std::vector& collect; + const int q_id = 0; + + SingleQueryResultCollectHandler( + std::vector& res, + size_t ntotal, + const IDSelector* sel_in) + : RHC(1, ntotal, sel_in), collect(res) { + this->q_map = &q_id; + } + + void begin(const float* norms) override { + normalizers = norms; + } + + void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) final { + if (this->disable) { + return; + } + + this->adjust_with_origin(q, d0, d1); + + uint32_t lt_mask = this->get_lt_mask(C::neutral(), b, d0, d1); + + if (!lt_mask) { + return; + } + + ALIGNED(32) uint16_t d32tab[32]; + d0.store(d32tab); + d1.store(d32tab + 16); + + if (this->sel != nullptr) { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + auto real_idx = this->adjust_id(b, j); + lt_mask -= 1 << j; + if (this->sel->is_member(real_idx)) { + T dis = d32tab[j]; + collect.emplace_back(real_idx, dis); + this->in_range_num += 1; + } + } + } + else { + while (lt_mask) { + // find first non-zero + int j = __builtin_ctz(lt_mask); + lt_mask -= 1 << j; + T dis = d32tab[j]; + int64_t idx = this->adjust_id(b, j); + collect.emplace_back(idx, dis); + this->in_range_num += 1; + + } + } + } + + void end() override { + if (normalizers) { + float one_a = 1 / normalizers[0]; + float b = normalizers[1]; + for (auto i = 0; i < collect.size(); i++) { + collect[i].val = collect[i].val * one_a + b; + } + } + } +}; /** Simple top-N implementation using a reservoir. * * Results are stored when they are below the threshold until the capacity is