Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SparseInvertedIndexNodeCC #933

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
constexpr const char* INDEX_SPARSE_INVERTED_INDEX_CC = "SPARSE_INVERTED_INDEX_CC";
constexpr const char* INDEX_SPARSE_WAND_CC = "SPARSE_WAND_CC";
} // namespace IndexEnum

namespace ClusterEnum {
Expand Down
3 changes: 3 additions & 0 deletions include/knowhere/expected.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ enum class Status {
timeout = 26,
internal_error = 27,
invalid_serialized_index_type = 28,
sparse_inner_error = 29,
};

inline std::string
Expand Down Expand Up @@ -101,6 +102,8 @@ Status2String(knowhere::Status status) {
return "internal error (something that must not have happened at all)";
case knowhere::Status::invalid_serialized_index_type:
return "the serialized index type is not recognized";
case knowhere::Status::sparse_inner_error:
return "sparse index inner error";
default:
return "unexpected status";
}
Expand Down
132 changes: 128 additions & 4 deletions src/index/sparse/sparse_index_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ namespace knowhere {

// Inverted Index impl for sparse vectors. May optionally use WAND algorithm to speed up search.
//
// Not overriding RangeSerach, will use the default implementation in IndexNode.
// Not overriding RangeSearch, will use the default implementation in IndexNode.
//
// Thread safety: not thread safe.
template <typename T, bool use_wand>
class SparseInvertedIndexNode : public IndexNode {
static_assert(std::is_same_v<T, fp32>, "SparseInvertedIndexNode only support float");

public:
explicit SparseInvertedIndexNode(const int32_t& /*version*/, const Object& /*object*/)
: search_pool_(ThreadPool::GetGlobalSearchThreadPool()) {
: search_pool_(ThreadPool::GetGlobalSearchThreadPool()), build_pool_(ThreadPool::GetGlobalBuildThreadPool()) {
}

~SparseInvertedIndexNode() override {
Expand Down Expand Up @@ -74,8 +76,17 @@ class SparseInvertedIndexNode : public IndexNode {
LOG_KNOWHERE_ERROR_ << "Could not add data to empty " << Type();
return Status::empty_index;
}
return index_->Add(static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor()), dataset->GetRows(),
dataset->GetDim());
auto tryObj = build_pool_
->push([&] {
return index_->Add(static_cast<const sparse::SparseRow<T>*>(dataset->GetTensor()),
dataset->GetRows(), dataset->GetDim());
})
.getTry();
if (!tryObj.hasValue()) {
LOG_KNOWHERE_WARNING_ << "failed to add data to index " << Type() << ": " << tryObj.exception().what();
return Status::sparse_inner_error;
}
return tryObj.value();
}

[[nodiscard]] expected<DataSetPtr>
Expand Down Expand Up @@ -316,14 +327,127 @@ class SparseInvertedIndexNode : public IndexNode {

sparse::BaseInvertedIndex<T>* index_{};
std::shared_ptr<ThreadPool> search_pool_;
std::shared_ptr<ThreadPool> build_pool_;

// if map_ is not nullptr, it means the index is mmapped from disk.
char* map_ = nullptr;
size_t map_size_ = 0;
}; // class SparseInvertedIndexNode

// Concurrent version of SparseInvertedIndexNode
//
// Thread safety: only the overridden methods are allowed to be called concurrently.
template <typename T, bool use_wand>
class SparseInvertedIndexNodeCC : public SparseInvertedIndexNode<T, use_wand> {
public:
explicit SparseInvertedIndexNodeCC(const int32_t& version, const Object& object)
: SparseInvertedIndexNode<T, use_wand>(version, object) {
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
std::unique_lock<std::mutex> lock(mutex_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can std::shared_mutex<> from C++17 (https://en.cppreference.com/w/cpp/thread/shared_mutex) solve the problem of concurrent access for readers and writers in this particular use case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we want to avoid the starvation of writers. this impl guarantees that new read requests have to wait for already waiting write requests to finish.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm's, I'm assuming that you've checked that it won't deadlock

uint64_t task_id = next_task_id_++;
add_tasks_.push(task_id);

// add task is allowed to run only after all search tasks that come before it have finished.
cv_.wait(lock, [this, task_id]() { return current_task_id_ == task_id && active_readers_ == 0; });

auto res = SparseInvertedIndexNode<T, use_wand>::Add(dataset, cfg);

add_tasks_.pop();
current_task_id_++;
lock.unlock();
cv_.notify_all();
return res;
}

expected<DataSetPtr>
Search(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::Search(dataset, std::move(cfg), bitset);
}

expected<std::vector<IndexNode::IteratorPtr>>
AnnIterator(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::AnnIterator(dataset, std::move(cfg), bitset);
}

expected<DataSetPtr>
RangeSearch(const DataSetPtr dataset, std::unique_ptr<Config> cfg, const BitsetView& bitset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::RangeSearch(dataset, std::move(cfg), bitset);
}

expected<DataSetPtr>
GetVectorByIds(const DataSetPtr dataset) const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::GetVectorByIds(dataset);
}

int64_t
Dim() const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::Dim();
}

int64_t
Size() const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::Size();
}

int64_t
Count() const override {
ReadPermission permission(*this);
return SparseInvertedIndexNode<T, use_wand>::Count();
}

std::string
Type() const override {
return use_wand ? knowhere::IndexEnum::INDEX_SPARSE_WAND_CC
: knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC;
}

private:
struct ReadPermission {
ReadPermission(const SparseInvertedIndexNodeCC& node) : node_(node) {
std::unique_lock<std::mutex> lock(node_.mutex_);
uint64_t task_id = node_.next_task_id_++;
// read task may execute only after all add tasks that come before it have finished.
if (!node_.add_tasks_.empty() && task_id > node_.add_tasks_.front()) {
node_.cv_.wait(
lock, [this, task_id]() { return node_.add_tasks_.empty() || task_id < node_.add_tasks_.front(); });
}
// read task is allowed to run, block all add tasks
node_.active_readers_++;
}

~ReadPermission() {
std::unique_lock<std::mutex> lock(node_.mutex_);
node_.active_readers_--;
node_.current_task_id_++;
node_.cv_.notify_all();
}
const SparseInvertedIndexNodeCC& node_;
};

mutable std::mutex mutex_;
mutable std::condition_variable cv_;
mutable int64_t active_readers_ = 0;
mutable std::queue<uint64_t> add_tasks_;
mutable uint64_t next_task_id_ = 0;
mutable uint64_t current_task_id_ = 0;
}; // class SparseInvertedIndexNodeCC

KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX, SparseInvertedIndexNode, knowhere::feature::MMAP,
/*use_wand=*/false)
KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_WAND, SparseInvertedIndexNode, knowhere::feature::MMAP,
/*use_wand=*/true)
KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_INVERTED_INDEX_CC, SparseInvertedIndexNodeCC,
knowhere::feature::MMAP,
/*use_wand=*/false)
KNOWHERE_SIMPLE_REGISTER_SPARSE_FLOAT_GLOBAL(SPARSE_WAND_CC, SparseInvertedIndexNodeCC, knowhere::feature::MMAP,
/*use_wand=*/true)
} // namespace knowhere
118 changes: 118 additions & 0 deletions tests/ut/test_sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.

