diff --git a/knowhere/archive/BruteForce.cpp b/knowhere/archive/BruteForce.cpp index 25795c4b9..5afd29ff1 100644 --- a/knowhere/archive/BruteForce.cpp +++ b/knowhere/archive/BruteForce.cpp @@ -20,22 +20,25 @@ #include "faiss/utils/distances.h" #include "knowhere/archive/BruteForce.h" #include "knowhere/common/Exception.h" +#include "knowhere/common/Log.h" namespace knowhere { // copy from faiss/IndexBinaryFlat.cpp::IndexBinaryFlat::search() // disable lint to make further migration easier void -BruteForceSearch(const knowhere::MetricType& metric_type, - const void* xb, - const void* xq, - const int64_t dim, - const int64_t nb, - const int64_t nq, - const int64_t topk, - int64_t* labels, - float* distances, - const faiss::BitsetView bitset) { +BruteForceSearch( + const knowhere::MetricType& metric_type, + const void* xb, + const void* xq, + const int64_t dim, + const int64_t nb, + const int64_t nq, + const int64_t topk, + int64_t* labels, + float* distances, + const faiss::BitsetView bitset) { + auto faiss_metric_type = GetFaissMetricType(metric_type); switch (faiss_metric_type) { @@ -84,4 +87,55 @@ BruteForceSearch(const knowhere::MetricType& metric_type, } } +void +BruteForceRangeSearch( + const knowhere::MetricType& metric_type, + const void* xb, + const void* xq, + const int64_t dim, + const int64_t nb, + const int64_t nq, + const float radius, + int64_t*& labels, + float*& distances, + size_t*& lims, + const faiss::BitsetView bitset) { + + auto faiss_metric_type = GetFaissMetricType(metric_type); + + faiss::RangeSearchResult res(nq); + switch (faiss_metric_type) { + case faiss::METRIC_L2: + faiss::range_search_L2sqr((const float*)xq, (const float*)xb, dim, nq, nb, radius * radius, &res, bitset); + break; + case faiss::METRIC_INNER_PRODUCT: + faiss::range_search_inner_product((const float*)xq, (const float*)xb, dim, nq, nb, radius, &res, bitset); + break; + case faiss::METRIC_Jaccard: + faiss::binary_range_search, float>(faiss::METRIC_Jaccard, + (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset); + break; + case faiss::METRIC_Tanimoto: + faiss::binary_range_search, float>(faiss::METRIC_Tanimoto, + (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, radius, dim / 8, &res, bitset); + break; + case faiss::METRIC_Hamming: + faiss::binary_range_search, int>(faiss::METRIC_Hamming, + (const uint8_t*)xq, (const uint8_t*)xb, nq, nb, (int)radius, dim / 8, &res, bitset); + break; + default: + KNOWHERE_THROW_MSG("BruteForce range search not support metric type: " + metric_type); + } + + labels = res.labels; + distances = res.distances; + lims = res.lims; + + LOG_KNOWHERE_DEBUG_ << "Range search result num: " << lims[nq]; + + res.distances = nullptr; + res.labels = nullptr; + res.lims = nullptr; +} + } // namespace knowhere diff --git a/knowhere/archive/BruteForce.h b/knowhere/archive/BruteForce.h index addfd0cfb..168191f93 100644 --- a/knowhere/archive/BruteForce.h +++ b/knowhere/archive/BruteForce.h @@ -24,16 +24,58 @@ namespace knowhere { +/** knowhere wrapper API to call faiss brute force search for all metric types + * + * @param metric_type + * @param xb training vecors, size nb * dim + * @param xq query vecors, size nq * dim + * @param dim + * @param nb rows of training vectors + * @param nq rows of query vectors + * @param topk + * @param labels output, memory allocated and freed by caller + * @param distances output, memory allocated and freed by coller + * @param bitset + */ +void +BruteForceSearch( + const knowhere::MetricType& metric_type, + const void* xb, + const void* xq, + const int64_t dim, + const int64_t nb, + const int64_t nq, + const int64_t topk, + int64_t* labels, + float* distances, + const faiss::BitsetView bitset); + +/** knowhere wrapper API to call faiss brute force range search for all metric types + * + * @param metric_type + * @param xb training vecors, size nb * dim + * @param xq query vecors, size nq * dim + * @param dim + * @param nb rows of training vectors + * @param nq rows of query vectors + * @param radius range search radius + * @param labels output, memory allocated inside and freed by caller + * @param distances output, memory allocated inside and freed by coller + * @param lims output, memory allocated inside and freed by coller + * @param bitset + */ void -BruteForceSearch(const knowhere::MetricType& metric_type, - const void* xb, - const void* xq, - const int64_t dim, - const int64_t nb, - const int64_t nq, - const int64_t topk, - int64_t* labels, - float* distances, - const faiss::BitsetView bitset); +BruteForceRangeSearch( + const knowhere::MetricType& metric_type, + const void* xb, + const void* xq, + const int64_t dim, + const int64_t nb, + const int64_t nq, + const float radius, + int64_t*& labels, + float*& distances, + size_t*& lims, + const faiss::BitsetView bitset); } // namespace knowhere diff --git a/unittest/test_bruteforce.cpp b/unittest/test_bruteforce.cpp index 49e1bc688..47974b54d 100644 --- a/unittest/test_bruteforce.cpp +++ b/unittest/test_bruteforce.cpp @@ -85,4 +85,147 @@ TEST_P(BruteForceTest, binary_basic) { ASSERT_FALSE(labels[i * k] == i); } } -} \ No newline at end of file +} + +TEST_P(BruteForceTest, float_range_search_l2) { + Init_with_default(); + auto metric_type = knowhere::metric::L2; + + auto qd = knowhere::GenDataset(nq, dim, xq.data()); + + auto test_range_search_l2 = [&](float radius, const faiss::BitsetView bitset) { + std::vector golden_labels; + std::vector golden_distances; + std::vector golden_lims; + int64_t* labels = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + RunFloatRangeSearchBF>(golden_labels, golden_distances, golden_lims, metric_type, + xb.data(), nb, xq.data(), nq, dim, radius, bitset); + + knowhere::BruteForceRangeSearch(metric_type, xb.data(), xq.data(), dim, nb, nq, radius, labels, distances, + lims, bitset); + auto result = knowhere::GenResultDataset(labels, distances, lims); + CheckRangeSearchResult>(result, nq, radius * radius, golden_labels.data(), golden_lims.data(), true); + }; + + auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold(); + for (int64_t blas_threshold : {0, 20}) { + knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold); + for (float radius: {4.1f, 4.2f, 4.3f}) { + test_range_search_l2(radius, nullptr); + test_range_search_l2(radius, *bitset); + } + } + knowhere::KnowhereConfig::SetBlasThreshold(old_blas_threshold); +} + +TEST_P(BruteForceTest, float_range_search_ip) { + Init_with_default(); + auto metric_type = knowhere::metric::IP; + + normalize(xb.data(), nb, dim); + normalize(xq.data(), nq, dim); + + auto test_range_search_ip = [&](float radius, const faiss::BitsetView bitset) { + std::vector golden_labels; + std::vector golden_distances; + std::vector golden_lims; + int64_t* labels = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + RunFloatRangeSearchBF>(golden_labels, golden_distances, golden_lims, metric_type, + xb.data(), nb, xq.data(), nq, dim, radius, bitset); + + knowhere::BruteForceRangeSearch(metric_type, xb.data(), xq.data(), dim, nb, nq, radius, labels, distances, + lims, bitset); + auto result = knowhere::GenResultDataset(labels, distances, lims); + CheckRangeSearchResult>(result, nq, radius, golden_labels.data(), golden_lims.data(), true); + }; + + auto old_blas_threshold = knowhere::KnowhereConfig::GetBlasThreshold(); + for (int64_t blas_threshold : {0, 20}) { + knowhere::KnowhereConfig::SetBlasThreshold(blas_threshold); + //for (float radius: {42.0f, 43.0f, 44.0f}) { + for (float radius: {0.75f, 0.78f, 0.81f}) { + test_range_search_ip(radius, nullptr); + test_range_search_ip(radius, *bitset); + } + } + knowhere::KnowhereConfig::SetBlasThreshold(old_blas_threshold); +} + +TEST_P(BruteForceTest, binary_range_search_hamming) { + Init_with_default(true); + int hamming_radius = 50; + auto metric_type = knowhere::metric::HAMMING; + + auto test_range_search_hamming = [&](float radius, const faiss::BitsetView bitset) { + std::vector golden_labels; + std::vector golden_distances; + std::vector golden_lims; + int64_t* labels = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + RunBinaryRangeSearchBF>(golden_labels, golden_distances, golden_lims, metric_type, + xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset); + + knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels, + distances, lims, bitset); + auto result = knowhere::GenResultDataset(labels, distances, lims); + CheckRangeSearchResult>(result, nq, radius, golden_labels.data(), golden_lims.data(), true); + }; + + test_range_search_hamming(hamming_radius, nullptr); + test_range_search_hamming(hamming_radius, *bitset); +} + +TEST_P(BruteForceTest, binary_range_search_jaccard) { + Init_with_default(true); + float jaccard_radius = 0.5; + auto metric_type = knowhere::metric::JACCARD; + + auto test_range_search_jaccard = [&](float radius, const faiss::BitsetView bitset) { + std::vector golden_labels; + std::vector golden_distances; + std::vector golden_lims; + int64_t* labels = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + RunBinaryRangeSearchBF>(golden_labels, golden_distances, golden_lims, knowhere::metric::JACCARD, + xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset); + + knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels, + distances, lims, bitset); + auto result = knowhere::GenResultDataset(labels, distances, lims); + CheckRangeSearchResult>(result, nq, radius, golden_labels.data(), golden_lims.data(), true); + }; + + test_range_search_jaccard(jaccard_radius, nullptr); + test_range_search_jaccard(jaccard_radius, *bitset); +} + +TEST_P(BruteForceTest, binary_range_search_tanimoto) { + Init_with_default(true); + float tanimoto_radius = 1.0; + auto metric_type = knowhere::metric::TANIMOTO; + + auto test_range_search_tanimoto = [&](float radius, const faiss::BitsetView bitset) { + std::vector golden_labels; + std::vector golden_distances; + std::vector golden_lims; + int64_t* labels = nullptr; + float* distances = nullptr; + size_t* lims = nullptr; + RunBinaryRangeSearchBF>(golden_labels, golden_distances, golden_lims, metric_type, + xb_bin.data(), nb, xq_bin.data(), nq, dim, radius, bitset); + + knowhere::BruteForceRangeSearch(metric_type, xb_bin.data(), xq_bin.data(), dim, nb, nq, radius, labels, + distances, lims, bitset); + auto result = knowhere::GenResultDataset(labels, distances, lims); + CheckRangeSearchResult>(result, nq, radius, golden_labels.data(), golden_lims.data(), true); + }; + + test_range_search_tanimoto(tanimoto_radius, nullptr); + test_range_search_tanimoto(tanimoto_radius, *bitset); +}