Skip to content

Commit

Permalink
enhance: BF supports ids of base_data starting from a specific value
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Nov 13, 2024
1 parent 5935c1f commit beff789
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 12 deletions.
6 changes: 4 additions & 2 deletions include/knowhere/bitsetview_idselector.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ namespace knowhere {

struct BitsetViewIDSelector final : faiss::IDSelector {
const BitsetView bitset_view;
const size_t id_offset;

inline BitsetViewIDSelector(BitsetView bitset_view) : bitset_view{bitset_view} {
inline BitsetViewIDSelector(BitsetView bitset_view, const size_t offset = 0)
: bitset_view{bitset_view}, id_offset(offset) {
}

inline bool
is_member(faiss::idx_t id) const override final {
// it is by design that bitset_view.empty() is not tested here
return (!bitset_view.test(id));
return (!bitset_view.test(id + id_offset));
}
};

Expand Down
1 change: 1 addition & 0 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ constexpr const char* RETAIN_ITERATOR_ORDER = "retain_iterator_order";
constexpr const char* RADIUS = "radius";
constexpr const char* RANGE_FILTER = "range_filter";
constexpr const char* INPUT_IDS = "input_ids";
constexpr const char* INPUT_BEG_ID = "input_begin_id";
constexpr const char* OUTPUT_TENSOR = "output_tensor";
constexpr const char* DEVICE_ID = "gpu_id";
constexpr const char* NUM_BUILD_THREAD = "num_build_thread";
Expand Down
20 changes: 19 additions & 1 deletion include/knowhere/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->data_[meta::DIM] = Var(std::in_place_index<4>, dim);
}

void
SetTensorBeginId(const int64_t offset) {
std::unique_lock lock(mutex_);
this->data_[meta::INPUT_BEG_ID] = Var(std::in_place_index<4>, offset);
}

void
SetJsonInfo(const std::string& info) {
std::unique_lock lock(mutex_);
Expand Down Expand Up @@ -260,6 +266,17 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
this->is_sparse = is_sparse;
}

int64_t
GetTensorBeginId() const {
std::shared_lock lock(mutex_);
auto it = this->data_.find(meta::INPUT_BEG_ID);
if (it != this->data_.end()) {
int64_t res = *std::get_if<4>(&it->second);
return res;
}
return 0;
}

// deprecated API
template <typename T>
void
Expand Down Expand Up @@ -288,12 +305,13 @@ class DataSet : public std::enable_shared_from_this<const DataSet> {
using DataSetPtr = std::shared_ptr<DataSet>;

inline DataSetPtr
GenDataSet(const int64_t nb, const int64_t dim, const void* xb) {
GenDataSet(const int64_t nb, const int64_t dim, const void* xb, const int64_t beg_id = 0) {
auto ret_ds = std::make_shared<DataSet>();
ret_ds->SetRows(nb);
ret_ds->SetDim(dim);
ret_ds->SetTensor(xb);
ret_ds->SetIsOwner(false);
ret_ds->SetTensorBeginId(beg_id);
return ret_ds;
}

Expand Down
43 changes: 34 additions & 9 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
namespace knowhere {

/* knowhere wrapper API to call faiss brute force search for all metric types */
/* If the ids of base_dataset does not start from 0, the BF functions will filter based on the real ids and return the
* real ids.*/

class BruteForceConfig : public BaseConfig {};

Expand Down Expand Up @@ -70,6 +72,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset

auto xb = base->GetTensor();
auto nb = base->GetRows();
auto xb_id_offset = base->GetTensorBeginId();
auto dim = base->GetDim();

auto xq = query->GetTensor();
Expand Down Expand Up @@ -121,7 +124,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
auto cur_labels = labels_ptr + topk * index;
auto cur_distances = distances_ptr + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -179,6 +182,11 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset
if (ret != Status::success) {
return expected<DataSetPtr>::Err(ret, "failed to brute force search");
}
if (xb_id_offset != 0) {
for (auto i = 0; i < nq * topk; i++) {
labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset;
}
}
auto res = GenResultDataSet(nq, cfg.k.value(), std::move(labels), std::move(distances));

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
Expand All @@ -202,6 +210,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -248,7 +257,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
auto cur_labels = labels + topk * index;
auto cur_distances = distances + topk * index;

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -311,6 +320,11 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_
}
// LCOV_EXCL_STOP
#endif
if (xb_id_offset != 0) {
for (auto i = 0; i < nq * topk; i++) {
labels[i] = labels[i] == -1 ? -1 : labels[i] + xb_id_offset;
}
}

return Status::success;
}
Expand All @@ -331,6 +345,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -423,7 +438,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
ThreadPool::ScopedSearchOmpSetter setter(1);
faiss::RangeSearchResult res(1);

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;

switch (faiss_metric_type) {
Expand Down Expand Up @@ -469,7 +484,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da
result_id_array[index].resize(elem_cnt);
for (size_t j = 0; j < elem_cnt; j++) {
result_dist_array[index][j] = res.distances[j];
result_id_array[index][j] = res.labels[j];
result_id_array[index][j] = res.labels[j] + xb_id_offset;
}
if (cfg.range_filter.value() != defaultRangeFilter) {
FilterRangeSearchResultForOneNq(result_dist_array[index], result_id_array[index], is_ip, radius,
Expand Down Expand Up @@ -504,6 +519,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
auto base = static_cast<const sparse::SparseRow<float>*>(base_dataset->GetTensor());
auto rows = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
auto xb_id_offset = base_dataset->GetTensorBeginId();

auto xq = static_cast<const sparse::SparseRow<float>*>(query_dataset->GetTensor());
auto nq = query_dataset->GetRows();
Expand Down Expand Up @@ -561,7 +577,8 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
}
sparse::MaxMinHeap<float> heap(topk);
for (int64_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
auto x_id = j + xb_id_offset;
if (!bitset.empty() && bitset.test(x_id)) {
continue;
}
float row_sum = 0;
Expand All @@ -573,7 +590,7 @@ BruteForce::SearchSparseWithBuf(const DataSetPtr base_dataset, const DataSetPtr
}
float dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
heap.push(j, dist);
heap.push(x_id, dist);
}
}
int result_size = heap.size();
Expand Down Expand Up @@ -626,6 +643,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
auto xb = base->GetTensor();
auto nb = base->GetRows();
auto dim = base->GetDim();
auto xb_id_offset = base->GetTensorBeginId();

auto xq = query->GetTensor();
auto nq = query->GetRows();
Expand Down Expand Up @@ -669,7 +687,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
futs.emplace_back(pool->push([&, index = i] {
ThreadPool::ScopedSearchOmpSetter setter(1);

BitsetViewIDSelector bw_idselector(bitset);
BitsetViewIDSelector bw_idselector(bitset, xb_id_offset);
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine;
auto max_dis = larger_is_closer ? std::numeric_limits<float>::lowest() : std::numeric_limits<float>::max();
Expand Down Expand Up @@ -697,6 +715,11 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
return Status::invalid_metric_type;
}
}
if (xb_id_offset != 0) {
for (auto i = 0; i < distances_ids.size(); i++) {
distances_ids[i].id = distances_ids[i].id == -1 ? -1 : distances_ids[i].id + xb_id_offset;
}
}
vec[index] = std::make_shared<PrecomputedDistanceIterator>(std::move(distances_ids), larger_is_closer);

return Status::success;
Expand Down Expand Up @@ -726,6 +749,7 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
auto base = static_cast<const sparse::SparseRow<float>*>(base_dataset->GetTensor());
auto rows = base_dataset->GetRows();
auto dim = base_dataset->GetDim();
auto xb_id_offset = base_dataset->GetTensorBeginId();

auto xq = static_cast<const sparse::SparseRow<float>*>(query_dataset->GetTensor());
auto nq = query_dataset->GetRows();
Expand Down Expand Up @@ -776,7 +800,8 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
std::vector<DistId> distances_ids;
if (row.size() > 0) {
for (int64_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
auto xb_id = j + xb_id_offset;
if (!bitset.empty() && bitset.test(xb_id)) {
continue;
}
float row_sum = 0;
Expand All @@ -788,7 +813,7 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
}
auto dist = row.dot(base[j], computer, row_sum);
if (dist > 0) {
distances_ids.emplace_back(j, dist);
distances_ids.emplace_back(xb_id, dist);
}
}
}
Expand Down
47 changes: 47 additions & 0 deletions tests/ut/test_bruteforce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "catch2/catch_approx.hpp"
#include "catch2/catch_test_macros.hpp"
#include "catch2/generators/catch_generators.hpp"
#include "faiss/utils/Heap.h"
#include "knowhere/comp/brute_force.h"
#include "knowhere/comp/index_param.h"
#include "knowhere/utils.h"
Expand Down Expand Up @@ -196,3 +197,49 @@ TEST_CASE("Test Brute Force", "[binary vector]") {
}
}
}

TEST_CASE("Test Brute Force with input ids", "[float vector]") {
using Catch::Approx;
const int64_t nb = 1000;
const int64_t nq = 1;
const int64_t dim = 128;
const int64_t k = 10;
const knowhere::Json conf = {
{knowhere::meta::DIM, dim},
{knowhere::meta::METRIC_TYPE, "L2"},
{knowhere::meta::TOPK, k},
};
std::vector<int64_t> block_prefix = {0, 333, 500, 555, 1000};

// generate filter id and data
auto filter_bits = GenerateBitsetWithRandomTbitsSet(nb, 100);
knowhere::BitsetView bitset(filter_bits.data(), nb);

const auto total_train_ds = GenDataSet(nb, dim);
const auto query_ds = GenDataSet(nq, dim);

std::vector<float> dis(nq * k, std::numeric_limits<float>::quiet_NaN());
std::vector<int64_t> ids(nq * k, -1);
faiss::float_maxheap_array_t heaps{nq, k, ids.data(), dis.data()};
heaps.heapify();
for (auto i = 0; i < block_prefix.size() - 1; i++) {
auto begin_id = block_prefix[i];
auto end_id = block_prefix[i + 1];
auto blk_rows = end_id - begin_id;
auto tensor = (const float*)total_train_ds->GetTensor() + dim * begin_id;
auto blk_train_ds = knowhere::GenDataSet(blk_rows, dim, tensor, begin_id);
auto partial_v = knowhere::BruteForce::Search<knowhere::fp32>(blk_train_ds, query_ds, conf, bitset);
REQUIRE(partial_v.has_value());
auto partial_res = partial_v.value();
heaps.addn_with_ids(k, partial_res->GetDistance(), partial_res->GetIds(), k, 0, nq);
}
heaps.reorder();

auto gt = knowhere::BruteForce::Search<knowhere::fp32>(total_train_ds, query_ds, conf, bitset);
auto gt_ids = gt.value()->GetIds();
auto gt_dis = gt.value()->GetDistance();
for (auto i = 0; i < nq * k; i++) {
REQUIRE(gt_ids[i] == ids[i]);
REQUIRE(gt_dis[i] == dis[i]);
}
}

0 comments on commit beff789

Please sign in to comment.