Skip to content
This repository has been archived by the owner on Aug 16, 2023. It is now read-only.

Commit

Permalink
Add API BruteForceRangeSearch() (#243)
Browse files Browse the repository at this point in the history
Signed-off-by: yudong.cai <[email protected]>
  • Loading branch information
cydrain authored Jul 11, 2022
1 parent eeacb8a commit 2d7ba84
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 21 deletions.
74 changes: 64 additions & 10 deletions knowhere/archive/BruteForce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,25 @@
#include "faiss/utils/distances.h"
#include "knowhere/archive/BruteForce.h"
#include "knowhere/common/Exception.h"
#include "knowhere/common/Log.h"

namespace knowhere {

// copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search()
// disable lint to make further migration easier
void
BruteForceSearch(const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const int64_t topk,
int64_t* labels,
float* distances,
const faiss::BitsetView bitset) {
BruteForceSearch(
const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const int64_t topk,
int64_t* labels,
float* distances,
const faiss::BitsetView bitset) {

auto faiss_metric_type = GetFaissMetricType(metric_type);

switch (faiss_metric_type) {
Expand Down Expand Up @@ -84,4 +87,55 @@ BruteForceSearch(const knowhere::MetricType& metric_type,
}
}

void
BruteForceRangeSearch(
const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const float radius,
int64_t*& labels,
float*& distances,
size_t*& lims,
const faiss::BitsetView bitset) {

auto faiss_metric_type = GetFaissMetricType(metric_type);

faiss::RangeSearchResult res(nq);
switch (faiss_metric_type) {
case faiss::METRIC_L2:
faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius * radius, &res, bitset);
break;
case faiss::METRIC_INNER_PRODUCT:
faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset);
break;
case faiss::METRIC_Jaccard:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Jaccard,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Tanimoto:
faiss::binary_range_search<faiss::CMin<float, int64_t>, float>(faiss::METRIC_Tanimoto,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset);
break;
case faiss::METRIC_Hamming:
faiss::binary_range_search<faiss::CMin<int, int64_t>, int>(faiss::METRIC_Hamming,
(const uint8_t*)xq, (const uint8_t*)xb, nq, nb, (int)radius, dim / 8, &res, bitset);
break;
default:
KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type);
}

labels = res.labels;
distances = res.distances;
lims = res.lims;

LOG_KNOWHERE_DEBUG_ << "Range search result num: " << lims[nq];

res.distances = nullptr;
res.labels = nullptr;
res.lims = nullptr;
}

} // namespace knowhere
62 changes: 52 additions & 10 deletions knowhere/archive/BruteForce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,58 @@

namespace knowhere {

/** knowhere wrapper API to call faiss brute force search for all metric types
*
* @param metric_type
* @param xb training vecors, size nb * dim
* @param xq query vecors, size nq * dim
* @param dim
* @param nb rows of training vectors
* @param nq rows of query vectors
* @param topk
* @param labels output, memory allocated and freed by caller
* @param distances output, memory allocated and freed by coller
* @param bitset
*/
void
BruteForceSearch(
const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const int64_t topk,
int64_t* labels,
float* distances,
const faiss::BitsetView bitset);

/** knowhere wrapper API to call faiss brute force range search for all metric types
*
* @param metric_type
* @param xb training vecors, size nb * dim
* @param xq query vecors, size nq * dim
* @param dim
* @param nb rows of training vectors
* @param nq rows of query vectors
* @param radius range search radius
* @param labels output, memory allocated inside and freed by caller
* @param distances output, memory allocated inside and freed by coller
* @param lims output, memory allocated inside and freed by coller
* @param bitset
*/
void
BruteForceSearch(const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const int64_t topk,
int64_t* labels,
float* distances,
const faiss::BitsetView bitset);
BruteForceRangeSearch(
const knowhere::MetricType& metric_type,
const void* xb,
const void* xq,
const int64_t dim,
const int64_t nb,
const int64_t nq,
const float radius,
int64_t*& labels,
float*& distances,
size_t*& lims,
const faiss::BitsetView bitset);

} // namespace knowhere
145 changes: 144 additions & 1 deletion unittest/test_bruteforce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,147 @@ TEST_P(BruteForceTest, binary_basic) {
ASSERT_FALSE(labels[i * k] == i);
}
}
}
}

