diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index e35246db5..23c90be85 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -51,3 +51,5 @@ benchmark_test(benchmark_float_bitset hdf5/benchmark_float_bitset.cpp) benchmark_test(benchmark_float_qps hdf5/benchmark_float_qps.cpp) benchmark_test(benchmark_float_range hdf5/benchmark_float_range.cpp) benchmark_test(benchmark_float_range_bitset hdf5/benchmark_float_range_bitset.cpp) + +benchmark_test(gen_hdf5_file hdf5/gen_hdf5_file.cpp) diff --git a/benchmark/hdf5/benchmark_hdf5.h b/benchmark/hdf5/benchmark_hdf5.h index 92602e578..ddc8ffe8b 100644 --- a/benchmark/hdf5/benchmark_hdf5.h +++ b/benchmark/hdf5/benchmark_hdf5.h @@ -330,6 +330,20 @@ class Benchmark_hdf5 : public Benchmark_base { return data_out; } + void + write_hdf5_dataset(hid_t file, const char* dataset_name, hid_t type_id, int32_t rows, int32_t cols, + const void* data) { + hsize_t dims[2]; + dims[0] = rows; + dims[1] = cols; + auto dataspace = H5Screate_simple(2, dims, NULL); + auto dataset = H5Dcreate2(file, dataset_name, type_id, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); + auto err = H5Dwrite(dataset, type_id, H5S_ALL, H5S_ALL, H5P_DEFAULT, data); + assert(err == 0); + H5Dclose(dataset); + H5Sclose(dataspace); + } + // For binary vector, dim should be divided by 32, since we use int32 to store binary vector data */ template void @@ -338,31 +352,18 @@ class Benchmark_hdf5 : public Benchmark_base { /* Open the file and the dataset. */ hid_t file = H5Fcreate(file_name, H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); - auto write_hdf5_dataset = [](hid_t file, const char* dataset_name, hid_t type_id, int32_t rows, int32_t cols, - const void* data) { - hsize_t dims[2]; - dims[0] = rows; - dims[1] = cols; - auto dataspace = H5Screate_simple(2, dims, NULL); - auto dataset = H5Dcreate2(file, dataset_name, type_id, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); - auto err = H5Dwrite(dataset, type_id, H5S_ALL, H5S_ALL, H5P_DEFAULT, data); - assert(err == 0); - H5Dclose(dataset); - H5Sclose(dataspace); - }; - /* write train dataset */ if (!is_binary) { write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_FLOAT, nb, dim, xb); } else { - write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_INT32, nb, dim, xb); + write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_INT32, nb, dim / 32, xb); } /* write test dataset */ if (!is_binary) { write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_FLOAT, nq, dim, xq); } else { - write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_INT32, nq, dim, xq); + write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_INT32, nq, dim / 32, xq); } /* write ground-truth labels dataset */ @@ -388,31 +389,18 @@ class Benchmark_hdf5 : public Benchmark_base { /* Open the file and the dataset. */ hid_t file = H5Fcreate(file_name, H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); - auto write_hdf5_dataset = [](hid_t file, const char* dataset_name, hid_t type_id, int32_t rows, int32_t cols, - const void* data) { - hsize_t dims[2]; - dims[0] = rows; - dims[1] = cols; - auto dataspace = H5Screate_simple(2, dims, NULL); - auto dataset = H5Dcreate2(file, dataset_name, type_id, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); - auto err = H5Dwrite(dataset, type_id, H5S_ALL, H5S_ALL, H5P_DEFAULT, data); - assert(err == 0); - H5Dclose(dataset); - H5Sclose(dataspace); - }; - /* write train dataset */ if (!is_binary) { write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_FLOAT, nb, dim, xb); } else { - write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_INT32, nb, dim, xb); + write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_INT32, nb, dim / 32, xb); } /* write test dataset */ if (!is_binary) { write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_FLOAT, nq, dim, xq); } else { - write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_INT32, nq, dim, xq); + write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_INT32, nq, dim / 32, xq); } /* write ground-truth radius */ @@ -431,63 +419,6 @@ class Benchmark_hdf5 : public Benchmark_base { H5Fclose(file); } - // For binary vector, dim should be divided by 32, since we use int32 to store binary vector data */ - // Write HDF5 file with following dataset: - // HDF5_DATASET_RADIUS - H5T_NATIVE_FLOAT, [1, nq] - // HDF5_DATASET_LIMS - H5T_NATIVE_INT32, [1, nq+1] - // HDF5_DATASET_NEIGHBORS - H5T_NATIVE_INT32, [1, lims[nq]] - // HDF5_DATASET_DISTANCES - H5T_NATIVE_FLOAT, [1, lims[nq]] - template - void - hdf5_write_range(const char* file_name, const int32_t dim, const void* xb, const int32_t nb, const void* xq, - const int32_t nq, const float* g_radius, const void* g_lims, const void* g_ids, - const void* g_dist) { - /* Open the file and the dataset. */ - hid_t file = H5Fcreate(file_name, H5F_ACC_TRUNC, H5P_DEFAULT, H5P_DEFAULT); - - auto write_hdf5_dataset = [](hid_t file, const char* dataset_name, hid_t type_id, int32_t rows, int32_t cols, - const void* data) { - hsize_t dims[2]; - dims[0] = rows; - dims[1] = cols; - auto dataspace = H5Screate_simple(2, dims, NULL); - auto dataset = H5Dcreate2(file, dataset_name, type_id, dataspace, H5P_DEFAULT, H5P_DEFAULT, H5P_DEFAULT); - auto err = H5Dwrite(dataset, type_id, H5S_ALL, H5S_ALL, H5P_DEFAULT, data); - assert(err == 0); - H5Dclose(dataset); - H5Sclose(dataspace); - }; - - /* write train dataset */ - if (!is_binary) { - write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_FLOAT, nb, dim, xb); - } else { - write_hdf5_dataset(file, HDF5_DATASET_TRAIN, H5T_NATIVE_INT32, nb, dim, xb); - } - - /* write test dataset */ - if (!is_binary) { - write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_FLOAT, nq, dim, xq); - } else { - write_hdf5_dataset(file, HDF5_DATASET_TEST, H5T_NATIVE_INT32, nq, dim, xq); - } - - /* write ground-truth radius */ - write_hdf5_dataset(file, HDF5_DATASET_RADIUS, H5T_NATIVE_FLOAT, 1, nq, g_radius); - - /* write ground-truth lims dataset */ - write_hdf5_dataset(file, HDF5_DATASET_LIMS, H5T_NATIVE_INT32, 1, nq + 1, g_lims); - - /* write ground-truth labels dataset */ - write_hdf5_dataset(file, HDF5_DATASET_NEIGHBORS, H5T_NATIVE_INT32, 1, ((int32_t*)g_lims)[nq], g_ids); - - /* write ground-truth distance dataset */ - write_hdf5_dataset(file, HDF5_DATASET_DISTANCES, H5T_NATIVE_FLOAT, 1, ((int32_t*)g_lims)[nq], g_dist); - - /* Close/release resources. */ - H5Fclose(file); - } - protected: std::string ann_test_name_ = ""; std::string metric_str_; diff --git a/benchmark/hdf5/gen_hdf5_file.cpp b/benchmark/hdf5/gen_hdf5_file.cpp new file mode 100644 index 000000000..d8ddfaaef --- /dev/null +++ b/benchmark/hdf5/gen_hdf5_file.cpp @@ -0,0 +1,178 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include +#include +#include + +#include "benchmark_hdf5.h" +#include "knowhere/comp/brute_force.h" +#include "knowhere/comp/index_param.h" +#include "knowhere/comp/knowhere_config.h" +#include "knowhere/dataset.h" + +knowhere::DataSetPtr +GenDataSet(int rows, int dim) { + std::mt19937 rng(42); + std::uniform_real_distribution<> distrib(-1.0, 1.0); + float* ts = new float[rows * dim]; + for (int i = 0; i < rows * dim; ++i) { + ts[i] = (float)distrib(rng); + } + auto ds = knowhere::GenDataSet(rows, dim, ts); + ds->SetIsOwner(true); + return ds; +} + +knowhere::DataSetPtr +GenBinDataSet(int rows, int dim) { + std::mt19937 rng(42); + std::uniform_int_distribution<> distrib(0, 255); + int uint8_num = dim / 8; + uint8_t* ts = new uint8_t[rows * uint8_num]; + for (int i = 0; i < rows * uint8_num; ++i) { + ts[i] = (uint8_t)distrib(rng); + } + auto ds = knowhere::GenDataSet(rows, dim, ts); + ds->SetIsOwner(true); + return ds; +} + +class Create_HDF5 : public Benchmark_hdf5, public ::testing::Test { + protected: + void + SetUp() override { + } + + void + TearDown() override { + } + + template + void + create_hdf5_file(const knowhere::MetricType& metric_type, const int64_t nb, const int64_t nq, const int64_t dim, + const int64_t topk) { + std::string metric_str = metric_type; + transform(metric_str.begin(), metric_str.end(), metric_str.begin(), ::tolower); + std::string fn = "rand-" + std::to_string(dim) + "-" + metric_str + ".hdf5"; + + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric_type; + json[knowhere::meta::TOPK] = topk; + + knowhere::DataSetPtr xb_ds, xq_ds; + if (is_binary) { + xb_ds = GenBinDataSet(nb, dim); + xq_ds = GenBinDataSet(nq, dim); + } else { + xb_ds = GenDataSet(nb, dim); + xq_ds = GenDataSet(nq, dim); + } + + auto result = knowhere::BruteForce::Search(xb_ds, xq_ds, json, nullptr); + assert(result.has_value()); + + // convert golden_ids to int32 + auto elem_cnt = nq * topk; + std::vector gt_ids_int(elem_cnt); + for (int32_t i = 0; i < elem_cnt; i++) { + gt_ids_int[i] = result.value()->GetIds()[i]; + } + + hdf5_write(fn.c_str(), dim, topk, xb_ds->GetTensor(), nb, xq_ds->GetTensor(), nq, gt_ids_int.data(), + result.value()->GetDistance()); + } + + template + void + create_range_hdf5_file(const knowhere::MetricType& metric_type, const int64_t nb, const int64_t nq, + const int64_t dim, const float radius) { + std::string metric_str = metric_type; + transform(metric_str.begin(), metric_str.end(), metric_str.begin(), ::tolower); + std::string fn = "rand-" + std::to_string(dim) + "-" + metric_str + "-range.hdf5"; + + knowhere::Json json; + json[knowhere::meta::DIM] = dim; + json[knowhere::meta::METRIC_TYPE] = metric_type; + json[knowhere::meta::RADIUS] = radius; + + knowhere::DataSetPtr xb_ds, xq_ds; + if (is_binary) { + xb_ds = GenBinDataSet(nb, dim); + xq_ds = GenBinDataSet(nq, dim); + } else { + xb_ds = GenDataSet(nb, dim); + xq_ds = GenDataSet(nq, dim); + } + + auto result = knowhere::BruteForce::RangeSearch(xb_ds, xq_ds, json, nullptr); + assert(result.has_value()); + + // convert golden_lims to int32 + std::vector gt_lims_int(nq + 1); + for (int32_t i = 0; i <= nq; i++) { + gt_lims_int[i] = result.value()->GetLims()[i]; + } + + // convert golden_ids to int32 + auto elem_cnt = result.value()->GetLims()[nq]; + std::vector gt_ids_int(elem_cnt); + for (int32_t i = 0; i < elem_cnt; i++) { + gt_ids_int[i] = result.value()->GetIds()[i]; + } + + hdf5_write_range(fn.c_str(), dim, xb_ds->GetTensor(), nb, xq_ds->GetTensor(), nq, radius, + gt_lims_int.data(), gt_ids_int.data(), result.value()->GetDistance()); + } +}; + +TEST_F(Create_HDF5, CREATE_FLOAT) { + int64_t nb = 10000; + int64_t nq = 100; + int64_t dim = 128; + int64_t topk = 100; + + create_hdf5_file(knowhere::metric::L2, nb, nq, dim, topk); + create_hdf5_file(knowhere::metric::IP, nb, nq, dim, topk); + create_hdf5_file(knowhere::metric::COSINE, nb, nq, dim, topk); +} + +TEST_F(Create_HDF5, CREATE_FLOAT_RANGE) { + int64_t nb = 10000; + int64_t nq = 100; + int64_t dim = 128; + + create_range_hdf5_file(knowhere::metric::L2, nb, nq, dim, 65.0); + create_range_hdf5_file(knowhere::metric::IP, nb, nq, dim, 8.7); + create_range_hdf5_file(knowhere::metric::COSINE, nb, nq, dim, 0.2); +} + +TEST_F(Create_HDF5, CREATE_BINARY) { + int64_t nb = 10000; + int64_t nq = 100; + int64_t dim = 1024; + int64_t topk = 100; + + create_hdf5_file(knowhere::metric::HAMMING, nb, nq, dim, topk); + create_hdf5_file(knowhere::metric::JACCARD, nb, nq, dim, topk); +} + +TEST_F(Create_HDF5, CREATE_BINARY_RANGE) { + int64_t nb = 10000; + int64_t nq = 100; + int64_t dim = 1024; + + create_range_hdf5_file(knowhere::metric::HAMMING, nb, nq, dim, 476); + create_range_hdf5_file(knowhere::metric::JACCARD, nb, nq, dim, 0.63); +} diff --git a/include/knowhere/comp/index_param.h b/include/knowhere/comp/index_param.h index 986d07241..675ccc35c 100644 --- a/include/knowhere/comp/index_param.h +++ b/include/knowhere/comp/index_param.h @@ -75,6 +75,7 @@ namespace indexparam { // IVF Params constexpr const char* NPROBE = "nprobe"; constexpr const char* NLIST = "nlist"; +constexpr const char* USE_ELKAN = "use_elkan"; constexpr const char* NBITS = "nbits"; // PQ/SQ constexpr const char* M = "m"; // PQ param for IVFPQ constexpr const char* SSIZE = "ssize"; diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 3b194a0c8..9d3c77d3b 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -87,3 +87,11 @@ def GetVectorDataSetToArray(ans): data = np.zeros([rows, dim]).astype(np.float32) swigknowhere.DataSetTensor2Array(ans, data) return data + + +def GetBinaryVectorDataSetToArray(ans): + dim = int(swigknowhere.DataSet_Dim(ans) / 32) + rows = swigknowhere.DataSet_Rows(ans) + data = np.zeros([rows, dim]).astype(np.int32) + swigknowhere.BinaryDataSetTensor2Array(ans, data) + return data diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index ee47cf328..ddcb5be57 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -69,6 +69,7 @@ import_array(); %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2){(float *dis,int nq_1,int k_1)} %apply (int *INPLACE_ARRAY2, int DIM1, int DIM2){(int *ids,int nq_2,int k_2)} %apply (float* INPLACE_ARRAY2, int DIM1, int DIM2){(float *data,int rows,int dim)} +%apply (int32_t *INPLACE_ARRAY2, int DIM1, int DIM2){(int32_t *data,int rows,int dim)} %typemap(in, numinputs=0) knowhere::Status& status(knowhere::Status tmp) %{ $1 = &tmp; @@ -329,6 +330,17 @@ DataSetTensor2Array(knowhere::DataSetPtr result, float* data, int rows, int dim) } } +void +BinaryDataSetTensor2Array(knowhere::DataSetPtr result, int32_t* data, int rows, int dim) { + GILReleaser rel; + auto data_ = result->GetTensor(); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < dim; ++j) { + *(data + i * dim + j) = *((int32_t*)(data_) + i * dim + j); + } + } +} + void DumpRangeResultIds(knowhere::DataSetPtr result, int* ids, int len) { GILReleaser rel; diff --git a/src/common/config.cc b/src/common/config.cc index f54c705a3..2e238e911 100644 --- a/src/common/config.cc +++ b/src/common/config.cc @@ -18,6 +18,7 @@ static const std::unordered_set ext_legal_json_keys = {"metric_type "dim", "nlist", // IVF param "nprobe", // IVF param + "use_elkan", // IVF param "ssize", // IVF_FLAT_CC param "nbits", // IVF_PQ param "m", // IVF_PQ param diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index e27a3ac61..109f84496 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -14,6 +14,7 @@ #include "faiss/IndexBinaryFlat.h" #include "faiss/IndexBinaryIVF.h" #include "faiss/IndexFlat.h" +#include "faiss/IndexFlatElkan.h" #include "faiss/IndexIVFFlat.h" #include "faiss/IndexIVFPQ.h" #include "faiss/IndexIVFPQFastScan.h" @@ -33,16 +34,6 @@ namespace knowhere { -template -struct QuantizerT { - typedef faiss::IndexFlat type; -}; - -template <> -struct QuantizerT { - using type = faiss::IndexBinaryFlat; -}; - template class IvfIndexNode : public IndexNode { public: @@ -241,6 +232,17 @@ MatchNbits(int64_t size, int64_t nbits) { return nbits; } +namespace { + +// turn IndexFlatElkan into IndexFlat +std::unique_ptr to_index_flat( + std::unique_ptr&& index) { + // C++ slicing here + return std::make_unique(std::move(*index)); +} + +} + template Status IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { @@ -271,24 +273,50 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { auto dim = dataset.GetDim(); auto data = dataset.GetTensor(); - typename QuantizerT::type* qzr = nullptr; - faiss::IndexIVFPQFastScan* base_index = nullptr; std::unique_ptr index; + // if cfg.use_elkan is used, then we'll use a temporary instance of + // IndexFlatElkan for the training. try { if constexpr (std::is_same::value) { const IvfFlatConfig& ivf_flat_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cfg.nlist.value()); - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, metric.value(), is_cosine); + + const bool use_elkan = ivf_flat_cfg.use_elkan.value_or(true); + + // create quantizer for the training + std::unique_ptr qzr = + std::make_unique(dim, metric.value(), false, use_elkan); + // create index. Index does not own qzr + index = std::make_unique(qzr.get(), dim, nlist, metric.value(), is_cosine); + // train index->train(rows, (const float*)data); + // replace quantizer with a regular IndexFlat + qzr = to_index_flat(std::move(qzr)); + index->quantizer = qzr.get(); + // transfer ownership of qzr to index + qzr.release(); + index->own_fields = true; } if constexpr (std::is_same::value) { const IvfFlatCcConfig& ivf_flat_cc_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_flat_cc_cfg.nlist.value()); - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, ivf_flat_cc_cfg.ssize.value(), + + const bool use_elkan = ivf_flat_cc_cfg.use_elkan.value_or(true); + + // create quantizer for the training + std::unique_ptr qzr = + std::make_unique(dim, metric.value(), false, use_elkan); + // create index. Index does not own qzr + index = std::make_unique(qzr.get(), dim, nlist, ivf_flat_cc_cfg.ssize.value(), metric.value(), is_cosine); + // train index->train(rows, (const float*)data); + // replace quantizer with a regular IndexFlat + qzr = to_index_flat(std::move(qzr)); + index->quantizer = qzr.get(); + // transfer ownership of qzr to index + qzr.release(); + index->own_fields = true; // ivfflat_cc has no serialize stage, make map at build stage index->make_direct_map(true, faiss::DirectMap::ConcurrentArray); } @@ -296,48 +324,92 @@ IvfIndexNode::Train(const DataSet& dataset, const Config& cfg) { const IvfPqConfig& ivf_pq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_pq_cfg.nlist.value()); auto nbits = MatchNbits(rows, ivf_pq_cfg.nbits.value()); - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, ivf_pq_cfg.m.value(), nbits, metric.value()); + + const bool use_elkan = ivf_pq_cfg.use_elkan.value_or(true); + + // create quantizer for the training + std::unique_ptr qzr = + std::make_unique(dim, metric.value(), false, use_elkan); + // create index. Index does not own qzr + index = std::make_unique(qzr.get(), dim, nlist, ivf_pq_cfg.m.value(), nbits, metric.value()); + // train index->train(rows, (const float*)data); + // replace quantizer with a regular IndexFlat + qzr = to_index_flat(std::move(qzr)); + index->quantizer = qzr.get(); + // transfer ownership of qzr to index + qzr.release(); + index->own_fields = true; } if constexpr (std::is_same::value) { const ScannConfig& scann_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, scann_cfg.nlist.value()); bool is_cosine = base_cfg.metric_type.value() == metric::COSINE; - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - base_index = new (std::nothrow) - faiss::IndexIVFPQFastScan(qzr, dim, nlist, (dim + 1) / 2, 4, is_cosine, metric.value()); - base_index->own_fields = true; + + const bool use_elkan = scann_cfg.use_elkan.value_or(true); + + // create quantizer for the training + std::unique_ptr qzr = + std::make_unique(dim, metric.value(), false, use_elkan); + // create base index. it does not own qzr + auto base_index = + std::make_unique(qzr.get(), dim, nlist, (dim + 1) / 2, 4, is_cosine, metric.value()); + // create scann index, which does not base_index by default, + // but owns the refine index by default omg if (scann_cfg.with_raw_data.value()) { - index = std::make_unique(base_index, (const float*)data); + index = std::make_unique(base_index.get(), (const float*)data); } else { - index = std::make_unique(base_index, nullptr); + index = std::make_unique(base_index.get(), nullptr); } + // train index->train(rows, (const float*)data); + // at this moment, we still own qzr. + // replace quantizer with a regular IndexFlat + qzr = to_index_flat(std::move(qzr)); + base_index->quantizer = qzr.get(); + // release qzr + qzr.release(); + base_index->own_fields = true; + // transfer ownership of the base index + base_index.release(); + index->own_fields = true; } if constexpr (std::is_same::value) { const IvfSqConfig& ivf_sq_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_sq_cfg.nlist.value()); - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); + + const bool use_elkan = ivf_sq_cfg.use_elkan.value_or(true); + + // create quantizer for the training + std::unique_ptr qzr = + std::make_unique(dim, metric.value(), false, use_elkan); + // create index. Index does not own qzr index = std::make_unique( - qzr, dim, nlist, faiss::ScalarQuantizer::QuantizerType::QT_8bit, metric.value()); + qzr.get(), dim, nlist, faiss::ScalarQuantizer::QuantizerType::QT_8bit, metric.value()); + // train index->train(rows, (const float*)data); + // replace quantizer with a regular IndexFlat + qzr = to_index_flat(std::move(qzr)); + index->quantizer = qzr.get(); + // transfer ownership of qzr to index + qzr.release(); + index->own_fields = true; } if constexpr (std::is_same::value) { const IvfBinConfig& ivf_bin_cfg = static_cast(cfg); auto nlist = MatchNlist(rows, ivf_bin_cfg.nlist.value()); - qzr = new (std::nothrow) typename QuantizerT::type(dim, metric.value()); - index = std::make_unique(qzr, dim, nlist, metric.value()); + + // create quantizer + auto qzr = std::make_unique(dim, metric.value()); + // create index. Index does not own qzr + index = std::make_unique(qzr.get(), dim, nlist, metric.value()); + // train index->train(rows, (const uint8_t*)data); + // transfer ownership of qzr to index + qzr.release(); + index->own_fields = true; } - index->own_fields = true; } catch (std::exception& e) { - if (qzr) { - delete qzr; - } - if (base_index) { - delete base_index; - } LOG_KNOWHERE_WARNING_ << "faiss inner error: " << e.what(); return Status::faiss_inner_error; } diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index ee900069e..fabd1d6a2 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -20,6 +20,7 @@ class IvfConfig : public BaseConfig { public: CFG_INT nlist; CFG_INT nprobe; + CFG_BOOL use_elkan; KNOHWERE_DECLARE_CONFIG(IvfConfig) { KNOWHERE_CONFIG_DECLARE_FIELD(nlist) .set_default(128) @@ -31,6 +32,10 @@ class IvfConfig : public BaseConfig { .description("number of probes at query time.") .for_search() .set_range(1, 65536); + KNOWHERE_CONFIG_DECLARE_FIELD(use_elkan) + .set_default(true) + .description("whether to use elkan algorithm") + .for_train(); } }; diff --git a/tests/faiss_isolated/cmake/utils/platform_check.cmake b/tests/faiss_isolated/cmake/utils/platform_check.cmake new file mode 100644 index 000000000..d713a2d44 --- /dev/null +++ b/tests/faiss_isolated/cmake/utils/platform_check.cmake @@ -0,0 +1,12 @@ +include(CheckSymbolExists) + +macro(detect_target_arch) + check_symbol_exists(__aarch64__ "" __AARCH64) + check_symbol_exists(__x86_64__ "" __X86_64) + + if(NOT __AARCH64 AND NOT __X86_64) + message(FATAL "knowhere only support amd64 and arm64.") + endif() +endmacro() + +detect_target_arch() diff --git a/tests/ut/test_knowhere_init.cc b/tests/ut/test_knowhere_init.cc index 703b166c6..3e63c3c8d 100644 --- a/tests/ut/test_knowhere_init.cc +++ b/tests/ut/test_knowhere_init.cc @@ -41,7 +41,7 @@ TEST_CASE("Knowhere global config", "[init]") { } TEST_CASE("Knowhere SIMD config", "[simd]") { - std::vector v = {"AVX512", "AVX2", "SSE4_2", "GENERIC"}; + std::vector v = {"AVX512", "AVX2", "SSE4_2", "GENERIC", "NEON"}; std::unordered_set s(v.begin(), v.end()); auto res = knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX512); diff --git a/tests/ut/test_search.cc b/tests/ut/test_search.cc index 33edb70bf..aea7896b3 100644 --- a/tests/ut/test_search.cc +++ b/tests/ut/test_search.cc @@ -32,9 +32,9 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { const int64_t nb = 1000, nq = 10; const int64_t dim = 128; - const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::L2, knowhere::metric::COSINE); + auto topk = GENERATE(as{}, 5, 120); auto version = GenTestVersionList(); auto base_gen = [=]() { @@ -89,7 +89,7 @@ TEST_CASE("Test Mem Index With Float Vector", "[float metrics]") { knowhere::Json json = base_gen(); json[knowhere::indexparam::HNSW_M] = 128; json[knowhere::indexparam::EFCONSTRUCTION] = 200; - json[knowhere::indexparam::EF] = 64; + json[knowhere::indexparam::EF] = 200; return json; }; @@ -270,9 +270,9 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { const int64_t nb = 1000, nq = 10; const int64_t dim = 1024; - const int64_t topk = 5; auto metric = GENERATE(as{}, knowhere::metric::HAMMING, knowhere::metric::JACCARD); + auto topk = GENERATE(as{}, 5, 120); auto version = GenTestVersionList(); auto base_gen = [=]() { knowhere::Json json; @@ -288,7 +288,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { auto ivfflat_gen = [base_gen]() { knowhere::Json json = base_gen(); json[knowhere::indexparam::NLIST] = 16; - json[knowhere::indexparam::NPROBE] = 8; + json[knowhere::indexparam::NPROBE] = 14; return json; }; @@ -296,7 +296,7 @@ TEST_CASE("Test Mem Index With Binary Vector", "[float metrics]") { knowhere::Json json = base_gen(); json[knowhere::indexparam::HNSW_M] = 128; json[knowhere::indexparam::EFCONSTRUCTION] = 200; - json[knowhere::indexparam::EF] = 64; + json[knowhere::indexparam::EF] = 200; return json; }; @@ -377,11 +377,11 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { using Catch::Approx; const int64_t nb = 1000, nq = 10; - const int64_t topk = 5; auto dim = GENERATE(as{}, 8, 16, 32, 64, 128, 256, 512, 160); auto version = GenTestVersionList(); auto metric = GENERATE(as{}, knowhere::metric::SUPERSTRUCTURE, knowhere::metric::SUBSTRUCTURE); + auto topk = GENERATE(as{}, 5, 100); auto base_gen = [=]() { knowhere::Json json; @@ -441,11 +441,17 @@ TEST_CASE("Test Mem Index With Binary Vector", "[bool metrics]") { auto code_size = dim / 8; for (int64_t i = 0; i < nq; i++) { const uint8_t* query_vector = (const uint8_t*)query_ds->GetTensor() + i * code_size; - std::vector ids_v(ids + i * topk, ids + (i + 1) * topk); - auto ds = GenIdsDataSet(topk, ids_v); + // filter out -1 when the result num less than topk + int64_t real_topk = 0; + for (; real_topk < topk; real_topk++) { + if (ids[i * topk + real_topk] < 0) + break; + } + std::vector ids_v(ids + i * topk, ids + i * topk + real_topk); + auto ds = GenIdsDataSet(real_topk, ids_v); auto gv_res = idx.GetVectorByIds(*ds); REQUIRE(gv_res.has_value()); - for (int64_t j = 0; j < topk; j++) { + for (int64_t j = 0; j < real_topk; j++) { const uint8_t* res_vector = (const uint8_t*)gv_res.value()->GetTensor() + j * code_size; if (metric == knowhere::metric::SUPERSTRUCTURE) { REQUIRE(faiss::is_subset(res_vector, query_vector, code_size)); diff --git a/thirdparty/faiss/faiss/IndexFlat.h b/thirdparty/faiss/faiss/IndexFlat.h index 42998e0f6..c045f2779 100644 --- a/thirdparty/faiss/faiss/IndexFlat.h +++ b/thirdparty/faiss/faiss/IndexFlat.h @@ -24,6 +24,9 @@ struct IndexFlat : IndexFlatCodes { explicit IndexFlat(idx_t d, MetricType metric = METRIC_L2, bool is_cosine = false); + // Be careful with overriding this function, because + // renormalized x may be used inside. + // Overridden by IndexFlat1D. void add(idx_t n, const float* x) override; void search( diff --git a/thirdparty/faiss/faiss/IndexFlatElkan.cpp b/thirdparty/faiss/faiss/IndexFlatElkan.cpp new file mode 100644 index 000000000..b9c48e271 --- /dev/null +++ b/thirdparty/faiss/faiss/IndexFlatElkan.cpp @@ -0,0 +1,79 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#include + +#include + +#include +#include + +namespace faiss { + +IndexFlatElkan::IndexFlatElkan(idx_t d, MetricType metric, bool is_cosine, bool use_elkan) + : IndexFlat(d, metric, is_cosine) { + this->use_elkan = use_elkan; +} + +void IndexFlatElkan::search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params) const { + // usually used in IVF k-means algorithm + + FAISS_THROW_IF_NOT_MSG( + k == 1, + "this index requires k == 1 in a search() call." + ); + FAISS_THROW_IF_NOT_MSG( + params == nullptr, + "search params not supported for this index" + ); + + float* dis_inner = distances; + std::unique_ptr dis_inner_deleter = nullptr; + if (distances == nullptr) { + dis_inner_deleter = std::make_unique(n); + dis_inner = dis_inner_deleter.get(); + } + + switch (metric_type) { + case METRIC_INNER_PRODUCT: + case METRIC_L2: { + // ignore the metric_type, both use L2 + if (use_elkan) { + // use elkan + elkan_L2_sse(x, get_xb(), d, n, ntotal, labels, dis_inner); + } + else { + // use L2 search. The same code as in IndexFlat::search() for L2. + IDSelector* sel = params ? params->sel : nullptr; + + float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances}; + knn_L2sqr(x, get_xb(), d, n, ntotal, &res, nullptr, sel); + } + + break; + } + default: { + // binary metrics + // There may be something wrong, but maintain the original logic + // now. + IndexFlat::search(n, x, k, dis_inner, labels, params); + break; + } + } +} + +} diff --git a/thirdparty/faiss/faiss/IndexFlatElkan.h b/thirdparty/faiss/faiss/IndexFlatElkan.h new file mode 100644 index 000000000..555f2d46d --- /dev/null +++ b/thirdparty/faiss/faiss/IndexFlatElkan.h @@ -0,0 +1,45 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#pragma once + +#include + +namespace faiss { + +// This is a special modification of IndexFlat that does two things. +// 1. It allows to use elkan algorithm for the search. It is slower, +// sometimes a magnitude slower than the regular IndexFlat::search() +// implementation, but sometimes the trained index produces a better +// recall rate. +// 2. It always uses L2 distance for the IP / L2 metrics in order to +// support an early stop strategy from Clustering.cpp. Early stop +// strategy is a Knowhere-specific feature. +// +// This index is intended to be used in Knowhere's ivf.cc file ONLY!!! +// +// Elkan algo was introduced into Knowhere in #2178, #2180 and #2258. +struct IndexFlatElkan : IndexFlat { + bool use_elkan = true; + + explicit IndexFlatElkan(idx_t d, MetricType metric = METRIC_L2, + bool is_cosine = false, bool use_elkan = true); + + void search( + idx_t n, + const float* x, + idx_t k, + float* distances, + idx_t* labels, + const SearchParameters* params = nullptr) const override; +}; + +} \ No newline at end of file diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp index a06b28f71..86d5cf84a 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.cpp +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.cpp @@ -43,11 +43,14 @@ IndexIVFFastScan::IndexIVFFastScan( size_t d, size_t nlist, size_t code_size, - MetricType metric) + MetricType metric, + bool is_cosine) : IndexIVF(quantizer, d, nlist, code_size, metric) { // unlike other indexes, we prefer no residuals for performance reasons. by_residual = false; FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT); + + this->is_cosine = is_cosine; } IndexIVFFastScan::IndexIVFFastScan() { @@ -90,10 +93,36 @@ IndexIVFFastScan::~IndexIVFFastScan() = default; * Code management functions *********************************************************/ +void IndexIVFFastScan::train(idx_t n, const float* x) { + if (is_cosine) { + auto norm_data = std::make_unique(n * d); + std::memcpy(norm_data.get(), x, n * d * sizeof(float)); + knowhere::NormalizeVecs(norm_data.get(), n, d); + IndexIVF::train(n, norm_data.get()); + } else { + IndexIVF::train(n, x); + } +} + void IndexIVFFastScan::add_with_ids( idx_t n, const float* x, const idx_t* xids) { + if (is_cosine) { + auto norm_data = std::make_unique(n * d); + std::memcpy(norm_data.get(), x, n * d * sizeof(float)); + norms = std::move(knowhere::NormalizeVecs(norm_data.get(), n, d)); + add_with_ids_impl(n, norm_data.get(), xids); + } else { + add_with_ids_impl(n, x, xids); + } +} + +// knowhere-specific function +void IndexIVFFastScan::add_with_ids_impl( + idx_t n, + const float* x, + const idx_t* xids) { FAISS_THROW_IF_NOT(is_trained); // do some blocking to avoid excessive allocs @@ -118,7 +147,7 @@ void IndexIVFFastScan::add_with_ids( total_time, mem); } - add_with_ids(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr); + add_with_ids_impl(i1 - i0, x + i0 * d, xids ? xids + i0 : nullptr); } return; } diff --git a/thirdparty/faiss/faiss/IndexIVFFastScan.h b/thirdparty/faiss/faiss/IndexIVFFastScan.h index b88561fed..b242bebc9 100644 --- a/thirdparty/faiss/faiss/IndexIVFFastScan.h +++ b/thirdparty/faiss/faiss/IndexIVFFastScan.h @@ -50,12 +50,16 @@ struct IndexIVFFastScan : IndexIVF { int qbs = 0; size_t qbs2 = 0; + // // todo aguzhva: get rid of this + std::vector norms; + IndexIVFFastScan( Index* quantizer, size_t d, size_t nlist, size_t code_size, - MetricType metric = METRIC_L2); + MetricType metric = METRIC_L2, + bool is_cosine = false); IndexIVFFastScan(); @@ -74,7 +78,18 @@ struct IndexIVFFastScan : IndexIVF { /// orig's inverted lists (for debugging) InvertedLists* orig_invlists = nullptr; - void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; + // Knowhere-specific function, needed for norms, introduced in PR #1 + // final is needed because 'x' can be renormalized inside it, + // so a derived class is not allowed to override this function. + void add_with_ids(idx_t n, const float* x, const idx_t* xids) override final; + + // This matches Faiss baseline. + void add_with_ids_impl(idx_t n, const float* x, const idx_t* xids); + + // Knowhere-specific override. + // final is needed because 'x' can be renormalized inside it, + // so a derived class is not allowed to override this function. + void train(idx_t n, const float* x) override final; // prepare look-up tables diff --git a/thirdparty/faiss/faiss/IndexIVFFlat.h b/thirdparty/faiss/faiss/IndexIVFFlat.h index 88b19681b..42899708d 100644 --- a/thirdparty/faiss/faiss/IndexIVFFlat.h +++ b/thirdparty/faiss/faiss/IndexIVFFlat.h @@ -31,6 +31,9 @@ struct IndexIVFFlat : IndexIVF { void restore_codes(const uint8_t* raw_data, const size_t raw_size); + // Be careful with overriding this function, because + // renormalized x may be used inside. + // Overridden by IndexIVFFlatDedup. void train(idx_t n, const float* x) override; void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; diff --git a/thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp b/thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp index ce7345f1a..af9346d14 100644 --- a/thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp +++ b/thirdparty/faiss/faiss/IndexIVFPQFastScan.cpp @@ -61,7 +61,7 @@ IndexIVFPQFastScan::IndexIVFPQFastScan( MetricType metric, int bbs) : IndexIVFPQFastScan(quantizer, d, nlist, M, nbits_per_idx, metric, bbs) { - is_cosine_ = is_cosine; + this->is_cosine = is_cosine; } IndexIVFPQFastScan::IndexIVFPQFastScan() { @@ -133,17 +133,6 @@ void IndexIVFPQFastScan::train_encoder( } } -void IndexIVFPQFastScan::train(idx_t n, const float* x) { - if (is_cosine_) { - auto norm_data = std::make_unique(n * d); - std::memcpy(norm_data.get(), x, n * d * sizeof(float)); - knowhere::NormalizeVecs(norm_data.get(), n, d); - IndexIVFFastScan::train(n, norm_data.get()); - } else { - IndexIVFFastScan::train(n, x); - } -} - idx_t IndexIVFPQFastScan::train_encoder_num_vectors() const { return pq.cp.max_points_per_centroid * pq.ksub; } @@ -158,20 +147,6 @@ void IndexIVFPQFastScan::precompute_table() { verbose); } -void IndexIVFPQFastScan::add_with_ids( - idx_t n, - const float* x, - const idx_t* xids) { - if (is_cosine_) { - auto norm_data = std::make_unique(n * d); - std::memcpy(norm_data.get(), x, n * d * sizeof(float)); - norms = std::move(knowhere::NormalizeVecs(norm_data.get(), n, d)); - IndexIVFFastScan::add_with_ids(n, norm_data.get(), xids); - } else { - IndexIVFFastScan::add_with_ids(n, x, xids); - } -} - /********************************************************* * Code management functions *********************************************************/ diff --git a/thirdparty/faiss/faiss/IndexIVFPQFastScan.h b/thirdparty/faiss/faiss/IndexIVFPQFastScan.h index 201805b79..66915ad7b 100644 --- a/thirdparty/faiss/faiss/IndexIVFPQFastScan.h +++ b/thirdparty/faiss/faiss/IndexIVFPQFastScan.h @@ -40,10 +40,6 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { /// if use_precompute_table size (nlist, pq.M, pq.ksub) AlignedTable precomputed_table; - // // todo aguzhva: get rid of this - bool is_cosine_ = false; - std::vector norms; - // todo agzuhva: add back cosine support from knowhere IndexIVFPQFastScan( Index* quantizer, @@ -71,12 +67,8 @@ struct IndexIVFPQFastScan : IndexIVFFastScan { void train_encoder(idx_t n, const float* x, const idx_t* assign) override; - void train(idx_t n, const float* x) override; - idx_t train_encoder_num_vectors() const override; - void add_with_ids(idx_t n, const float* x, const idx_t* xids) override; - /// build precomputed table, possibly updating use_precomputed_table void precompute_table(); diff --git a/thirdparty/faiss/faiss/IndexScaNN.cpp b/thirdparty/faiss/faiss/IndexScaNN.cpp index b1ffb105d..1d48cb42c 100644 --- a/thirdparty/faiss/faiss/IndexScaNN.cpp +++ b/thirdparty/faiss/faiss/IndexScaNN.cpp @@ -170,7 +170,7 @@ void IndexScaNN::search( rf->compute_distance_subset(n, x, k_base, base_distances, base_labels); - if (base->is_cosine_) { + if (base->is_cosine) { for (idx_t i = 0; i < n * k_base; i++) { if (base_labels[i] >= 0) { base_distances[i] /= base->norms[base_labels[i]]; @@ -234,7 +234,7 @@ void IndexScaNN::range_search( idx_t current = 0; for (idx_t i = 0; i < result->lims[1]; ++i) { - if (base->is_cosine_) { + if (base->is_cosine) { result->distances[i] /= base->norms[result->labels[i]]; } if (metric_type == METRIC_L2) { diff --git a/thirdparty/faiss/faiss/impl/index_read.cpp b/thirdparty/faiss/faiss/impl/index_read.cpp index d34451e45..b57ed9bb0 100644 --- a/thirdparty/faiss/faiss/impl/index_read.cpp +++ b/thirdparty/faiss/faiss/impl/index_read.cpp @@ -1212,8 +1212,8 @@ Index* read_index(IOReader* f, int io_flags) { READ1(ivpq->M2); READ1(ivpq->implem); READ1(ivpq->qbs2); - READ1(ivpq->is_cosine_); - if (ivpq->is_cosine_) { + READ1(ivpq->is_cosine); + if (ivpq->is_cosine) { READVECTOR(ivpq->norms); } read_ProductQuantizer(&ivpq->pq, f); diff --git a/thirdparty/faiss/faiss/impl/index_write.cpp b/thirdparty/faiss/faiss/impl/index_write.cpp index f80fc77ed..bac1bc59b 100644 --- a/thirdparty/faiss/faiss/impl/index_write.cpp +++ b/thirdparty/faiss/faiss/impl/index_write.cpp @@ -1002,8 +1002,8 @@ void write_index(const Index* idx, IOWriter* f) { WRITE1(ivpq->M2); WRITE1(ivpq->implem); WRITE1(ivpq->qbs2); - WRITE1(ivpq->is_cosine_); - if (ivpq->is_cosine_) { + WRITE1(ivpq->is_cosine); + if (ivpq->is_cosine) { WRITEVECTOR(ivpq->norms); } write_ProductQuantizer(&ivpq->pq, f); diff --git a/thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp b/thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp index 184d4101a..8121c7b90 100644 --- a/thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +++ b/thirdparty/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp @@ -123,7 +123,26 @@ void accumulate_q_4step( constexpr int Q4 = (QBS >> 12) & 15; constexpr int SQ = Q1 + Q2 + Q3 + Q4; - for (int64_t j0 = 0; j0 < ntotal2; j0 += 32) { + for (int64_t j0 = 0; j0 < ntotal2; j0 += 32, codes += 32 * nsq / 2) { + res.set_block_origin(0, j0); + + // skip computing distances if all vectors inside a block are filtered out + if (res.sel != nullptr) { // we have filter here + bool skip_flag = true; + for (int64_t jj = 0; jj < std::min(32, ntotal2 - j0); + jj++) { + auto real_idx = res.adjust_id(0, jj); + if (res.sel->is_member(real_idx)) { // id is not filtered out, can not skip computing + skip_flag = false; + break; + } + } + + if (skip_flag) { + continue; + } + } + FixedStorageHandler res2; const uint8_t* LUT = LUT0; kernel_accumulate_block(nsq, codes, LUT, res2, scaler); @@ -142,9 +161,7 @@ void accumulate_q_4step( res2.set_block_origin(Q1 + Q2 + Q3, 0); kernel_accumulate_block(nsq, codes, LUT, res2, scaler); } - res.set_block_origin(0, j0); res2.to_other_handler(res); - codes += 32 * nsq / 2; } }