diff --git a/knowhere/index/VecIndex.h b/knowhere/index/VecIndex.h index 0b31b0571..5aa9d77d4 100644 --- a/knowhere/index/VecIndex.h +++ b/knowhere/index/VecIndex.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "knowhere/common/Dataset.h" #include "knowhere/common/Exception.h" @@ -84,6 +85,9 @@ class VecIndex : public Index { KNOWHERE_THROW_MSG("GetVectorById not supported yet"); } + virtual bool + HasRawData(const std::string& metric_type) const = 0; + /** * @brief TopK Query. if the result size is smaller than K, this API will fill the return ids with -1 and distances * with FLOAT_MIN or FLOAT_MAX depends on the metric type. diff --git a/knowhere/index/VecIndexThreadPoolWrapper.h b/knowhere/index/VecIndexThreadPoolWrapper.h index faa78e3a6..d0783c5ea 100644 --- a/knowhere/index/VecIndexThreadPoolWrapper.h +++ b/knowhere/index/VecIndexThreadPoolWrapper.h @@ -11,6 +11,7 @@ #pragma once +#include #include #include @@ -63,6 +64,11 @@ class VecIndexThreadPoolWrapper : public VecIndex { return index_->GetVectorById(dataset, config); } + bool + HasRawData(const std::string& metric_type) const override { + return index_->HasRawData(metric_type); + } + DatasetPtr Query(const DatasetPtr& dataset, const Config& config, const faiss::BitsetView bitset) override { return thread_pool_->push([&]() { return this->index_->Query(dataset, config, bitset); }).get(); diff --git a/knowhere/index/vector_index/IndexAnnoy.h b/knowhere/index/vector_index/IndexAnnoy.h index 7ef3abb00..81d79c53e 100644 --- a/knowhere/index/vector_index/IndexAnnoy.h +++ b/knowhere/index/vector_index/IndexAnnoy.h @@ -46,6 +46,11 @@ class IndexAnnoy : public VecIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexBinaryIDMAP.h b/knowhere/index/vector_index/IndexBinaryIDMAP.h index 0dd97fb04..bd4835405 100644 --- a/knowhere/index/vector_index/IndexBinaryIDMAP.h +++ b/knowhere/index/vector_index/IndexBinaryIDMAP.h @@ -14,6 +14,7 @@ #include #include #include +#include #include "knowhere/index/VecIndex.h" #include "knowhere/index/vector_index/FaissBaseBinaryIndex.h" @@ -45,6 +46,11 @@ class BinaryIDMAP : public VecIndex, public FaissBaseBinaryIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexBinaryIVF.h b/knowhere/index/vector_index/IndexBinaryIVF.h index f68f4634f..2cb091eaa 100644 --- a/knowhere/index/vector_index/IndexBinaryIVF.h +++ b/knowhere/index/vector_index/IndexBinaryIVF.h @@ -14,6 +14,7 @@ #include #include #include +#include #include @@ -50,6 +51,11 @@ class BinaryIVF : public VecIndex, public FaissBaseBinaryIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexDiskANN.h b/knowhere/index/vector_index/IndexDiskANN.h index 6f35791ff..1bd716609 100644 --- a/knowhere/index/vector_index/IndexDiskANN.h +++ b/knowhere/index/vector_index/IndexDiskANN.h @@ -70,6 +70,11 @@ class IndexDiskANN : public VecIndex { KNOWHERE_THROW_MSG("DiskANN doesn't support GetVectorById."); } + bool + HasRawData(const std::string& /*metric_type*/) const override { + return false; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexHNSW.h b/knowhere/index/vector_index/IndexHNSW.h index 9ad3851d6..9e8bb73d3 100644 --- a/knowhere/index/vector_index/IndexHNSW.h +++ b/knowhere/index/vector_index/IndexHNSW.h @@ -13,6 +13,7 @@ #include #include +#include #include "hnswlib/hnswlib/hnswlib.h" #include "knowhere/common/Exception.h" @@ -55,6 +56,11 @@ class IndexHNSW : public VecIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexIDMAP.h b/knowhere/index/vector_index/IndexIDMAP.h index 18cca27ce..73d311f17 100644 --- a/knowhere/index/vector_index/IndexIDMAP.h +++ b/knowhere/index/vector_index/IndexIDMAP.h @@ -14,6 +14,7 @@ #include #include #include +#include #include "knowhere/index/VecIndex.h" #include "knowhere/index/vector_index/FaissBaseIndex.h" @@ -45,6 +46,11 @@ class IDMAP : public VecIndex, public FaissBaseIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexIVF.h b/knowhere/index/vector_index/IndexIVF.h index 79be00a6c..4a4dc5506 100644 --- a/knowhere/index/vector_index/IndexIVF.h +++ b/knowhere/index/vector_index/IndexIVF.h @@ -14,6 +14,7 @@ #include #include #include +#include #include @@ -50,6 +51,11 @@ class IVF : public VecIndex, public FaissBaseIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/knowhere/index/vector_index/IndexIVFPQ.h b/knowhere/index/vector_index/IndexIVFPQ.h index ebe1aaa8d..d61c136c8 100644 --- a/knowhere/index/vector_index/IndexIVFPQ.h +++ b/knowhere/index/vector_index/IndexIVFPQ.h @@ -13,6 +13,7 @@ #include #include +#include #include "knowhere/index/vector_index/IndexIVF.h" @@ -35,6 +36,11 @@ class IVFPQ : public IVF { KNOWHERE_THROW_MSG("GetVectorById not supported yet"); } + bool + HasRawData(const std::string& /*metric_type*/) const override { + return false; + } + void Train(const DatasetPtr&, const Config&) override; diff --git a/knowhere/index/vector_index/IndexIVFSQ.h b/knowhere/index/vector_index/IndexIVFSQ.h index afc89b268..ef4b4f033 100644 --- a/knowhere/index/vector_index/IndexIVFSQ.h +++ b/knowhere/index/vector_index/IndexIVFSQ.h @@ -13,6 +13,7 @@ #include #include +#include #include "knowhere/index/vector_index/IndexIVF.h" @@ -35,6 +36,11 @@ class IVFSQ : public IVF { KNOWHERE_THROW_MSG("GetVectorById not supported yet"); } + bool + HasRawData(const std::string& /*metric_type*/) const override { + return false; + } + void Train(const DatasetPtr&, const Config&) override; diff --git a/knowhere/index/vector_offset_index/IndexIVF_NM.h b/knowhere/index/vector_offset_index/IndexIVF_NM.h index 6c64e0470..3a1b839d9 100644 --- a/knowhere/index/vector_offset_index/IndexIVF_NM.h +++ b/knowhere/index/vector_offset_index/IndexIVF_NM.h @@ -15,6 +15,7 @@ #include #include #include +#include #include @@ -51,6 +52,11 @@ class IVF_NM : public VecIndex, public OffsetBaseIndex { DatasetPtr GetVectorById(const DatasetPtr&, const Config&) override; + bool + HasRawData(const std::string& /*metric_type*/) const override { + return true; + } + DatasetPtr Query(const DatasetPtr&, const Config&, const faiss::BitsetView) override; diff --git a/unittest/AsyncIndex.h b/unittest/AsyncIndex.h index 34a8e85d7..1f1fd064d 100644 --- a/unittest/AsyncIndex.h +++ b/unittest/AsyncIndex.h @@ -104,6 +104,11 @@ class AsyncIndex : public VecIndex { return index_->Load(index_binary); } + bool + HasRawData(const std::string& metric_type) const override { + return false; + } + void BuildAll(const DatasetPtr& dataset_ptr, const Config& config) override { index_->BuildAll(dataset_ptr, config); diff --git a/unittest/test_annoy.cpp b/unittest/test_annoy.cpp index 2fbcd1621..c6ae257dc 100644 --- a/unittest/test_annoy.cpp +++ b/unittest/test_annoy.cpp @@ -57,6 +57,7 @@ TEST_P(AnnoyTest, annoy_basic) { ASSERT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_binaryidmap.cpp b/unittest/test_binaryidmap.cpp index 6711c276a..f280462ba 100644 --- a/unittest/test_binaryidmap.cpp +++ b/unittest/test_binaryidmap.cpp @@ -71,6 +71,7 @@ TEST_P(BinaryIDMAPTest, binaryidmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertBinVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_binaryivf.cpp b/unittest/test_binaryivf.cpp index ec47e3e5e..0443a8b59 100644 --- a/unittest/test_binaryivf.cpp +++ b/unittest/test_binaryivf.cpp @@ -68,6 +68,7 @@ TEST_P(BinaryIVFTest, binaryivf_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertBinVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_diskann.cpp b/unittest/test_diskann.cpp index f5e941c83..113c015c6 100644 --- a/unittest/test_diskann.cpp +++ b/unittest/test_diskann.cpp @@ -9,6 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License. +#include #include #include @@ -506,6 +507,18 @@ TEST_P(DiskANNTest, knn_search_test) { CheckDistanceError(raw_data_, query_data_, result, metric_, num_queries_, dim_, kK, num_rows_, is_large_dim_); } +TEST_P(DiskANNTest, get_vector_by_id) { + knowhere::Config cfg; + cfg.clear(); + knowhere::DiskANNPrepareConfig::Set(cfg, prep_conf); + EXPECT_TRUE(diskann->Prepare(cfg)); + cfg.clear(); + knowhere::DiskANNQueryConfig::Set(cfg, query_conf); + + ASSERT_FALSE(diskann->HasRawData(metric_)); + ASSERT_ANY_THROW(diskann->GetVectorById(nullptr, cfg)); +} + TEST_P(DiskANNTest, knn_search_with_accelerate_build_test) { if (is_large_dim_) { GTEST_SKIP() << "Skip build accelerate test for large dim."; diff --git a/unittest/test_hnsw.cpp b/unittest/test_hnsw.cpp index aa36edee2..b5bbf5abe 100644 --- a/unittest/test_hnsw.cpp +++ b/unittest/test_hnsw.cpp @@ -77,6 +77,7 @@ TEST_P(HNSWTest, HNSW_basic) { index_->Load(bs); + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_idmap.cpp b/unittest/test_idmap.cpp index 4cc51ac17..5eb401fc9 100644 --- a/unittest/test_idmap.cpp +++ b/unittest/test_idmap.cpp @@ -90,6 +90,7 @@ TEST_P(IDMAPTest, idmap_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim); diff --git a/unittest/test_ivf.cpp b/unittest/test_ivf.cpp index 92b2e4f0a..ea2f31963 100644 --- a/unittest/test_ivf.cpp +++ b/unittest/test_ivf.cpp @@ -91,6 +91,7 @@ TEST_P(IVFTest, ivf_basic) { EXPECT_EQ(index_->Dim(), dim); ASSERT_GT(index_->Size(), 0); if (index_mode_ == knowhere::IndexMode::MODE_CPU) { + ASSERT_FALSE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); ASSERT_ANY_THROW(index_->GetVectorById(id_dataset, conf_)); } diff --git a/unittest/test_ivf_nm.cpp b/unittest/test_ivf_nm.cpp index 4fd4dc4cf..94b89fc07 100644 --- a/unittest/test_ivf_nm.cpp +++ b/unittest/test_ivf_nm.cpp @@ -105,6 +105,7 @@ TEST_P(IVFNMTest, ivfnm_basic) { LoadRawData(index_, base_dataset, conf_); if (index_mode_ == knowhere::IndexMode::MODE_CPU) { + ASSERT_TRUE(index_->HasRawData(knowhere::GetMetaMetricType(conf_))); auto result = index_->GetVectorById(id_dataset, conf_); AssertVec(result, base_dataset, id_dataset, nq, dim);