diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 8525c5ad7..4051962c1 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -57,6 +57,8 @@ constexpr const char* INDEX_DISKANN = "DISKANN"; constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX"; constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND"; +constexpr const char* INDEX_SPARSE_INVERTED_INDEX_CC = "SPARSE_INVERTED_INDEX_CC"; +constexpr const char* INDEX_SPARSE_WAND_CC = "SPARSE_WAND_CC"; } // namespace IndexEnum namespace ClusterEnum { diff --git a/include/knowhere/expected.h b/include/knowhere/expected.h index 10a5f41b1..0c46df4cd 100644 --- a/include/knowhere/expected.h +++ b/include/knowhere/expected.h @@ -48,6 +48,7 @@ enum class Status { timeout = 26, internal_error = 27, invalid_serialized_index_type = 28, + sparse_inner_error = 29, }; inline std::string @@ -101,6 +102,8 @@ Status2String(knowhere::Status status) { return "internal error (something that must not have happened at all)"; case knowhere::Status::invalid_serialized_index_type: return "the serialized index type is not recognized"; + case knowhere::Status::sparse_inner_error: + return "sparse index inner error"; default: return "unexpected status"; } diff --git a/src/index/sparse/sparse_index_node.cc b/src/index/sparse/sparse_index_node.cc index 6e47b6c59..0df4de7b4 100644 --- a/src/index/sparse/sparse_index_node.cc +++ b/src/index/sparse/sparse_index_node.cc @@ -30,14 +30,16 @@ namespace knowhere { // Inverted Index impl for sparse vectors. May optionally use WAND algorithm to speed up search. // -// Not overriding RangeSerach, will use the default implementation in IndexNode. +// Not overriding RangeSearch, will use the default implementation in IndexNode. +// +// Thread safety: not thread safe. template class SparseInvertedIndexNode : public IndexNode { static_assert(std::is_same_v, "SparseInvertedIndexNode only support float"); public: explicit SparseInvertedIndexNode(const int32_t& /*version*/, const Object& /*object*/) - : search_pool_(ThreadPool::GetGlobalSearchThreadPool()) { + : search_pool_(ThreadPool::GetGlobalSearchThreadPool()), build_pool_(ThreadPool::GetGlobalBuildThreadPool()) { } ~SparseInvertedIndexNode() override { @@ -74,8 +76,17 @@ class SparseInvertedIndexNode : public IndexNode { LOG_KNOWHERE_ERROR_ << "Could not add data to empty " << Type(); return Status::empty_index; } - return index_->Add(static_cast*>(dataset->GetTensor()), dataset->GetRows(), - dataset->GetDim()); + auto tryObj = build_pool_ + ->push([&] { + return index_->Add(static_cast*>(dataset->GetTensor()), + dataset->GetRows(), dataset->GetDim()); + }) + .getTry(); + if (!tryObj.hasValue()) { + LOG_KNOWHERE_WARNING_ << "failed to add data to index " << Type() << ": " << tryObj.exception().what(); + return Status::sparse_inner_error; + } + return tryObj.value(); } [[nodiscard]] expected @@ -316,14 +327,127 @@ class SparseInvertedIndexNode : public IndexNode { sparse::BaseInvertedIndex* index_{}; std::shared_ptr search_pool_; + std::shared_ptr build_pool_; // if map_ is not nullptr, it means the index is mmapped from disk. char* map_ = nullptr; size_t map_size_ = 0; }; // class SparseInvertedIndexNode +// Concurrent version of SparseInvertedIndexNode +// +// Thread safety: only the overridden methods are allowed to be called concurrently. +template +class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode { + public: + explicit SparseInvertedIndexNodeCC(const int32_t& version, const Object& object) + : SparseInvertedIndexNode(version, object) { + } + + Status + Add(const DataSetPtr dataset, std::shared_ptr cfg) override { + std::unique_lock lock(mutex_); + uint64_t task_id = next_task_id_++; + add_tasks_.push(task_id); + + // add task is allowed to run only after all search tasks that come before it have finished. + cv_.wait(lock, [this, task_id]() { return current_task_id_ == task_id && active_readers_ == 0; }); + + auto res = SparseInvertedIndexNode::Add(dataset, cfg); + + add_tasks_.pop(); + current_task_id_++; + lock.unlock(); + cv_.notify_all(); + return res; + } + + expected + Search(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::Search(dataset, std::move(cfg), bitset); + } + + expected> + AnnIterator(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::AnnIterator(dataset, std::move(cfg), bitset); + } + + expected + RangeSearch(const DataSetPtr dataset, std::unique_ptr cfg, const BitsetView& bitset) const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::RangeSearch(dataset, std::move(cfg), bitset); + } + + expected + GetVectorByIds(const DataSetPtr dataset) const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::GetVectorByIds(dataset); + } + + int64_t + Dim() const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::Dim(); + } + + int64_t + Size() const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::Size(); + } + + int64_t + Count() const override { + ReadPermission permission(*this); + return SparseInvertedIndexNode::Count(); + } + + std::string + Type() const override { + return use_wand ? knowhere::IndexEnum::INDEX_SPARSE_WAND_CC + : knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC; + } + + private: + struct ReadPermission { + ReadPermission(const SparseInvertedIndexNodeCC& node) : node_(node) { + std::unique_lock lock(node_.mutex_); + uint64_t task_id = node_.next_task_id_++; + // read task may execute only after all add tasks that come before it have finished. + if (!node_.add_tasks_.empty() && task_id > node_.add_tasks_.front()) { + node_.cv_.wait( + lock, [this, task_id]() { return node_.add_tasks_.empty() || task_id < node_.add_tasks_.front(); }); + } + // read task is allowed to run, block all add tasks + node_.active_readers_++; + } + + ~ReadPermission() { + std::unique_lock lock(node_.mutex_); + node_.active_readers_--; + node_.current_task_id_++; + node_.cv_.notify_all(); + } + const SparseInvertedIndexNodeCC& node_; + }; + + mutable std::mutex mutex_; + mutable std::condition_variable cv_; + mutable int64_t active_readers_ = 0; + mutable std::queue add_tasks_; + mutable uint64_t next_task_id_ = 0; + mutable uint64_t current_task_id_ = 0; +}; // class SparseInvertedIndexNodeCC + KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP, /*use_wand=*/false) KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_WAND, SparseInvertedIndexNode, knowhere::feature::MMAP, /*use_wand=*/true) +KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX_CC, SparseInvertedIndexNodeCC, + knowhere::feature::MMAP, + /*use_wand=*/false) +KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_WAND_CC, SparseInvertedIndexNodeCC, knowhere::feature::MMAP, + /*use_wand=*/true) } // namespace knowhere diff --git a/tests/ut/test_sparse.cc b/tests/ut/test_sparse.cc index 88e79d2c9..99743c7ac 100644 --- a/tests/ut/test_sparse.cc +++ b/tests/ut/test_sparse.cc @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include +#include #include "catch2/catch_test_macros.hpp" #include "catch2/generators/catch_generators.hpp" @@ -547,3 +548,120 @@ TEST_CASE("Test Mem Sparse Index Handle Empty Vector", "[float metrics]") { } } } + +TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") { + std::atomic value_base(0); + // each time a new batch of vectors are generated, the base value is increased by 1. + // also the sparse vectors are all full, so newly generated vectors are guaranteed + // to have larger IP than old vectors. + auto doc_vector_gen = [&](int32_t nb, int32_t dim) { + auto base = value_base.fetch_add(1); + std::vector> data(nb); + for (int32_t i = 0; i < nb; ++i) { + for (int32_t j = 0; j < dim; ++j) { + data[i][j] = base + static_cast(rand()) / RAND_MAX * 0.8 + 0.1; + } + } + return GenSparseDataSet(data, dim); + }; + + auto nb = 1000; + auto dim = 30; + auto topk = 50; + int64_t nq = 100; + + auto query_ds = doc_vector_gen(nq, dim); + + // drop ratio build is not supported in CC index + auto drop_ratio_build = 0.0; + auto drop_ratio_search = GENERATE(0.0, 0.3); + + auto metric = GENERATE(knowhere::metric::IP); + auto version = GenTestVersionList(); + + auto base_gen = [=, dim = dim]() { + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric; + json[knowhere::meta::TOPK] = topk; + json[knowhere::meta::BM25_K1] = 1.2; + json[knowhere::meta::BM25_B] = 0.75; + json[knowhere::meta::BM25_AVGDL] = 100; + return json; + }; + + auto sparse_inverted_index_gen = [base_gen, drop_ratio_build = drop_ratio_build, + drop_ratio_search = drop_ratio_search]() { + knowhere::Json json = base_gen(); + json[knowhere::indexparam::DROP_RATIO_BUILD] = drop_ratio_build; + json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search; + return json; + }; + + const knowhere::Json conf = { + {knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, {knowhere::meta::BM25_K1, 1.2}, + {knowhere::meta::BM25_B, 0.75}, {knowhere::meta::BM25_AVGDL, 100}, + }; + + // since all newly inserted vectors are guaranteed to have larger IP than old vectors, + // the result ids of each search requests shoule be from the same batch of inserted vectors. + auto check_result = [&](const knowhere::DataSet& ds) { + auto nq = ds.GetRows(); + auto k = ds.GetDim(); + auto* ids = ds.GetIds(); + auto expected_id_base = ids[0] / nb; + for (auto i = 0; i < nq; ++i) { + for (auto j = 0; j < k; ++j) { + auto base = ids[i * k + j] / nb; + REQUIRE(base == expected_id_base); + } + } + }; + + auto test_time = 10; + + SECTION("Test Search") { + using std::make_tuple; + auto [name, gen] = GENERATE_REF(table>({ + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen), + make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen), + })); + + auto idx = knowhere::IndexFactory::Instance().Create(name, version).value(); + auto cfg_json = gen().dump(); + CAPTURE(name, cfg_json); + knowhere::Json json = knowhere::Json::parse(cfg_json); + REQUIRE(idx.Type() == name); + // build the index with some initial data + REQUIRE(idx.Build(doc_vector_gen(nb, dim), json) == knowhere::Status::success); + + auto add_task = [&]() { + auto start = std::chrono::steady_clock::now(); + while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < + test_time) { + auto doc_ds = doc_vector_gen(nb, dim); + auto res = idx.Add(doc_ds, json); + REQUIRE(res == knowhere::Status::success); + } + }; + + auto search_task = [&]() { + auto start = std::chrono::steady_clock::now(); + while (std::chrono::duration_cast(std::chrono::steady_clock::now() - start).count() < + test_time) { + auto results = idx.Search(query_ds, json, nullptr); + REQUIRE(results.has_value()); + check_result(*results.value()); + } + }; + + std::vector> task_list; + for (int thread = 0; thread < 5; thread++) { + task_list.push_back(std::async(std::launch::async, search_task)); + } + task_list.push_back(std::async(std::launch::async, add_task)); + for (auto& task : task_list) { + task.wait(); + } + } +}