TEST_P(BruteForceTest, float_range_search_l2) {
Init_with_default();
auto metric_type = knowhere::metric::L2;

auto qd = knowhere::GenDataset(nq, dim, xq.data());

auto test_range_search_l2 = [&](float radius, const faiss::BitsetView bitset) {
std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
int64_t* labels = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RunFloatRangeSearchBF<CMin<float>>(golden_labels, golden_distances, golden_lims, metric_type,
xb.data(), nb, xq.data(), nq, dim, radius, bitset);

knowhere::BruteForceRangeSearch(metric_type, xb.data(), xq.data(), dim, nb, nq, radius, labels, distances,
lims, bitset);
auto result = knowhere::GenResultDataset(labels, distances, lims);
CheckRangeSearchResult<CMin<float>>(result, nq, radius * radius, golden_labels.data(), golden_lims.data(), true);
};

auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold();
for (int64_t blas_threshold : {0, 20}) {
knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold);
for (float radius: {4.1f, 4.2f, 4.3f}) {
test_range_search_l2(radius, nullptr);
test_range_search_l2(radius, *bitset);
}
}
knowhere::KnowhereConfig::SetBlasThreshold(old_blas_threshold);
}

TEST_P(BruteForceTest, float_range_search_ip) {
Init_with_default();
auto metric_type = knowhere::metric::IP;

normalize(xb.data(), nb, dim);
normalize(xq.data(), nq, dim);

auto test_range_search_ip = [&](float radius, const faiss::BitsetView bitset) {
std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
int64_t* labels = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RunFloatRangeSearchBF<CMax<float>>(golden_labels, golden_distances, golden_lims, metric_type,
xb.data(), nb, xq.data(), nq, dim, radius, bitset);

knowhere::BruteForceRangeSearch(metric_type, xb.data(), xq.data(), dim, nb, nq, radius, labels, distances,
lims, bitset);
auto result = knowhere::GenResultDataset(labels, distances, lims);
CheckRangeSearchResult<CMax<float>>(result, nq, radius, golden_labels.data(), golden_lims.data(), true);
};

auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold();
for (int64_t blas_threshold : {0, 20}) {
knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold);
//for (float radius: {42.0f, 43.0f, 44.0f}) {
for (float radius: {0.75f, 0.78f, 0.81f}) {
test_range_search_ip(radius, nullptr);
test_range_search_ip(radius, *bitset);
}
}
knowhere::KnowhereConfig::SetBlasThreshold(old_blas_threshold);
}

TEST_P(BruteForceTest, binary_range_search_hamming) {
Init_with_default(true);
int hamming_radius = 50;
auto metric_type = knowhere::metric::HAMMING;

auto test_range_search_hamming = [&](float radius, const faiss::BitsetView bitset) {
std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
int64_t* labels = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RunBinaryRangeSearchBF<CMin<float>>(golden_labels, golden_distances, golden_lims, metric_type,
xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset);

knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels,
distances, lims, bitset);
auto result = knowhere::GenResultDataset(labels, distances, lims);
CheckRangeSearchResult<CMin<float>>(result, nq, radius, golden_labels.data(), golden_lims.data(), true);
};

test_range_search_hamming(hamming_radius, nullptr);
test_range_search_hamming(hamming_radius, *bitset);
}

TEST_P(BruteForceTest, binary_range_search_jaccard) {
Init_with_default(true);
float jaccard_radius = 0.5;
auto metric_type = knowhere::metric::JACCARD;

auto test_range_search_jaccard = [&](float radius, const faiss::BitsetView bitset) {
std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
int64_t* labels = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RunBinaryRangeSearchBF<CMin<float>>(golden_labels, golden_distances, golden_lims, knowhere::metric::JACCARD,
xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset);

knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels,
distances, lims, bitset);
auto result = knowhere::GenResultDataset(labels, distances, lims);
CheckRangeSearchResult<CMin<float>>(result, nq, radius, golden_labels.data(), golden_lims.data(), true);
};

test_range_search_jaccard(jaccard_radius, nullptr);
test_range_search_jaccard(jaccard_radius, *bitset);
}

TEST_P(BruteForceTest, binary_range_search_tanimoto) {
Init_with_default(true);
float tanimoto_radius = 1.0;
auto metric_type = knowhere::metric::TANIMOTO;

auto test_range_search_tanimoto = [&](float radius, const faiss::BitsetView bitset) {
std::vector<int64_t> golden_labels;
std::vector<float> golden_distances;
std::vector<size_t> golden_lims;
int64_t* labels = nullptr;
float* distances = nullptr;
size_t* lims = nullptr;
RunBinaryRangeSearchBF<CMin<float>>(golden_labels, golden_distances, golden_lims, metric_type,
xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset);

knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels,
distances, lims, bitset);
auto result = knowhere::GenResultDataset(labels, distances, lims);
CheckRangeSearchResult<CMin<float>>(result, nq, radius, golden_labels.data(), golden_lims.data(), true);
};

test_range_search_tanimoto(tanimoto_radius, nullptr);
test_range_search_tanimoto(tanimoto_radius, *bitset);
}

0 comments on commit 2d7ba84

Please sign in to comment.