#include <future>
#include <thread>

#include "catch2/catch_test_macros.hpp"
#include "catch2/generators/catch_generators.hpp"
Expand Down Expand Up @@ -547,3 +548,120 @@ TEST_CASE("Test Mem Sparse Index Handle Empty Vector", "[float metrics]") {
}
}
}

TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
std::atomic<int32_t> value_base(0);
// each time a new batch of vectors are generated, the base value is increased by 1.
// also the sparse vectors are all full, so newly generated vectors are guaranteed
// to have larger IP than old vectors.
auto doc_vector_gen = [&](int32_t nb, int32_t dim) {
auto base = value_base.fetch_add(1);
std::vector<std::map<int32_t, float>> data(nb);
for (int32_t i = 0; i < nb; ++i) {
for (int32_t j = 0; j < dim; ++j) {
data[i][j] = base + static_cast<float>(rand()) / RAND_MAX * 0.8 + 0.1;
}
}
return GenSparseDataSet(data, dim);
};

auto nb = 1000;
auto dim = 30;
auto topk = 50;
int64_t nq = 100;

auto query_ds = doc_vector_gen(nq, dim);

// drop ratio build is not supported in CC index
auto drop_ratio_build = 0.0;
auto drop_ratio_search = GENERATE(0.0, 0.3);

