From 1667fc3e38800ea7ead63fefabb8f91d9cb98d8b Mon Sep 17 00:00:00 2001 From: cqy123456 Date: Wed, 8 Jan 2025 14:39:13 +0800 Subject: [PATCH] enhance: knowhere support data view index node Signed-off-by: cqy123456 --- include/knowhere/comp/index_param.h | 3 +- include/knowhere/index/index_node.h | 22 +- include/knowhere/object.h | 7 +- include/knowhere/operands.h | 31 + include/knowhere/utils.h | 45 +- src/common/utils.cc | 71 +- .../data_view_dense_index.h | 658 ++++++++++++++++++ .../data_view_index_config.h | 142 ++++ .../index_node_with_data_view_refiner.h | 515 ++++++++++++++ src/index/hnsw/faiss_hnsw.cc | 26 - src/index/ivf/ivf.cc | 9 +- src/index/ivf/ivf_config.h | 14 +- tests/ut/test_data_view_index.cc | 272 ++++++++ tests/ut/test_iterator.cc | 97 +++ tests/ut/test_utils.cc | 2 - tests/ut/utils.h | 4 +- thirdparty/faiss/faiss/IndexIVFFastScan.cpp | 4 + thirdparty/faiss/faiss/IndexIVFFastScan.h | 3 + thirdparty/faiss/faiss/IndexRefine.cpp | 44 +- thirdparty/faiss/faiss/utils/Heap.cpp | 46 ++ thirdparty/faiss/faiss/utils/Heap.h | 10 + 21 files changed, 1922 insertions(+), 103 deletions(-) create mode 100644 src/index/data_view_dense_index/data_view_dense_index.h create mode 100644 src/index/data_view_dense_index/data_view_index_config.h create mode 100644 src/index/data_view_dense_index/index_node_with_data_view_refiner.h create mode 100644 tests/ut/test_data_view_index.cc diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index d389522fd..35d33a89a 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -30,6 +30,7 @@ constexpr const char* INDEX_FAISS_IVFFLAT = "IVF_FLAT"; constexpr const char* INDEX_FAISS_IVFFLAT_CC = "IVF_FLAT_CC"; constexpr const char* INDEX_FAISS_IVFPQ = "IVF_PQ"; constexpr const char* INDEX_FAISS_SCANN = "SCANN"; +constexpr const char* INDEX_FAISS_SCANN_WITH_DV_REFINER = "SCANN_WITH_DV_REFINER"; constexpr const char* INDEX_FAISS_IVFSQ8 = "IVF_SQ8"; constexpr const char* INDEX_FAISS_IVFSQ_CC = "IVF_SQ_CC"; @@ -118,7 +119,7 @@ constexpr const char* WITH_RAW_DATA = "with_raw_data"; constexpr const char* ENSURE_TOPK_FULL = "ensure_topk_full"; constexpr const char* CODE_SIZE = "code_size"; constexpr const char* RAW_DATA_STORE_PREFIX = "raw_data_store_prefix"; - +constexpr const char* SUB_DIM = "sub_dim"; // RAFT Params constexpr const char* REFINE_RATIO = "refine_ratio"; constexpr const char* CACHE_DATASET_ON_DEVICE = "cache_dataset_on_device"; diff --git a/include/knowhere/index/index_node.h b/include/knowhere/index/index_node.h index 829669970..1049b34af 100644 --- a/include/knowhere/index/index_node.h +++ b/include/knowhere/index/index_node.h @@ -524,8 +524,16 @@ class IndexIterator : public IndexNode::iterator { } protected: + inline size_t + min_refine_size() const { + // TODO: maybe make this configurable + return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_)); + } + virtual void - next_batch(std::function&)> batch_handler) = 0; + next_batch(std::function&)> batch_handler) { + throw std::runtime_error("next_batch not implemented"); + } // will be called only if refine_ratio_ is not 0. virtual float raw_distance(int64_t) { @@ -537,18 +545,15 @@ class IndexIterator : public IndexNode::iterator { const float refine_ratio_; const bool refine_; + bool initialized_ = false; + bool retain_iterator_order_ = false; + const int64_t sign_; std::priority_queue, std::greater> res_; // unused if refine_ is false std::priority_queue, std::greater> refined_res_; private: - inline size_t - min_refine_size() const { - // TODO: maybe make this configurable - return std::max((size_t)20, (size_t)(res_.size() * refine_ratio_)); - } - void UpdateNext() { auto batch_handler = [this](const std::vector& batch) { @@ -569,10 +574,7 @@ class IndexIterator : public IndexNode::iterator { next_batch(batch_handler); } - bool initialized_ = false; - bool retain_iterator_order_ = false; bool use_knowhere_search_pool_ = true; - const int64_t sign_; }; // An iterator implementation that accepts a function to get distances and ids list and returns them in order. diff --git a/include/knowhere/object.h b/include/knowhere/object.h index b683dcedc..2541e0f80 100644 --- a/include/knowhere/object.h +++ b/include/knowhere/object.h @@ -14,6 +14,7 @@ #include #include +#include #include #include @@ -73,11 +74,13 @@ class Object { mutable std::atomic_uint32_t ref_counts_ = 1; }; +using ViewDataOp = std::function; + template class Pack : public Object { - static_assert(std::is_same_v>, + static_assert(std::is_same_v> || std::is_same_v, "IndexPack only support std::shared_ptr by far."); - + // todo: pack can hold more object public: Pack() { } diff --git a/include/knowhere/operands.h b/include/knowhere/operands.h index 545e03a91..b567e73b7 100644 --- a/include/knowhere/operands.h +++ b/include/knowhere/operands.h @@ -196,5 +196,36 @@ template <> struct MockData { using type = knowhere::fp32; }; + +// +enum class DataFormatEnum { fp32, fp16, bf16, int8, bin1 }; + +template +struct DataType2EnumHelper {}; + +template <> +struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp32; +}; +template <> +struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::fp16; +}; +template <> +struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::bf16; +}; +template <> +struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::int8; +}; +template <> +struct DataType2EnumHelper { + static constexpr DataFormatEnum value = DataFormatEnum::bin1; +}; + +template +static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; + } // namespace knowhere #endif /* OPERANDS_H */ diff --git a/include/knowhere/utils.h b/include/knowhere/utils.h index bd8ddca24..77b7b0fdc 100644 --- a/include/knowhere/utils.h +++ b/include/knowhere/utils.h @@ -36,6 +36,14 @@ IsFlatIndex(const knowhere::IndexType& index_type) { return std::find(flat_index_list.begin(), flat_index_list.end(), index_type) != flat_index_list.end(); } +template +float +GetL2Norm(const DataType* x, int32_t d); + +template +std::vector +GetL2Norms(const DataType* x, int32_t d, int32_t n); + template extern float NormalizeVec(DataType* x, int32_t d); @@ -52,6 +60,10 @@ template extern void NormalizeDataset(const DataSetPtr dataset); +template +extern std::tuple> +CopyAndNormalizeDataset(const DataSetPtr dataset); + constexpr inline uint64_t seed = 0xc70f6907UL; inline uint64_t @@ -112,8 +124,10 @@ GetKey(const std::string& name) { template inline DataSetPtr data_type_conversion(const DataSet& src, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { - auto dim = src.GetDim(); + const std::optional count = std::nullopt, + const std::optional filling_dim = std::nullopt) { + auto in_dim = src.GetDim(); + auto out_dim = filling_dim.value_or(in_dim); auto rows = src.GetRows(); // check the acceptable range @@ -128,15 +142,18 @@ data_type_conversion(const DataSet& src, const std::optional start = st } // map - auto* des_data = new OutType[dim * count_rows]; + auto* des_data = new OutType[out_dim * count_rows]; + std::memset(des_data, 0, sizeof(OutType) * out_dim * count_rows); auto* src_data = (const InType*)src.GetTensor(); - for (auto i = 0; i < dim * count_rows; i++) { - des_data[i] = (OutType)src_data[i + start_row * dim]; + for (auto i = 0; i < count_rows; i++) { + for (auto d = 0; d < in_dim; d++) { + des_data[i * out_dim + d] = (OutType)src_data[(start_row + i) * in_dim + d]; + } } auto des = std::make_shared(); des->SetRows(count_rows); - des->SetDim(dim); + des->SetDim(out_dim); des->SetTensor(des_data); des->SetIsOwner(true); return des; @@ -152,28 +169,32 @@ data_type_conversion(const DataSet& src, const std::optional start = st template inline DataSetPtr ConvertFromDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { + const std::optional count = std::nullopt, + const std::optional filling_dim = std::nullopt) { if constexpr (std::is_same_v::type>) { - if (!start.has_value() && !count.has_value()) { + if (!start.has_value() && !count.has_value() && + (!filling_dim.has_value() || ds->GetDim() == filling_dim.value())) { return ds; } } - return data_type_conversion::type>(*ds, start, count); + return data_type_conversion::type>(*ds, start, count, filling_dim); } // Convert DataSet from float to DataType template inline DataSetPtr ConvertToDataTypeIfNeeded(const DataSetPtr& ds, const std::optional start = std::nullopt, - const std::optional count = std::nullopt) { + const std::optional count = std::nullopt, + const std::optional filling_dim = std::nullopt) { if constexpr (std::is_same_v::type>) { - if (!start.has_value() && !count.has_value()) { + if (!start.has_value() && !count.has_value() && + (!filling_dim.has_value() || ds->GetDim() == filling_dim.value())) { return ds; } } - return data_type_conversion::type, DataType>(*ds, start, count); + return data_type_conversion::type, DataType>(*ds, start, count, filling_dim); } template diff --git a/src/common/utils.cc b/src/common/utils.cc index 8e4d2f15a..48c7ad07a 100644 --- a/src/common/utils.cc +++ b/src/common/utils.cc @@ -25,11 +25,38 @@ namespace knowhere { const float FloatAccuracy = 0.00001; +template +float +GetL2Norm(const DataType* x, int32_t d) { + float norm_l2_sqr = 0.0; + if constexpr (std::is_same_v) { + norm_l2_sqr = faiss::fvec_norm_L2sqr(x, d); + } else if constexpr (std::is_same_v) { + norm_l2_sqr = faiss::fp16_vec_norm_L2sqr(x, d); + } else if constexpr (std::is_same_v) { + norm_l2_sqr = faiss::bf16_vec_norm_L2sqr(x, d); + } else { + KNOWHERE_THROW_MSG("Unknown Datatype"); + } + + if (norm_l2_sqr > 0 && std::abs(1.0f - norm_l2_sqr) > FloatAccuracy) { + float norm_l2 = std::sqrt(norm_l2_sqr); + return norm_l2; + } + return 1.0f; +} +template +std::vector +GetL2Norms(const DataType* x, int32_t d, int32_t n) { + std::vector norms(n); + for (auto i = 0; i < n; i++) { + auto x_i = x + d * i; + norms[i] = GetL2Norm(x_i, d); + } + return norms; +} // normalize one vector and return its norm -// todo(cqy123456): Template specialization for fp16/bf16; -// float16 uses the smallest representable positive float16 value(6.1 x 10^(-5)) as FloatAccuracy; -// bfloat16 uses the same FloatAccuracy as float32; template float NormalizeVec(DataType* x, int32_t d) { @@ -83,10 +110,26 @@ NormalizeDataset(const DataSetPtr dataset) { auto data = (DataType*)dataset->GetTensor(); LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim; - NormalizeVecs(data, rows, dim); } +template +std::tuple> +CopyAndNormalizeDataset(const DataSetPtr dataset) { + auto rows = dataset->GetRows(); + auto dim = dataset->GetDim(); + auto data = (DataType*)dataset->GetTensor(); + + LOG_KNOWHERE_DEBUG_ << "vector normalize, rows " << rows << ", dim " << dim; + + auto x_normalized = new DataType[rows * dim]; + std::copy_n(data, rows * dim, x_normalized); + auto norms = NormalizeVecs(x_normalized, rows, dim); + auto normalize_bs = GenDataSet(rows, dim, x_normalized); + normalize_bs->SetIsOwner(true); + return std::make_tuple(normalize_bs, norms); +} + void ConvertIVFFlat(const BinarySet& binset, const MetricType metric_type, const uint8_t* raw_data, const size_t raw_size) { std::vector names = {"IVF", // compatible with knowhere-1.x @@ -135,6 +178,20 @@ UseDiskLoad(const std::string& index_type, const int32_t& version) { #endif } +template float +GetL2Norm(const fp32* x, int32_t d); +template float +GetL2Norm(const fp16* x, int32_t d); +template float +GetL2Norm(const bf16* x, int32_t d); + +template std::vector +GetL2Norms(const fp32* x, int32_t d, int32_t n); +template std::vector +GetL2Norms(const fp16* x, int32_t d, int32_t n); +template std::vector +GetL2Norms(const bf16* x, int32_t d, int32_t n); + template float NormalizeVec(fp32* x, int32_t d); template float @@ -163,4 +220,10 @@ NormalizeDataset(const DataSetPtr dataset); template void NormalizeDataset(const DataSetPtr dataset); +template std::tuple> +CopyAndNormalizeDataset(const DataSetPtr dataset); +template std::tuple> +CopyAndNormalizeDataset(const DataSetPtr dataset); +template std::tuple> +CopyAndNormalizeDataset(const DataSetPtr dataset); } // namespace knowhere diff --git a/src/index/data_view_dense_index/data_view_dense_index.h b/src/index/data_view_dense_index/data_view_dense_index.h new file mode 100644 index 000000000..b654579b5 --- /dev/null +++ b/src/index/data_view_dense_index/data_view_dense_index.h @@ -0,0 +1,658 @@ +// Copyright (C) 2019-2024 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +// knowhere-specific indices +#pragma once + +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "faiss/impl/AuxIndexStructures.h" +#include "faiss/impl/ResultHandler.h" +#include "faiss/utils/distances_if.h" +#include "knowhere/bitsetview_idselector.h" +#include "knowhere/comp/thread_pool.h" +#include "knowhere/config.h" +#include "knowhere/operands.h" +#include "knowhere/range_util.h" +#include "knowhere/utils.h" +#include "simd/hook.h" +namespace knowhere { +using idx_t = faiss::idx_t; +using CMAX = faiss::CMax; +using CMIN = faiss::CMin; +struct RangeSearchResult; +class DataViewIndexBase { + public: + DataViewIndexBase(idx_t d, DataFormatEnum data_type, MetricType metric_type, ViewDataOp view, bool is_cosine) + : d_(d), data_type_(data_type), metric_type_(metric_type), view_data_(view), is_cosine_(is_cosine) { + if (metric_type != metric::L2 && metric_type != metric::IP) { + throw std::runtime_error("DataViewIndexBase only support L2 or IP."); + } + if (data_type_ == DataFormatEnum::fp32) { + code_size_ = sizeof(fp32) * d_; + } else if (data_type_ == DataFormatEnum::fp16) { + code_size_ = sizeof(fp16) * d_; + } else if (data_type_ == DataFormatEnum::bf16) { + code_size_ = sizeof(bf16) * d_; + } else { + throw std::runtime_error("data view index only support float data type."); + } + } + virtual ~DataViewIndexBase(){}; + + virtual void + Train(idx_t n, const void* x) = 0; + + virtual void + Add(idx_t n, const void* x, const float* norms_) = 0; + + virtual void + Search(const idx_t n, const void* x, const idx_t k, float* distances, idx_t* labels, + const BitsetView& bitset) const = 0; + + /** Knn Search on set of vectors + * + * @param n nb of vectors to query + * @param x query vectors, size nx * d + * @param ids_num_lims prefix sum of different selected ids rows , size n + 1 + * @param ids selected ids for each queries + * @param k topk + * @param out_dist result ids, size nx * topk + * @param out_ids result distances, size nx * topk + */ + virtual void + SearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, const idx_t k, + float* out_dist, idx_t* out_ids) const = 0; + + virtual RangeSearchResult + RangeSearch(const idx_t n, const void* x, const float radius, const float range_filter, + const BitsetView& bitset) const = 0; + + /** Range Search on set of vectors + * + * @param n nb of vectors to query + * @param x query vectors, size nx * d + * @param ids_num_lims prefix sum of different selected ids rows , size n + 1 + * @param ids selected ids for each queries + * @param radius + * @param range_filter + */ + virtual RangeSearchResult + RangeSearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, const float radius, + const float range_filter) const = 0; + + virtual void + ComputeDistanceSubset(const void* x, const idx_t sub_y_n, float* x_y_distances, const idx_t* x_y_labels) const = 0; + + auto + Dim() const { + return d_; + } + bool + IsCosine() const { + return is_cosine_; + } + MetricType + Metric() const { + return metric_type_; + } + DataFormatEnum + DataFormat() const { + return data_type_; + } + ViewDataOp + GetViewData() const { + return view_data_; + } + idx_t + Count() const { + return ntotal_.load(); + } + + protected: + int d_; + DataFormatEnum data_type_; + MetricType metric_type_; + ViewDataOp view_data_; + bool is_cosine_; + int code_size_; + std::atomic ntotal_ = 0; +}; + +class DataViewIndexFlat : public DataViewIndexBase { + public: + DataViewIndexFlat(idx_t d, DataFormatEnum data_type, MetricType metric_type, ViewDataOp view, bool is_cosine) + : DataViewIndexBase(d, data_type, metric_type, view, is_cosine) { + this->ntotal_.store(0); + } + void + Train(idx_t n, const void* x) override { + // do nothing + return; + } + + void + Add(idx_t n, const void* x, const float* in_norms) override { + if (is_cosine_) { + if (in_norms == nullptr) { + std::vector l2_norms; + if (data_type_ == DataFormatEnum::fp32) { + l2_norms = GetL2Norms((const fp32*)x, d_, n); + } else if (data_type_ == DataFormatEnum::fp16) { + l2_norms = GetL2Norms((const fp16*)x, d_, n); + } else { + l2_norms = GetL2Norms((const bf16*)x, d_, n); + } + std::unique_lock lock(norms_mutex_); + norms_.insert(norms_.end(), l2_norms.begin(), l2_norms.end()); + } else { + std::unique_lock lock(norms_mutex_); + norms_.insert(norms_.end(), in_norms, in_norms + n); + } + } + ntotal_.fetch_add(n); + } + + void + Search(const idx_t n, const void* x, const idx_t k, float* distances, idx_t* labels, + const BitsetView& bitset) const override; + + void + SearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, const idx_t k, + float* out_dist, idx_t* out_ids) const override; + + RangeSearchResult + RangeSearch(const idx_t n, const void* x, const float radius, const float range_filter, + const BitsetView& bitset) const override; + + RangeSearchResult + RangeSearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, const float radius, + const float range_filter) const override; + + void + ComputeDistanceSubset(const void* x, const idx_t sub_y_n, float* x_y_distances, + const idx_t* x_y_labels) const override; + + float + GetDataNorm(idx_t id) const { + assert(id < ntotal_); + if (norms_.size() < id) { // maybe cosine is false, get norm in place + auto data = view_data_(id); + if (data_type_ == DataFormatEnum::fp32) { + return GetL2Norm((const fp32*)data, d_); + } else if (data_type_ == DataFormatEnum::fp16) { + return GetL2Norm((const fp16*)data, d_); + } else { + return GetL2Norm((const bf16*)data, d_); + } + } else { + std::shared_lock lock(norms_mutex_); + return norms_[id]; + } + } + + protected: + std::vector norms_; + mutable std::shared_mutex norms_mutex_; +}; + +template +struct DataViewDistanceComputer : faiss::DistanceComputer { + ViewDataOp view_data; + size_t dim; + const DataType* q; + Distance1 dist1; + Distance4 dist4; + float q_norm; + + DataViewDistanceComputer(const DataViewIndexBase* index, Distance1 dist1, Distance4 dist4, + const DataType* query = nullptr, std::optional query_norm = std::nullopt) + : view_data(index->GetViewData()), dim(index->Dim()), dist1(dist1), dist4(dist4) { + if (query != nullptr) { + this->set_query((const float*)query, query_norm); + } + return; + } + + // convert x to float* for override, still use DataType to get distance + void + set_query(const float* x) override { + q = (const DataType*)x; + if constexpr (NeedNormalize) { + q_norm = GetL2Norm(q, dim); + } + } + + void + set_query(const float* x, std::optional x_norm = std::nullopt) { + q = (const DataType*)x; + if constexpr (NeedNormalize) { + q_norm = x_norm.value_or(GetL2Norm(q, dim)); + } + } + + float + operator()(idx_t i) override { + auto code = view_data(i); + return distance_to_code(code); + } + + float + distance_to_code(const void* x) { + if constexpr (NeedNormalize) { + return dist1(q, (const DataType*)x, dim) / q_norm; + } else { + return dist1(q, (const DataType*)x, dim); + } + } + + void + distances_batch_4(const idx_t idx0, const idx_t idx1, const idx_t idx2, const idx_t idx3, float& dis0, float& dis1, + float& dis2, float& dis3) override { + auto x0 = (DataType*)view_data(idx0); + auto x1 = (DataType*)view_data(idx1); + auto x2 = (DataType*)view_data(idx2); + auto x3 = (DataType*)view_data(idx3); + dist4(q, x0, x1, x2, x3, dim, dis0, dis1, dis2, dis3); + if constexpr (NeedNormalize) { + dis0 /= q_norm; + dis1 /= q_norm; + dis2 /= q_norm; + dis3 /= q_norm; + } + } + + /// compute distance between two stored vectors + float + symmetric_dis(idx_t i, idx_t j) override { + auto x = (DataType*)view_data(i); + auto y = (DataType*)view_data(j); + return dist1(x, y, dim); + } +}; + +static std::unique_ptr +SelectDataViewComputer(const DataViewIndexBase* index) { + if (index->DataFormat() == DataFormatEnum::fp16) { + if (index->Metric() == metric::IP) { + if (index->IsCosine()) { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::fp16_vec_inner_product, faiss::fp16_vec_inner_product_batch_4)); + } else { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::fp16_vec_inner_product, faiss::fp16_vec_inner_product_batch_4)); + } + } else { + return std::unique_ptr( + new DataViewDistanceComputer(index, faiss::fp16_vec_L2sqr, + faiss::fp16_vec_L2sqr_batch_4)); + } + } else if (index->DataFormat() == DataFormatEnum::bf16) { + if (index->Metric() == metric::IP) { + if (index->IsCosine()) { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::bf16_vec_inner_product, faiss::bf16_vec_inner_product_batch_4)); + } else { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::bf16_vec_inner_product, faiss::bf16_vec_inner_product_batch_4)); + } + } else { + return std::unique_ptr( + new DataViewDistanceComputer(index, faiss::bf16_vec_L2sqr, + faiss::bf16_vec_L2sqr_batch_4)); + } + } else if (index->DataFormat() == DataFormatEnum::fp32) { + if (index->Metric() == metric::IP) { + if (index->IsCosine()) { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::fvec_inner_product, faiss::fvec_inner_product_batch_4)); + } else { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::fvec_inner_product, faiss::fvec_inner_product_batch_4)); + } + } else { + return std::unique_ptr( + new DataViewDistanceComputer( + index, faiss::fvec_L2sqr, faiss::fvec_L2sqr_batch_4)); + } + } else { + return nullptr; + } +} +namespace { +template +void +exhaustive_search_one_query_impl(const std::unique_ptr& computer, size_t ny, + SingleResultHandler& resi, const SelectorHelper& selector, + const bool is_cosine = false, const float* norms_ = nullptr) { + if (is_cosine && norms_ == nullptr) { + throw std::runtime_error("Please provide norms if is_cosine == true."); + } + auto filter = [&selector](const size_t j) { return selector.is_member(j); }; + if (is_cosine) { + auto apply = [&resi, &norms_](const float dis, const idx_t j) { + auto dist_with_norm = dis / (norms_[j]); + resi.add_result(dist_with_norm, j); + }; + faiss::distance_compute_if(ny, computer.get(), filter, apply); + } else { + auto apply = [&resi](const float dis, const idx_t j) { resi.add_result(dis, j); }; + faiss::distance_compute_if(ny, computer.get(), filter, apply); + } +} +} // namespace + +void +DataViewIndexFlat::Search(const idx_t n, const void* x, const idx_t k, float* distances, idx_t* labels, + const BitsetView& bitset) const { + // todo: need more test to check + std::shared_ptr base_norms = nullptr; + if (is_cosine_) { + // use copy to avoid concurrent add and search + std::shared_lock lock(norms_mutex_); + base_norms = std::shared_ptr(new float[norms_.size()]); + std::memcpy(base_norms.get(), norms_.data(), sizeof(float) * norms_.size()); + } + const auto& search_pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + futs.reserve(n); + if (k < faiss::distance_compute_min_k_reservoir) { + if (metric_type_ == metric::L2) { + faiss::HeapBlockResultHandler res(n, distances, labels, k); + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + faiss::HeapBlockResultHandler::SingleResultHandler resi(res); + auto computer = SelectDataViewComputer(this); + computer->set_query((const float*)(x + code_size_ * i)); + resi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, resi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, resi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + resi.end(); + })); + } + WaitAllSuccess(futs); + } else { + faiss::HeapBlockResultHandler res(n, distances, labels, k); + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + faiss::HeapBlockResultHandler::SingleResultHandler resi(res); + auto computer = SelectDataViewComputer(this); + computer->set_query((const float*)(x + code_size_ * i)); + resi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, resi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, resi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + resi.end(); + })); + } + WaitAllSuccess(futs); + } + } else { + if (metric_type_ == metric::L2) { + faiss::ReservoirBlockResultHandler res(n, distances, labels, k); + + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + faiss::ReservoirBlockResultHandler::SingleResultHandler resi(res); + auto computer = SelectDataViewComputer(this); + computer->set_query((const float*)(x + code_size_ * i)); + resi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, resi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, resi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + resi.end(); + })); + } + WaitAllSuccess(futs); + } else { + faiss::ReservoirBlockResultHandler res(n, distances, labels, k); + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&] { + ThreadPool::ScopedSearchOmpSetter setter(1); + faiss::ReservoirBlockResultHandler::SingleResultHandler resi(res); + auto computer = SelectDataViewComputer(this); + computer->set_query((const float*)(x + code_size_ * i)); + resi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, resi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, resi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + resi.end(); + })); + WaitAllSuccess(futs); + } + } + } +} + +void +DataViewIndexFlat::SearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, + const idx_t k, float* out_dist, idx_t* out_ids) const { + const auto& search_pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + futs.reserve(n); + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&, i = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); + auto base_ids = ids + ids_num_lims[i]; + auto base_n = ids_num_lims[i + 1] - ids_num_lims[i]; + auto base_dist = std::unique_ptr(new float[base_n]); + if (metric_type_ == metric::L2) { + std::fill(base_dist.get(), base_dist.get() + base_n, CMAX::neutral()); + } else { + std::fill(base_dist.get(), base_dist.get() + base_n, CMIN::neutral()); + } + auto x_i = x + code_size_ * i; + + assert(base_n >= k); + ComputeDistanceSubset(x_i, base_n, base_dist.get(), base_ids); + if (is_cosine_) { + for (auto j = 0; j < base_n; j++) { + if (base_ids[j] != -1) { + std::shared_lock lock(norms_mutex_); + base_dist[j] = base_dist[j] / norms_[base_ids[j]]; + } + } + } + if (metric_type_ == metric::L2) { + faiss::reorder_2_heaps(1, k, out_ids + i * k, out_dist + i * k, base_n, base_ids, + base_dist.get()); + } else { + faiss::reorder_2_heaps(1, k, out_ids + i * k, out_dist + i * k, base_n, base_ids, + base_dist.get()); + } + })); + } + WaitAllSuccess(futs); + return; +} + +RangeSearchResult +DataViewIndexFlat::RangeSearch(const idx_t n, const void* x, const float radius, const float range_filter, + const BitsetView& bitset) const { + // todo: need more test to check + std::vector> result_dist_array(n); + std::vector> result_id_array(n); + + std::shared_ptr base_norms = nullptr; + if (is_cosine_) { + std::shared_lock lock(norms_mutex_); + base_norms = std::shared_ptr(new float[norms_.size()]); + std::memcpy(base_norms.get(), norms_.data(), sizeof(float) * norms_.size()); + } + auto is_ip = metric_type_ == metric::IP; + + const auto& search_pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + futs.reserve(n); + if (metric_type_ == metric::L2) { + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&, i = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); + auto computer = SelectDataViewComputer(this); + faiss::RangeSearchResult res(1); + faiss::RangeSearchBlockResultHandler resh(&res, radius); + faiss::RangeSearchBlockResultHandler::SingleResultHandler reshi(resh); + computer->set_query(((const float*)x + code_size_ * i)); + reshi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, reshi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, reshi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + reshi.end(); + auto elem_cnt = res.lims[1]; + result_dist_array[i].resize(elem_cnt); + result_id_array[i].resize(elem_cnt); + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[i][j] = res.distances[j]; + result_id_array[i][j] = res.labels[j]; + } + if (range_filter != defaultRangeFilter) { + FilterRangeSearchResultForOneNq(result_dist_array[i], result_id_array[i], is_ip, radius, + range_filter); + } + })); + WaitAllSuccess(futs); + } + } else { + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&, i = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); + auto computer = SelectDataViewComputer(this); + faiss::RangeSearchResult res(1); + faiss::RangeSearchBlockResultHandler resh(&res, radius); + faiss::RangeSearchBlockResultHandler::SingleResultHandler reshi(resh); + computer->set_query(((const float*)x + code_size_ * i)); + reshi.begin(i); + if (bitset.empty()) { + exhaustive_search_one_query_impl(computer, n, reshi, faiss::IDSelectorAll(), base_norms.get()); + } else { + exhaustive_search_one_query_impl(computer, n, reshi, BitsetViewIDSelector(bitset), + base_norms.get()); + } + reshi.end(); + auto elem_cnt = res.lims[1]; + result_dist_array[i].resize(elem_cnt); + result_id_array[i].resize(elem_cnt); + for (size_t j = 0; j < elem_cnt; j++) { + result_dist_array[i][j] = res.distances[j]; + result_id_array[i][j] = res.labels[j]; + } + if (range_filter != defaultRangeFilter) { + FilterRangeSearchResultForOneNq(result_dist_array[i], result_id_array[i], is_ip, radius, + range_filter); + } + })); + WaitAllSuccess(futs); + } + } + return GetRangeSearchResult(result_dist_array, result_id_array, is_ip, n, radius, range_filter); +} + +RangeSearchResult +DataViewIndexFlat::RangeSearchWithIds(const idx_t n, const void* x, const idx_t* ids_num_lims, const idx_t* ids, + const float radius, const float range_filter) const { + std::vector> result_dist_array(n); + std::vector> result_id_array(n); + auto is_ip = metric_type_ == metric::IP; + const auto& search_pool = ThreadPool::GetGlobalSearchThreadPool(); + std::vector> futs; + futs.reserve(n); + for (auto i = 0; i < n; i++) { + futs.emplace_back(search_pool->push([&, i = i] { + ThreadPool::ScopedSearchOmpSetter setter(1); + auto base_ids = ids + ids_num_lims[i]; + auto base_n = ids_num_lims[i + 1] - ids_num_lims[i]; + auto base_dist = std::unique_ptr(new float[base_n]); + const void* x_i = x + code_size_ * i; + ComputeDistanceSubset(x_i, base_n, base_dist.get(), base_ids); + if (is_cosine_) { + std::shared_lock lock(norms_mutex_); + for (auto j = 0; j < base_n; j++) { + base_dist[j] = base_dist[j] / norms_[base_ids[j]]; + } + } + for (auto j = 0; j < base_n; j++) { + if (!is_ip) { + if (base_dist[j] < radius) { + result_dist_array[i].emplace_back(base_dist[j]); + result_id_array[i].emplace_back(base_ids[j]); + } + } else { + if (base_dist[j] > radius) { + result_dist_array[i].emplace_back(base_dist[j]); + result_id_array[i].emplace_back(base_ids[j]); + } + } + } + if (range_filter != defaultRangeFilter) { + FilterRangeSearchResultForOneNq(result_dist_array[i], result_id_array[i], is_ip, radius, range_filter); + } + })); + } + WaitAllSuccess(futs); + return GetRangeSearchResult(result_dist_array, result_id_array, is_ip, n, radius, range_filter); +} + +void +DataViewIndexFlat::ComputeDistanceSubset(const void* x, const idx_t sub_y_n, float* x_y_distances, + const idx_t* x_y_labels) const { + auto computer = SelectDataViewComputer(this); + + computer->set_query((const float*)(x)); + const idx_t* __restrict idsj = x_y_labels; + float* __restrict disj = x_y_distances; + + auto filter = [=](const size_t i) { return (idsj[i] >= 0); }; + auto apply = [=](const float dis, const size_t i) { disj[i] = dis; }; + distance_compute_by_idx_if(idsj, sub_y_n, computer.get(), filter, apply); +} +} // namespace knowhere diff --git a/src/index/data_view_dense_index/data_view_index_config.h b/src/index/data_view_dense_index/data_view_index_config.h new file mode 100644 index 000000000..1b3b51112 --- /dev/null +++ b/src/index/data_view_dense_index/data_view_index_config.h @@ -0,0 +1,142 @@ + +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef DATA_VIEW_INDEX_CONFIG_H +#define DATA_VIEW_INDEX_CONFIG_H + +#include "index/ivf/ivf_config.h" +#include "simd/hook.h" +namespace knowhere { +class IndexWithDataViewRefinerConfig : public ScannConfig { + public: + CFG_INT reorder_k; + KNOHWERE_DECLARE_CONFIG(IndexWithDataViewRefinerConfig) { + KNOWHERE_CONFIG_DECLARE_FIELD(reorder_k) + .description("reorder k used for refining") + .allow_empty_without_default() + .set_range(1, std::numeric_limits::max()) + .for_search(); + } + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto topk = k.value(); + if (!reorder_k.has_value()) { + reorder_k = topk; + } else if (reorder_k.value() < topk) { + if (!err_msg) { + err_msg = new std::string(); + } + std::string msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + return HandleError(err_msg, msg, Status::out_of_range_in_json); + } + } + return Status::success; + } +}; +class ScannWithDataViewRefinerConfig : public ScannConfig { + public: + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + if (param_type == PARAM_TYPE::TRAIN) { + auto topk = k.value(); + if (!reorder_k.has_value()) { + reorder_k = topk; + } else if (reorder_k.value() < topk) { + if (!err_msg) { + err_msg = new std::string(); + } + std::string msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + return HandleError(err_msg, msg, Status::out_of_range_in_json); + } + } + if (!faiss::support_pq_fast_scan) { + LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is " + "needed for x86 arch."; + return Status::invalid_instruction_set; + } else { + return Status::success; + } + return Status::success; + } +}; + +static void +AdaptToBaseIndexConfig(Config* cfg, PARAM_TYPE param_type, size_t dim) { + // config can't do copy, change the base config in place. + if (cfg == nullptr) + return; + if (auto base_cfg = dynamic_cast(cfg)) { + if (base_cfg->metric_type.value() == metric::COSINE) { + base_cfg->metric_type.value() = metric::IP; + } + switch (param_type) { + case PARAM_TYPE::TRAIN: { + base_cfg->with_raw_data = false; + int sub_dim = base_cfg->sub_dim.value(); + if (dim % sub_dim != 0) { + dim = ROUND_UP(dim, sub_dim); + base_cfg->dim = dim; + } else { + base_cfg->dim = dim; + } + break; + } + case PARAM_TYPE::SEARCH: { + if (base_cfg->reorder_k.has_value()) { + base_cfg->k = base_cfg->reorder_k.value(); + } + break; + } + case PARAM_TYPE::RANGE_SEARCH: { + base_cfg->range_filter = defaultRangeFilter; + break; + } + case PARAM_TYPE::ITERATOR: { + if (base_cfg->iterator_refine_ratio != 0.0) { + base_cfg->retain_iterator_order = false; + } + break; + } + default: + break; + } + } else if (auto base_cfg = dynamic_cast(cfg)) { + if (base_cfg->metric_type.value() == metric::COSINE) { + base_cfg->metric_type.value() = metric::IP; + } + switch (param_type) { + case PARAM_TYPE::SEARCH: { + base_cfg->k = base_cfg->reorder_k.value(); + break; + } + case PARAM_TYPE::RANGE_SEARCH: { + base_cfg->range_filter = defaultRangeFilter; + break; + } + case PARAM_TYPE::ITERATOR: { + if (base_cfg->iterator_refine_ratio != 0.0) { + base_cfg->retain_iterator_order = false; + } + break; + } + default: + break; + } + } else { + throw std::runtime_error("Not a valid config for DV(Data View) refiner index."); + } +} +} // namespace knowhere +#endif diff --git a/src/index/data_view_dense_index/index_node_with_data_view_refiner.h b/src/index/data_view_dense_index/index_node_with_data_view_refiner.h new file mode 100644 index 000000000..326e0ab33 --- /dev/null +++ b/src/index/data_view_dense_index/index_node_with_data_view_refiner.h @@ -0,0 +1,515 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. +#ifndef INDEX_NODE_WITH_DATA_VIEW_REFINER_H +#define INDEX_NODE_WITH_DATA_VIEW_REFINER_H +#include +#include + +#include "faiss/utils/random.h" +#include "index/data_view_dense_index/data_view_dense_index.h" +#include "index/data_view_dense_index/data_view_index_config.h" +#include "knowhere/index/index_node.h" +namespace knowhere { +struct DataViewIndexFlat; +/* +IndexNodeWithDataViewRefiner is a Just in time index, support fast build and search. +This kind of index will not keep raw data anymore, so init it with a get raw data function(ViewDataOp). +And it maintain a basic index (code size < raw data size) and a refiner. +If metric == Cosine, base index will normalize all vectors, and replaced with Inner product; +refine_index will compute the IP distances, and divide by ||x|| and ||y||. + +todo: basic index use fp32, we should support more type later. +*/ +template +class IndexNodeWithDataViewRefiner : public IndexNode { + static_assert(KnowhereFloatTypeCheck::value); + + public: + IndexNodeWithDataViewRefiner(const int32_t& version, const Object& object) { + auto data_view_index_pack = dynamic_cast*>(&object); + assert(data_view_index_pack != nullptr); + view_data_op_ = data_view_index_pack->GetPack(); + base_index_ = std::make_unique(version, nullptr); + } + + Status + Train(const DataSetPtr dataset, std::shared_ptr cfg) override; + + Status + Add(const DataSetPtr dataset, std::shared_ptr cfg) override; + + expected + Search(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; + + expected + RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override; + + expected> + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset, + bool use_knowhere_search_pool) const override; + + expected + GetVectorByIds(const DataSetPtr dataset) const override { + return expected::Err(Status::not_implemented, "Data View Index not maintain raw data."); + } + + static Status + StaticConfigCheck(const Config& cfg, PARAM_TYPE paramType, std::string& msg) { + auto base_cfg = static_cast(cfg); + if constexpr (KnowhereFloatTypeCheck::value) { + if (IsMetricType(base_cfg.metric_type.value(), metric::L2) || + IsMetricType(base_cfg.metric_type.value(), metric::IP) || + IsMetricType(base_cfg.metric_type.value(), metric::COSINE)) { + } else { + msg = "metric type " + base_cfg.metric_type.value() + + " not found or not supported, supported: [L2 IP COSINE]"; + return Status::invalid_metric_type; + } + } + return Status::success; + } + + static bool + CommonHasRawData() { + return false; + } + + static bool + StaticHasRawData(const knowhere::BaseConfig& config, const IndexVersion& version) { + return false; + } + + bool + HasRawData(const std::string& metric_type) const override { + return false; + } + + expected + GetIndexMeta(std::unique_ptr cfg) const override { + if (!this->base_index_) { + return expected::Err(Status::empty_index, "Data View Index not maintain raw data."); + } + return this->base_index_->GetIndexMeta(std::move(cfg)); + } + + Status + Serialize(BinarySet& binset) const override { + LOG_KNOWHERE_ERROR_ << "Data View index is a JIT index type, do not Serialize"; + return Status::not_implemented; + } + + Status + Deserialize(const BinarySet& binset, std::shared_ptr cfg) override { + LOG_KNOWHERE_ERROR_ << "Data View index is a JIT index type, do not Deserialize"; + return Status::not_implemented; + } + + Status + DeserializeFromFile(const std::string& filename, std::shared_ptr cfg) override { + LOG_KNOWHERE_ERROR_ << "Data View index is a JIT index type, do not DeserializeFromFile"; + return Status::not_implemented; + } + + static std::unique_ptr + StaticCreateConfig() { + auto base_index_cfg = BaseIndexNode::StaticCreateConfig(); + if (dynamic_cast(base_index_cfg.get())) { + return std::make_unique(); + } else { + return std::make_unique(); + } + } + + std::unique_ptr + CreateConfig() const override { + if (base_index_->Type() == IndexEnum::INDEX_FAISS_SCANN) { + return std::make_unique(); + } else { + return std::make_unique(); + } + } + + int64_t + Dim() const override { + if (!this->refine_offset_index_) { + return -1; + } + return refine_offset_index_->Dim(); + } + + int64_t + Size() const override { + if (this->base_index_) { + return this->base_index_->Size(); + } + return 0; + } + + int64_t + Count() const override { + if (this->base_index_) { + return this->base_index_->Count(); + } + return 0; + } + + std::string + Type() const override; + + private: + class iterator : public IndexIterator { + public: + iterator(std::shared_ptr refine_offset_index, IndexNode::IteratorPtr base_workspace, + std::unique_ptr&& copied_query, bool larger_is_closer, float refine_ratio = 0.5f, + bool retain_iterator_order = false) + : IndexIterator(larger_is_closer, false, refine_ratio, retain_iterator_order), + refine_offset_index_(refine_offset_index), + copied_query_(std::move(copied_query)), + base_workspace_(base_workspace) { + refine_computer_ = SelectDataViewComputer(refine_offset_index.get()); + refine_computer_->set_query((const float*)copied_query_.get()); + } + + std::pair + Next() override { + if (!initialized_) { + initialize(); + } + if (!refine_) { + return base_workspace_->Next(); + } else { + auto ret = refined_res_.top(); + refined_res_.pop(); + UpdateNext(); + if (retain_iterator_order_) { + while (HasNext()) { + auto next_ret = refined_res_.top(); + if (next_ret.val >= ret.val) { + break; + } + refined_res_.pop(); + UpdateNext(); + } + } + return std::make_pair(ret.id, ret.val * sign_); + } + } + + [[nodiscard]] bool + HasNext() override { + if (!initialized_) { + initialize(); + } + if (!refine_) { + return base_workspace_->HasNext(); + } else { + return refined_res_.empty() || base_workspace_->HasNext(); + } + } + + void + initialize() override { + if (initialized_) { + throw std::runtime_error("initialize should not be called twice"); + } + UpdateNext(); + initialized_ = true; + } + + protected: + float + raw_distance(int64_t id) override { + if (refine_computer_ == nullptr) { + throw std::runtime_error("refine computer is null in offset refine index."); + } + if (refine_offset_index_->Count() <= id) { + throw std::runtime_error("the id of result larger than index rows count."); + } + float dis = refine_computer_->operator()(id); + dis = refine_offset_index_->IsCosine() ? dis / refine_offset_index_->GetDataNorm(id) : dis; + return dis; + } + + private: + void + UpdateNext() { + if (!base_workspace_->HasNext() || refine_ == false) { + return; + } + while (base_workspace_->HasNext() && (refined_res_.empty() || refined_res_.size() < min_refine_size())) { + auto pair = base_workspace_->Next(); + refined_res_.emplace(pair.first, raw_distance(pair.first) * sign_); + } + } + + private: + bool initialized_ = false; + std::shared_ptr refine_offset_index_ = nullptr; + std::unique_ptr copied_query_ = nullptr; + IndexNode::IteratorPtr base_workspace_ = nullptr; + std::unique_ptr refine_computer_ = nullptr; + }; + bool is_cosine_; + ViewDataOp view_data_op_; + std::shared_ptr + refine_offset_index_; // a data view flat index to maintain raw data without extra memory + std::unique_ptr base_index_; // base_index will hold data codes in memory, datatype is fp32 +}; + +namespace { +constexpr int64_t kBatchSize = 4096; +constexpr int64_t kMaxTrainSize = 5000; +constexpr int64_t kRandomSeed = 1234; +constexpr const char* kIndexNodeSuffixWithDataViewRefiner = "_WITH_DV_REFINER"; + +template +inline DataSetPtr +GenBaseIndexFp32TrainDataSet(const DataSetPtr& src, bool is_cosine = false, + std::optional filling_dim = std::nullopt) { + DataSetPtr train_ds; + auto rows = src->GetRows(); + bool cosine_need_copy = false; + auto src_dim = src->GetDim(); + auto des_dim = filling_dim.value_or(src_dim); + assert(src_dim <= des_dim); + + if (rows <= kMaxTrainSize && des_dim == src_dim) { + train_ds = ConvertFromDataTypeIfNeeded(src); + if constexpr (std::is_same_v) { + cosine_need_copy = true; + } + } else { + auto train_rows = std::min(rows, kMaxTrainSize); + std::vector random_ids(rows); + faiss::rand_perm(random_ids.data(), rows, kRandomSeed); + + const DataType* src_data = (const DataType*)src->GetTensor(); + auto* des_data = new float[des_dim * train_rows]; + std::memset(des_data, 0, sizeof(float) * des_dim * train_rows); + for (auto i = 0; i < train_rows; i++) { + auto from_id = random_ids[i]; + auto to_id = i; + if constexpr (std::is_same_v) { + std::memcpy(des_data + to_id * des_dim, src_data + from_id * src_dim, sizeof(float) * src_dim); + } else { + for (auto d = 0; d < src_dim; d++) { + // todo: optimize it with simd + des_data[to_id * des_dim + d] = (fp32)src_data[from_id * src_dim + d]; + } + } + } + auto des = std::make_shared(); + des->SetRows(train_rows); + des->SetDim(des_dim); + des->SetTensor(des_data); + des->SetIsOwner(true); + train_ds = des; + } + if (is_cosine) { + if (cosine_need_copy) { + train_ds = std::get<0>(CopyAndNormalizeDataset(train_ds)); + } else { + NormalizeDataset(train_ds); + } + } + return train_ds; +} + +template +inline std::tuple> +ConvertToBaseIndexFp32DataSet(const DataSetPtr& src, bool is_cosine = false, + const std::optional start = std::nullopt, + const std::optional count = std::nullopt, + const std::optional filling_dim = std::nullopt) { + auto src_dim = src->GetDim(); + auto des_dim = filling_dim.value_or(src_dim); + auto fp32_ds = ConvertFromDataTypeIfNeeded(src, start, count, filling_dim); + if (is_cosine) { + if (std::is_same_v && src_dim == des_dim) { + return CopyAndNormalizeDataset(fp32_ds); + } else { + auto rows = fp32_ds->GetRows(); + auto norms_vec = NormalizeVecs((float*)fp32_ds->GetTensor(), rows, des_dim); + return std::make_tuple(fp32_ds, norms_vec); + } + } + return std::make_tuple(fp32_ds, std::vector()); +} +} // namespace + +template +Status +IndexNodeWithDataViewRefiner::Train(const DataSetPtr dataset, std::shared_ptr cfg) { + BaseConfig& base_cfg = static_cast(*cfg); + this->is_cosine_ = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE); + auto dim = dataset->GetDim(); + + // construct refiner + auto refine_metric = is_cosine_ ? metric::IP : base_cfg.metric_type.value(); + refine_offset_index_ = + std::make_unique(dim, datatype_v, refine_metric, this->view_data_op_, is_cosine_); + // construct quant index and train: + AdaptToBaseIndexConfig(cfg.get(), PARAM_TYPE::TRAIN, dim); + auto base_index_dim = dynamic_cast(cfg.get())->dim.value(); + + LOG_KNOWHERE_DEBUG_ << "Generate Base Index with dim: " << base_index_dim << std::endl; + auto fp32_train_ds = GenBaseIndexFp32TrainDataSet(dataset, this->is_cosine_, base_index_dim); + return base_index_->Train(fp32_train_ds, cfg); +} + +template +Status +IndexNodeWithDataViewRefiner::Add(const DataSetPtr dataset, std::shared_ptr cfg) { + auto rows = dataset->GetRows(); + auto dim = dataset->GetDim(); + AdaptToBaseIndexConfig(cfg.get(), PARAM_TYPE::TRAIN, dim); + Status add_stat; + for (auto blk_i = 0; blk_i < rows; blk_i += kBatchSize) { + auto blk_size = std::min(kBatchSize, rows - blk_i); + auto [base_ds, norms] = + ConvertToBaseIndexFp32DataSet(dataset, is_cosine_, blk_i, blk_size, base_index_->Dim()); + add_stat = base_index_->Add(base_ds, cfg); + try { + refine_offset_index_->Add(blk_size, nullptr, norms.data()); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "data view index inner error: " << e.what(); + return Status::internal_error; + } + + if (add_stat != Status::success) { + return add_stat; + } + } + return Status::success; +} + +template +expected +IndexNodeWithDataViewRefiner::Search(const DataSetPtr dataset, std::unique_ptr cfg, + const BitsetView& bitset) const { + if (this->base_index_ == nullptr || this->refine_offset_index_ == nullptr) { + LOG_KNOWHERE_WARNING_ << "search on empty index"; + return expected::Err(Status::empty_index, "index not is trained."); + } + BaseConfig& base_cfg = static_cast(*cfg); + auto nq = dataset->GetRows(); + auto dim = dataset->GetDim(); + auto topk = base_cfg.k.value(); + // basic search + AdaptToBaseIndexConfig(cfg.get(), PARAM_TYPE::SEARCH, dim); + auto base_index_ds = std::get<0>( + ConvertToBaseIndexFp32DataSet(dataset, is_cosine_, std::nullopt, std::nullopt, base_index_->Dim())); + auto quant_res = base_index_->Search(base_index_ds, std::move(cfg), bitset); + if (!quant_res.has_value()) { + return quant_res; + } + // refine + auto queries_lims = std::vector(nq + 1); + auto reorder_k = quant_res.value()->GetDim(); + for (auto i = 0; i < nq + 1; i++) { + queries_lims[i] = reorder_k * i; + } + auto refine_ids = quant_res.value()->GetIds(); + auto labels = std::make_unique(nq * topk); + auto distances = std::make_unique(nq * topk); + try { + refine_offset_index_->SearchWithIds(nq, dataset->GetTensor(), queries_lims.data(), refine_ids, topk, + distances.get(), labels.get()); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "data view index inner error: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } + return GenResultDataSet(nq, topk, std::move(labels), std::move(distances)); +} + +template +expected +IndexNodeWithDataViewRefiner::RangeSearch(const DataSetPtr dataset, + std::unique_ptr cfg, + const BitsetView& bitset) const { + if (this->base_index_ == nullptr || this->refine_offset_index_ == nullptr) { + LOG_KNOWHERE_WARNING_ << "search on empty index"; + return expected::Err(Status::empty_index, "index not is trained."); + } + const BaseConfig& base_cfg = static_cast(*cfg); + auto nq = dataset->GetRows(); + auto dim = dataset->GetDim(); + auto radius = base_cfg.radius.value(); + auto range_filter = base_cfg.range_filter.value(); + AdaptToBaseIndexConfig(cfg.get(), PARAM_TYPE::RANGE_SEARCH, dim); + auto base_index_ds = std::get<0>( + ConvertToBaseIndexFp32DataSet(dataset, is_cosine_, std::nullopt, std::nullopt, base_index_->Dim())); + auto quant_res = base_index_->RangeSearch(base_index_ds, std::move(cfg), bitset); + if (!quant_res.has_value()) { + return quant_res; + } + auto quant_res_ids = quant_res.value()->GetIds(); + auto quant_res_lims = quant_res.value()->GetLims(); + try { + auto final_res = + refine_offset_index_->RangeSearchWithIds(nq, dataset->GetTensor(), (const knowhere::idx_t*)quant_res_lims, + (const knowhere::idx_t*)quant_res_ids, radius, range_filter); + return GenResultDataSet(nq, std::move(final_res)); + } catch (const std::exception& e) { + LOG_KNOWHERE_WARNING_ << "data view index inner error: " << e.what(); + return expected::Err(Status::faiss_inner_error, e.what()); + } +} + +template +expected> +IndexNodeWithDataViewRefiner::AnnIterator(const DataSetPtr dataset, + std::unique_ptr cfg, + const BitsetView& bitset, + bool use_knowhere_search_pool) const { + if (this->base_index_ == nullptr || this->refine_offset_index_ == nullptr) { + LOG_KNOWHERE_WARNING_ << "search on empty index"; + return expected>::Err(Status::empty_index, "index not is trained."); + } + + auto dim = dataset->GetDim(); + auto nq = dataset->GetRows(); + auto data = dataset->GetTensor(); + AdaptToBaseIndexConfig(cfg.get(), PARAM_TYPE::ITERATOR, dim); + const auto& base_cfg = static_cast(*cfg); + auto refine_ratio = base_cfg.iterator_refine_ratio.value(); + auto larger_is_closer = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::IP) || is_cosine_; + auto base_index_ds = std::get<0>( + ConvertToBaseIndexFp32DataSet(dataset, is_cosine_, std::nullopt, std::nullopt, base_index_->Dim())); + auto base_index_init = base_index_->AnnIterator(base_index_ds, std::move(cfg), bitset, use_knowhere_search_pool); + if (!base_index_init.has_value()) { + return base_index_init; + } + auto base_workspace_iters = base_index_init.value(); + if (base_workspace_iters.size() != nq) { + return expected>::Err( + Status::internal_error, "quant workspace is not equal to the rows count of input dataset."); + } + auto vec = std::vector(nq, nullptr); + for (auto i = 0; i < nq; i++) { + auto cur_query = (const DataType*)data + i * dim; + std::unique_ptr copied_query = nullptr; + copied_query = std::make_unique(dim); + std::copy_n(cur_query, dim, copied_query.get()); + vec[i] = std::shared_ptr(new iterator(this->refine_offset_index_, base_workspace_iters[i], + std::move(copied_query), larger_is_closer, refine_ratio)); + } + return vec; +} + +template +std::string +IndexNodeWithDataViewRefiner::Type() const { + return base_index_->Type() + kIndexNodeSuffixWithDataViewRefiner; +} + +} // namespace knowhere +#endif diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 2f933bfdd..84ade9295 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -291,32 +291,6 @@ class BaseFaissRegularIndexNode : public BaseFaissIndexNode { } }; -// -enum class DataFormatEnum { fp32, fp16, bf16, int8 }; - -template -struct DataType2EnumHelper {}; - -template <> -struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::fp32; -}; -template <> -struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::fp16; -}; -template <> -struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::bf16; -}; -template <> -struct DataType2EnumHelper { - static constexpr DataFormatEnum value = DataFormatEnum::int8; -}; - -template -static constexpr DataFormatEnum datatype_v = DataType2EnumHelper::value; - namespace { bool diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index 57307d8b1..35a75ac04 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -21,6 +21,7 @@ #include "faiss/IndexScaNN.h" #include "faiss/IndexScalarQuantizer.h" #include "faiss/index_io.h" +#include "index/data_view_dense_index/index_node_with_data_view_refiner.h" #include "index/ivf/ivf_config.h" #include "io/memory_io.h" #include "knowhere/bitsetview_idselector.h" @@ -558,12 +559,12 @@ IvfIndexNode::TrainInternal(const DataSetPtr dataset, std:: bool is_cosine = base_cfg.metric_type.value() == metric::COSINE; const bool use_elkan = scann_cfg.use_elkan.value_or(true); - + const int sub_dim = scann_cfg.sub_dim.value_or(2); // create quantizer for the training std::unique_ptr qzr = std::make_unique(dim, metric.value(), false, use_elkan); // create base index. it does not own qzr - auto base_index = std::make_unique(qzr.get(), dim, nlist, (dim + 1) / 2, 4, + auto base_index = std::make_unique(qzr.get(), dim, nlist, (dim + 1) / sub_dim, 4, is_cosine, metric.value()); // create scann index, which does not base_index by default, // but owns the refine index by default omg @@ -957,7 +958,6 @@ IvfIndexNode::AnnIterator(const DataSetPtr dataset, std::un size_t nprobe = ivf_cfg.nprobe.value(); // set iterator_refine_ratio = 0.0. If quantizer != flat, faiss:indexivf will not keep raw data; - // TODO: if SCANN support Iterator, iterator_refine_ratio should be set. float iterator_refine_ratio = 0.0f; if constexpr (std::is_same_v) { if (HasRawData(ivf_cfg.metric_type.value())) { @@ -1261,4 +1261,7 @@ KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ8, IvfIndexNode, knowhere::f KNOWHERE_MOCK_REGISTER_DENSE_FLOAT_ALL_GLOBAL(IVF_SQ_CC, IvfIndexNode, knowhere::feature::NONE, faiss::IndexIVFScalarQuantizerCC) +// faiss index + data view refiner combination +KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(SCANN_WITH_DV_REFINER, IndexNodeWithDataViewRefiner, + knowhere::feature::NONE, IvfIndexNode) } // namespace knowhere diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index f0f2dbbe9..524431535 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -102,6 +102,7 @@ class ScannConfig : public IvfFlatConfig { public: CFG_INT reorder_k; CFG_BOOL with_raw_data; + CFG_INT sub_dim; KNOHWERE_DECLARE_CONFIG(ScannConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(reorder_k) .description("reorder k used for refining") @@ -113,6 +114,11 @@ class ScannConfig : public IvfFlatConfig { .set_default(true) .for_static() .for_train(); + KNOWHERE_CONFIG_DECLARE_FIELD(sub_dim) + .description("sub dim of each sub dimension space") + .set_default(2) + .for_train() + .set_range(1, 65536); } Status @@ -122,9 +128,11 @@ class ScannConfig : public IvfFlatConfig { // TODO: handle odd dim with scann if (dim.has_value()) { int vec_dim = dim.value(); - if (vec_dim % 2 != 0) { - std::string msg = "The dimension of a vector (dim) should be a multiple of 2. Dimension:" + - std::to_string(vec_dim); + int vec_sub_dim = sub_dim.value(); + if (vec_dim % vec_sub_dim != 0) { + std::string msg = + "The dimension of a vector (dim) should be a multiple of sub_dim. Dimension:" + + std::to_string(vec_dim) + ", sub_dim:" + std::to_string(vec_sub_dim); return HandleError(err_msg, msg, Status::invalid_args); } } diff --git a/tests/ut/test_data_view_index.cc b/tests/ut/test_data_view_index.cc new file mode 100644 index 000000000..5d764025c --- /dev/null +++ b/tests/ut/test_data_view_index.cc @@ -0,0 +1,272 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include "catch2/catch_approx.hpp" +#include "catch2/catch_test_macros.hpp" +#include "catch2/generators/catch_generators.hpp" +#include "knowhere/bitsetview.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_check.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/index/index_factory.h" +#include "knowhere/log.h" +#include "knowhere/object.h" +#include "simd/hook.h" +#include "utils.h" + +namespace { +constexpr float kKnnRecallThreshold = 0.6f; +constexpr float kBruteForceRecallThreshold = 0.95f; +constexpr int kCosineMaxMissNum = 5; +} // namespace + +TEST_CASE("Test SCANN v.s. SCANN with data view refiner", "[float metrics]") { + using Catch::Approx; + auto version = GenTestVersionList(); + if (!faiss::support_pq_fast_scan) { + SKIP("pass scann test"); + } + + const int64_t nb = 1000, nq = 10; + auto metric = GENERATE(as{}, knowhere::metric::COSINE, knowhere::metric::IP, knowhere::metric::L2); + auto topk = GENERATE(as{}, 5, 120); + auto dim = GENERATE(as{}, 120); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99; + json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01; + return json; + }; + + auto scann_gen = [base_gen, topk]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 12; + json[knowhere::indexparam::REORDER_K] = topk * 4; + json[knowhere::indexparam::SUB_DIM] = 2; + json[knowhere::indexparam::WITH_RAW_DATA] = true; + json[knowhere::indexparam::ENSURE_TOPK_FULL] = true; + return json; + }; + + auto rand = GENERATE(1, 2); + const auto train_ds = GenDataSet(nb, dim, rand); + const auto query_ds = GenDataSet(nq, dim, rand + 777); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + knowhere::ViewDataOp data_view = [&train_ds, data_size = sizeof(float) * dim](size_t id) { + auto data = train_ds->GetTensor(); + return data + data_size * id; + }; + auto data_view_pack = knowhere::Pack(data_view); + SECTION("Accuraccy with refine") { + auto cfg_json = scann_gen().dump(); + knowhere::Json json = knowhere::Json::parse(cfg_json); + + auto scann_with_dv_refiner = + knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, data_view_pack) + .value(); + auto scann = knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS_SCANN, version) + .value(); + + REQUIRE(scann_with_dv_refiner.Type() == knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER); + REQUIRE(scann_with_dv_refiner.Build(train_ds, json) == knowhere::Status::success); + REQUIRE(scann.Build(train_ds, json) == knowhere::Status::success); + REQUIRE(scann_with_dv_refiner.Count() == nb); + REQUIRE(scann_with_dv_refiner.Size() > 0); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == false); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == + knowhere::IndexStaticFaced::HasRawData( + knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, cfg_json)); + + SECTION("knn search") { + auto scann_with_dv_refiner_results = scann_with_dv_refiner.Search(query_ds, json, nullptr); + auto scann_results = scann.Search(query_ds, json, nullptr); + REQUIRE(scann_with_dv_refiner_results.has_value()); + REQUIRE(scann_results.has_value()); + float recall1 = GetKNNRecall(*gt.value(), *scann_with_dv_refiner_results.value()); + float recall2 = GetKNNRecall(*gt.value(), *scann_results.value()); + REQUIRE(recall1 == recall2); + REQUIRE(recall1 > kKnnRecallThreshold); + REQUIRE(recall2 > kKnnRecallThreshold); + + if (metric == knowhere::metric::COSINE) { + REQUIRE(CheckDistanceInScope(*scann_with_dv_refiner_results.value(), topk, -1.00001, 1.00001)); + } + + auto scann_with_dv_ids = scann_with_dv_refiner_results.value()->GetIds(); + auto scann_ids = scann_results.value()->GetIds(); + auto scann_with_dv_dis = scann_with_dv_refiner_results.value()->GetDistance(); + auto scann_dis = scann_with_dv_refiner_results.value()->GetDistance(); + + if (scann.HasRawData(metric)) { + if (metric == knowhere::metric::COSINE) { + // cosine distances have a little different + auto miss_counter = 0; + for (auto i = 0; i < nq * topk; i++) { + if (scann_with_dv_ids[i] != scann_ids[i]) { + miss_counter++; + } + REQUIRE(std::abs((scann_with_dv_dis[i] - scann_dis[i]) / scann_dis[i]) < 0.00001); + } + REQUIRE(miss_counter < kCosineMaxMissNum); + } else { + for (auto i = 0; i < nq * topk; i++) { + REQUIRE(scann_with_dv_ids[i] == scann_ids[i]); + REQUIRE(scann_with_dv_dis[i] == scann_dis[i]); + } + } + } + } + + SECTION("range search") { + auto scann_results = scann.RangeSearch(query_ds, json, nullptr); + auto scann_with_dv_refiner_results = scann_with_dv_refiner.RangeSearch(query_ds, json, nullptr); + REQUIRE(scann_with_dv_refiner_results.has_value() & scann_results.has_value()); + auto scann_with_dv_ids = scann_with_dv_refiner_results.value()->GetIds(); + auto scann_with_dv_lims = scann_with_dv_refiner_results.value()->GetLims(); + auto scann_ids = scann_results.value()->GetIds(); + auto scann_lims = scann_results.value()->GetLims(); + if (scann.HasRawData(metric)) { + for (auto i = 1; i < nq + 1; i++) { + REQUIRE(scann_lims[i] == scann_with_dv_lims[i]); + } + for (size_t i = 0; i < scann_lims[nq]; i++) { + REQUIRE(scann_with_dv_ids[i] == scann_ids[i]); + } + } + } + + SECTION("knn search with bitset") { + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = {0.22f, 0.98f}; + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, percentage * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + auto scann_with_dv_refiner_results = scann_with_dv_refiner.Search(query_ds, json, bitset); + auto scann_results = scann.Search(query_ds, json, bitset); + auto gt_with_filter = + knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + REQUIRE(scann_results.has_value() & scann_with_dv_refiner_results.has_value()); + float recall1 = GetKNNRecall(*gt_with_filter.value(), *scann_with_dv_refiner_results.value()); + float recall2 = GetKNNRecall(*gt_with_filter.value(), *scann_results.value()); + REQUIRE(recall1 == recall2); + REQUIRE(recall1 > kKnnRecallThreshold); + REQUIRE(recall2 > kKnnRecallThreshold); + if (metric == knowhere::metric::COSINE) { + REQUIRE(CheckDistanceInScope(*scann_with_dv_refiner_results.value(), topk, -1.00001, 1.00001)); + } + } + } + } + } +} + +template +void +BaseTest(const knowhere::DataSetPtr train_ds, const knowhere::DataSetPtr query_ds, const int64_t k, + const knowhere::MetricType metric, const knowhere::Json& conf) { + auto version = knowhere::Version::GetCurrentVersion().VersionNumber(); + auto base = knowhere::ConvertToDataTypeIfNeeded(train_ds); + auto query = knowhere::ConvertToDataTypeIfNeeded(query_ds); + auto dim = base->GetDim(); + auto nb = base->GetRows(); + auto nq = query->GetRows(); + + auto knn_gt = knowhere::BruteForce::Search(base, query, conf, nullptr); + knowhere::ViewDataOp data_view = [&base, data_size = sizeof(DataType) * dim](size_t id) { + auto data = base->GetTensor(); + return data + data_size * id; + }; + auto data_view_pack = knowhere::Pack(data_view); + auto scann_with_dv_refiner = + knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, data_view_pack) + .value(); + + REQUIRE(scann_with_dv_refiner.Type() == knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER); + REQUIRE(scann_with_dv_refiner.Build(base, conf) == knowhere::Status::success); + + REQUIRE(scann_with_dv_refiner.Size() > 0); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == false); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == + knowhere::IndexStaticFaced::HasRawData( + knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, conf)); + + // knn search + auto scann_with_dv_refiner_results = scann_with_dv_refiner.Search(query, conf, nullptr); + REQUIRE(scann_with_dv_refiner_results.has_value()); + float recall = GetKNNRecall(*knn_gt.value(), *scann_with_dv_refiner_results.value()); + REQUIRE(recall > kKnnRecallThreshold); + if (metric == knowhere::metric::COSINE) { + REQUIRE(CheckDistanceInScope(*scann_with_dv_refiner_results.value(), k, -1.00001, 1.00001)); + } + // range search + auto scann_with_dv_refiner_range_results = scann_with_dv_refiner.RangeSearch(query, conf, nullptr); + REQUIRE(scann_with_dv_refiner_range_results.has_value()); + auto scann_with_dv_ids = scann_with_dv_refiner_range_results.value()->GetIds(); + auto scann_with_dv_lims = scann_with_dv_refiner_range_results.value()->GetLims(); + if (metric == knowhere::metric::L2 || metric == knowhere::metric::COSINE) { + for (int i = 0; i < nq; ++i) { + CHECK(scann_with_dv_ids[scann_with_dv_lims[i]] == i); + } + } +} + +TEST_CASE("Test difference dim with difference data type", "[multi metrics]") { + if (!faiss::support_pq_fast_scan) { + SKIP("pass scann test"); + } + const int64_t nb = 1000, nq = 10; + auto metric = GENERATE(as{}, knowhere::metric::COSINE, knowhere::metric::IP, knowhere::metric::L2); + auto topk = GENERATE(as{}, 10); + auto dim = GENERATE(as{}, 31, 128, 511, 1024); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.9; + json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01; + return json; + }; + + auto scann_gen = [base_gen, topk]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 24; + json[knowhere::indexparam::NPROBE] = 16; + json[knowhere::indexparam::REORDER_K] = topk * 4; + json[knowhere::indexparam::SUB_DIM] = 2; + return json; + }; + + const auto train_ds = GenDataSet(nb, dim); + const auto query_ds = GenDataSet(nq, dim); + auto cfg_json = scann_gen().dump(); + knowhere::Json json = knowhere::Json::parse(cfg_json); + BaseTest(train_ds, query_ds, topk, metric, json); + BaseTest(train_ds, query_ds, topk, metric, json); + BaseTest(train_ds, query_ds, topk, metric, json); +} diff --git a/tests/ut/test_iterator.cc b/tests/ut/test_iterator.cc index 611996b1a..b8f413016 100644 --- a/tests/ut/test_iterator.cc +++ b/tests/ut/test_iterator.cc @@ -26,6 +26,7 @@ namespace { constexpr float kKnnRecallThreshold = 0.8f; +constexpr float kKnnRecallWithBfThreshold = 0.6f; knowhere::DataSetPtr GetIteratorKNNResult(const std::vector>& iterators, int k, @@ -621,3 +622,99 @@ TEST_CASE("Test Iterator BruteForce With Sparse Float Vector", "[IP metric]") { } } } + +TEST_CASE("Test Scann with data view refiner", "[float metrics]") { + using Catch::Approx; + if (!faiss::support_pq_fast_scan) { + SKIP("pass scann test"); + } + auto version = GenTestVersionList(); + + const int64_t nb = 1000, nq = 10; + auto metric = GENERATE(as{}, knowhere::metric::IP, knowhere::metric::COSINE, knowhere::metric::L2); + auto topk = GENERATE(as{}, 10); + auto dim = GENERATE(as{}, 120); + + auto base_gen = [=]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::RADIUS] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 10.0 : 0.99; + json[knowhere::meta::RANGE_FILTER] = knowhere::IsMetricType(metric, knowhere::metric::L2) ? 0.0 : 1.01; + return json; + }; + + auto scann_gen = [base_gen, topk]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::NLIST] = 16; + json[knowhere::indexparam::NPROBE] = 14; + json[knowhere::indexparam::REORDER_K] = topk * 4; + json[knowhere::indexparam::SUB_DIM] = 2; + json[knowhere::indexparam::WITH_RAW_DATA] = true; + return json; + }; + + auto rand = GENERATE(1, 2); + const auto train_ds = GenDataSet(nb, dim, rand); + const auto query_ds = GenDataSet(nq, dim, rand + 777); + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, + {knowhere::meta::TOPK, topk}, + }; + auto gt = knowhere::BruteForce::Search(train_ds, query_ds, conf, nullptr); + knowhere::ViewDataOp data_view = [&train_ds, data_size = sizeof(float) * dim](size_t id) { + auto data = train_ds->GetTensor(); + return data + data_size * id; + }; + auto data_view_pack = knowhere::Pack(data_view); + auto cfg_json = scann_gen().dump(); + knowhere::Json json = knowhere::Json::parse(cfg_json); + + auto scann_with_dv_refiner = + knowhere::IndexFactory::Instance() + .Create(knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, data_view_pack) + .value(); + + REQUIRE(scann_with_dv_refiner.Type() == knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER); + REQUIRE(scann_with_dv_refiner.Build(train_ds, json) == knowhere::Status::success); + REQUIRE(scann_with_dv_refiner.Count() == nb); + REQUIRE(scann_with_dv_refiner.Size() > 0); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == false); + REQUIRE(scann_with_dv_refiner.HasRawData(metric) == + knowhere::IndexStaticFaced::HasRawData( + knowhere::IndexEnum::INDEX_FAISS_SCANN_WITH_DV_REFINER, version, cfg_json)); + + SECTION("iterator without bitset") { + auto scann_with_dv_its = scann_with_dv_refiner.AnnIterator(query_ds, json, nullptr); + REQUIRE(scann_with_dv_its.has_value()); + auto scann_with_ds_iterator_results = GetIteratorKNNResult(scann_with_dv_its.value(), topk); + + auto search_results = scann_with_dv_refiner.Search(query_ds, json, nullptr); + bool dist_less_better = knowhere::IsMetricType(metric, knowhere::metric::L2); + float recall = GetKNNRelativeRecall(*search_results.value(), *scann_with_ds_iterator_results, dist_less_better); + REQUIRE(recall > kKnnRecallThreshold); + } + SECTION("iterator with bitset") { + std::vector(size_t, size_t)>> gen_bitset_funcs = { + GenerateBitsetWithFirstTbitsSet, GenerateBitsetWithRandomTbitsSet}; + const auto bitset_percentages = {0.4f, 0.98f}; + for (const float percentage : bitset_percentages) { + for (const auto& gen_func : gen_bitset_funcs) { + auto bitset_data = gen_func(nb, percentage * nb); + knowhere::BitsetView bitset(bitset_data.data(), nb); + // Iterator doesn't have a fallback to bruteforce mechanism at high filter rate. + auto its = scann_with_dv_refiner.AnnIterator(query_ds, json, bitset); + REQUIRE(its.has_value()); + + auto iterator_results = GetIteratorKNNResult(its.value(), topk); + auto search_results = knowhere::BruteForce::Search(train_ds, query_ds, json, bitset); + REQUIRE(search_results.has_value()); + bool dist_less_better = knowhere::IsMetricType(metric, knowhere::metric::L2); + float recall = GetKNNRelativeRecall(*search_results.value(), *iterator_results, dist_less_better); + REQUIRE(recall > kKnnRecallWithBfThreshold); + } + } + } +} diff --git a/tests/ut/test_utils.cc b/tests/ut/test_utils.cc index 939a31d2a..ec23e0d17 100644 --- a/tests/ut/test_utils.cc +++ b/tests/ut/test_utils.cc @@ -33,9 +33,7 @@ CheckNormalizeDataset(int rows, int dim, float diff) { auto ds = GenDataSet(rows, dim); auto type_ds = knowhere::ConvertToDataTypeIfNeeded(ds); auto data = (T*)type_ds->GetTensor(); - knowhere::NormalizeDataset(type_ds); - for (int i = 0; i < rows; ++i) { float sum = 0.0; for (int j = 0; j < dim; ++j) { diff --git a/tests/ut/utils.h b/tests/ut/utils.h index 3bbc3d794..f7d449855 100644 --- a/tests/ut/utils.h +++ b/tests/ut/utils.h @@ -249,7 +249,7 @@ GetRelativeLoss(float gt_res, float res) { inline bool CheckDistanceInScope(const knowhere::DataSet& result, int topk, float low_bound, float high_bound) { - auto ids = result.GetDistance(); + auto ids = result.GetIds(); auto distances = result.GetDistance(); auto rows = result.GetRows(); for (int i = 0; i < rows; ++i) { @@ -267,7 +267,7 @@ CheckDistanceInScope(const knowhere::DataSet& result, int topk, float low_bound, inline bool CheckDistanceInScope(const knowhere::DataSet& result, float low_bound, float high_bound) { - auto ids = result.GetDistance(); + auto ids = result.GetIds(); auto distances = result.GetDistance(); auto lims = result.GetLims(); auto rows = result.GetRows(); diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp index d2eca4914..13083f1aa 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp @@ -51,6 +51,7 @@ IndexIVFFastScan::IndexIVFFastScan( FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); this->is_cosine = is_cosine; + mutex = std::make_shared(); } IndexIVFFastScan::IndexIVFFastScan() { @@ -58,6 +59,7 @@ IndexIVFFastScan::IndexIVFFastScan() { M2 = 0; is_trained = false; by_residual = false; + mutex = std::make_shared(); } void IndexIVFFastScan::init_fastscan( @@ -178,6 +180,7 @@ void IndexIVFFastScan::add_with_ids_impl( // TODO parallelize idx_t i0 = 0; while (i0 < n) { + std::unique_lock lock(*mutex.get()); idx_t list_no = idx[order[i0]]; idx_t i1 = i0 + 1; while (i1 < n && idx[order[i1]] == list_no) { @@ -361,6 +364,7 @@ void IndexIVFFastScan::search_preassigned( bool store_pairs, const IVFSearchParameters* params, IndexIVFStats* stats) const { + std::shared_lock lock(*mutex.get()); size_t nprobe = this->nprobe; if (params) { FAISS_THROW_IF_NOT(params->max_codes == 0); diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.h b/thirdparty/faiss/faiss/IndexIVFFastScan.h index 4ca5b5db8..2fa211c96 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.h +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include @@ -88,6 +89,8 @@ struct IndexIVFFastScan : IndexIVF { // // todo aguzhva: get rid of this std::vector norms; + std::shared_ptr mutex = nullptr; + IndexIVFFastScan( Index* quantizer, size_t d, diff --git a/thirdparty/faiss/faiss/IndexRefine.cpp b/thirdparty/faiss/faiss/IndexRefine.cpp index bbeaff0cc..9b4e9b9d6 100644 --- a/thirdparty/faiss/faiss/IndexRefine.cpp +++ b/thirdparty/faiss/faiss/IndexRefine.cpp @@ -61,36 +61,6 @@ void IndexRefine::reset() { ntotal = 0; } -namespace { - -using idx_t = faiss::idx_t; - -template -static void reorder_2_heaps( - idx_t n, - idx_t k, - idx_t* __restrict labels, - float* __restrict distances, - idx_t k_base, - const idx_t* __restrict base_labels, - const float* __restrict base_distances) { -#pragma omp parallel for if (n > 1) - for (idx_t i = 0; i < n; i++) { - idx_t* idxo = labels + i * k; - float* diso = distances + i * k; - const idx_t* idxi = base_labels + i * k_base; - const float* disi = base_distances + i * k_base; - - heap_heapify(k, diso, idxo, disi, idxi, k); - if (k_base != k) { // add remaining elements - heap_addn(k, diso, idxo, disi + k, idxi + k, k_base - k); - } - heap_reorder(k, diso, idxo); - } -} - -} // anonymous namespace - void IndexRefine::search( idx_t n, const float* x, @@ -177,12 +147,12 @@ void IndexRefine::search( // sort and store result if (metric_type == METRIC_L2) { typedef CMax C; - reorder_2_heaps( + reorder_2_heaps( n, k, labels, distances, k_base, base_labels, base_distances); } else if (metric_type == METRIC_INNER_PRODUCT) { typedef CMin C; - reorder_2_heaps( + reorder_2_heaps( n, k, labels, distances, k_base, base_labels, base_distances); } else { FAISS_THROW_MSG("Metric type not supported"); @@ -194,8 +164,7 @@ void IndexRefine::range_search( const float* x, float radius, RangeSearchResult* result, - const SearchParameters* params_in) const -{ + const SearchParameters* params_in) const { const IndexRefineSearchParameters* params = nullptr; if (params_in) { params = dynamic_cast(params_in); @@ -206,8 +175,7 @@ void IndexRefine::range_search( SearchParameters* base_index_params = (params != nullptr) ? params->base_index_params : nullptr; - base_index->range_search( - n, x, radius, result, base_index_params); + base_index->range_search(n, x, radius, result, base_index_params); #pragma omp parallel if (n > 1) { @@ -349,12 +317,12 @@ void IndexRefineFlat::search( // sort and store result if (metric_type == METRIC_L2) { typedef CMax C; - reorder_2_heaps( + reorder_2_heaps( n, k, labels, distances, k_base, base_labels, base_distances); } else if (metric_type == METRIC_INNER_PRODUCT) { typedef CMin C; - reorder_2_heaps( + reorder_2_heaps( n, k, labels, distances, k_base, base_labels, base_distances); } else { FAISS_THROW_MSG("Metric type not supported"); diff --git a/thirdparty/faiss/faiss/utils/Heap.cpp b/thirdparty/faiss/faiss/utils/Heap.cpp index 1907a0b1c..92551802b 100644 --- a/thirdparty/faiss/faiss/utils/Heap.cpp +++ b/thirdparty/faiss/faiss/utils/Heap.cpp @@ -247,4 +247,50 @@ INSTANTIATE(CMax, float); INSTANTIATE(CMin, int32_t); INSTANTIATE(CMax, int32_t); +/********************************************************** + * reorder_2_heaps + **********************************************************/ +/** reduce two results: k_base result and + * k result to k result + */ +template +void reorder_2_heaps( + size_t n, + size_t k, + idx_t* __restrict labels, + float* __restrict distances, + size_t k_base, + const idx_t* __restrict base_labels, + const float* __restrict base_distances) { +#pragma omp parallel for if (n > 1) + for (size_t i = 0; i < n; i++) { + idx_t* idxo = labels + i * k; + float* diso = distances + i * k; + const idx_t* idxi = base_labels + i * k_base; + const float* disi = base_distances + i * k_base; + + heap_heapify(k, diso, idxo, disi, idxi, k); + if (k_base != k) { // add remaining elements + heap_addn(k, diso, idxo, disi + k, idxi + k, k_base - k); + } + heap_reorder(k, diso, idxo); + } +} +template void reorder_2_heaps>( + size_t n, + size_t k, + int64_t* __restrict labels, + float* __restrict distances, + size_t k_base, + const int64_t* __restrict base_labels, + const float* __restrict base_distances); + +template void reorder_2_heaps>( + size_t n, + size_t k, + int64_t* __restrict labels, + float* __restrict distances, + size_t k_base, + const int64_t* __restrict base_labels, + const float* __restrict base_distances); } // namespace faiss diff --git a/thirdparty/faiss/faiss/utils/Heap.h b/thirdparty/faiss/faiss/utils/Heap.h index b67707ecb..b214f10ee 100644 --- a/thirdparty/faiss/faiss/utils/Heap.h +++ b/thirdparty/faiss/faiss/utils/Heap.h @@ -631,6 +631,16 @@ void merge_knn_results( typename C::T* distances, idx_t* labels); +template +void reorder_2_heaps( + size_t n, + size_t k, + idx_t* __restrict labels, + float* __restrict distances, + size_t k_base, + const idx_t* __restrict base_labels, + const float* __restrict base_distances); + } // namespace faiss #endif /* FAISS_Heap_h */