auto metric = GENERATE(knowhere::metric::IP);
auto version = GenTestVersionList();

auto base_gen = [=, dim = dim]() {
knowhere::Json json;
json[knowhere::meta::DIM] = dim;
json[knowhere::meta::METRIC_TYPE] = metric;
json[knowhere::meta::TOPK] = topk;
json[knowhere::meta::BM25_K1] = 1.2;
json[knowhere::meta::BM25_B] = 0.75;
json[knowhere::meta::BM25_AVGDL] = 100;
return json;
};

auto sparse_inverted_index_gen = [base_gen, drop_ratio_build = drop_ratio_build,
drop_ratio_search = drop_ratio_search]() {
knowhere::Json json = base_gen();
json[knowhere::indexparam::DROP_RATIO_BUILD] = drop_ratio_build;
json[knowhere::indexparam::DROP_RATIO_SEARCH] = drop_ratio_search;
return json;
};

const knowhere::Json conf = {
{knowhere::meta::METRIC_TYPE, metric}, {knowhere::meta::TOPK, topk}, {knowhere::meta::BM25_K1, 1.2},
{knowhere::meta::BM25_B, 0.75}, {knowhere::meta::BM25_AVGDL, 100},
};

// since all newly inserted vectors are guaranteed to have larger IP than old vectors,
// the result ids of each search requests shoule be from the same batch of inserted vectors.
auto check_result = [&](const knowhere::DataSet& ds) {
auto nq = ds.GetRows();
auto k = ds.GetDim();
auto* ids = ds.GetIds();
auto expected_id_base = ids[0] / nb;
for (auto i = 0; i < nq; ++i) {
for (auto j = 0; j < k; ++j) {
auto base = ids[i * k + j] / nb;
REQUIRE(base == expected_id_base);
}
}
};

auto test_time = 10;

SECTION("Test Search") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_INVERTED_INDEX_CC, sparse_inverted_index_gen),
make_tuple(knowhere::IndexEnum::INDEX_SPARSE_WAND_CC, sparse_inverted_index_gen),
}));

auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
CAPTURE(name, cfg_json);
knowhere::Json json = knowhere::Json::parse(cfg_json);
REQUIRE(idx.Type() == name);
// build the index with some initial data
REQUIRE(idx.Build(doc_vector_gen(nb, dim), json) == knowhere::Status::success);

auto add_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto doc_ds = doc_vector_gen(nb, dim);
auto res = idx.Add(doc_ds, json);
REQUIRE(res == knowhere::Status::success);
}
};

auto search_task = [&]() {
auto start = std::chrono::steady_clock::now();
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
test_time) {
auto results = idx.Search(query_ds, json, nullptr);
REQUIRE(results.has_value());
check_result(*results.value());
}
};

std::vector<std::future<void>> task_list;
for (int thread = 0; thread < 5; thread++) {
task_list.push_back(std::async(std::launch::async, search_task));
}
task_list.push_back(std::async(std::launch::async, add_task));
for (auto& task : task_list) {
task.wait();
}
}
}
Loading