Skip to content

Commit

Permalink
use faiss hnsw implementation to replace hnswlib
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy committed Nov 7, 2024
1 parent 4eeca80 commit 6339e5f
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 203 deletions.
11 changes: 4 additions & 7 deletions include/knowhere/comp/index_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,11 @@ constexpr const char* INDEX_GPU_IVFPQ = "GPU_IVF_PQ";
constexpr const char* INDEX_GPU_CAGRA = "GPU_CAGRA";

constexpr const char* INDEX_HNSW = "HNSW";
constexpr const char* INDEX_HNSW_SQ8 = "HNSW_SQ8";
constexpr const char* INDEX_HNSW_SQ8_REFINE = "HNSW_SQ8_REFINE";
constexpr const char* INDEX_DISKANN = "DISKANN";
constexpr const char* INDEX_HNSW_SQ = "HNSW_SQ";
constexpr const char* INDEX_HNSW_PQ = "HNSW_PQ";
constexpr const char* INDEX_HNSW_PRQ = "HNSW_PRQ";

constexpr const char* INDEX_FAISS_HNSW_FLAT = "FAISS_HNSW_FLAT";
constexpr const char* INDEX_FAISS_HNSW_SQ = "FAISS_HNSW_SQ";
constexpr const char* INDEX_FAISS_HNSW_PQ = "FAISS_HNSW_PQ";
constexpr const char* INDEX_FAISS_HNSW_PRQ = "FAISS_HNSW_PRQ";
constexpr const char* INDEX_DISKANN = "DISKANN";

constexpr const char* INDEX_SPARSE_INVERTED_INDEX = "SPARSE_INVERTED_INDEX";
constexpr const char* INDEX_SPARSE_WAND = "SPARSE_WAND";
Expand Down
42 changes: 12 additions & 30 deletions include/knowhere/index/index_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,17 @@ static std::set<std::pair<std::string, VecType>> legal_knowhere_index = {
{IndexEnum::INDEX_HNSW, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ8, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_SQ8_REFINE, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PQ, VecType::VECTOR_BFLOAT16},

// faiss hnsw
{IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_HNSW_FLAT, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_HNSW_SQ, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_HNSW_PQ, VecType::VECTOR_BFLOAT16},

{IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_FAISS_HNSW_PRQ, VecType::VECTOR_BFLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_FLOAT16},
{IndexEnum::INDEX_HNSW_PRQ, VecType::VECTOR_BFLOAT16},

// diskann
{IndexEnum::INDEX_DISKANN, VecType::VECTOR_FLOAT},
Expand All @@ -112,14 +99,9 @@ static std::set<std::string> legal_support_mmap_knowhere_index = {

// hnsw
IndexEnum::INDEX_HNSW,
IndexEnum::INDEX_HNSW_SQ8,
IndexEnum::INDEX_HNSW_SQ8_REFINE,

// faiss hnsw
IndexEnum::INDEX_FAISS_HNSW_FLAT,
IndexEnum::INDEX_FAISS_HNSW_SQ,
IndexEnum::INDEX_FAISS_HNSW_PQ,
IndexEnum::INDEX_FAISS_HNSW_PRQ,
IndexEnum::INDEX_HNSW_SQ,
IndexEnum::INDEX_HNSW_PQ,
IndexEnum::INDEX_HNSW_PRQ,

// sparse index
IndexEnum::INDEX_SPARSE_INVERTED_INDEX,
Expand Down
26 changes: 16 additions & 10 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1279,7 +1279,7 @@ class BaseFaissRegularIndexHNSWFlatNode : public BaseFaissRegularIndexHNSWNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT;
return knowhere::IndexEnum::INDEX_HNSW;
}

protected:
Expand Down Expand Up @@ -1789,7 +1789,7 @@ class BaseFaissRegularIndexHNSWSQNode : public BaseFaissRegularIndexHNSWNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ;
return knowhere::IndexEnum::INDEX_HNSW_SQ;
}

protected:
Expand Down Expand Up @@ -1904,7 +1904,7 @@ class BaseFaissRegularIndexHNSWPQNode : public BaseFaissRegularIndexHNSWNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ;
return knowhere::IndexEnum::INDEX_HNSW_PQ;
}

protected:
Expand Down Expand Up @@ -2103,7 +2103,7 @@ class BaseFaissRegularIndexHNSWPRQNode : public BaseFaissRegularIndexHNSWNode {

std::string
Type() const override {
return knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ;
return knowhere::IndexEnum::INDEX_HNSW_PRQ;
}

protected:
Expand Down Expand Up @@ -2288,15 +2288,21 @@ class BaseFaissRegularIndexHNSWPRQNodeTemplate : public BaseFaissRegularIndexHNS
}
};

//
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_FLAT,
// MV is only for compatibility
#ifdef KNOWHERE_WITH_CARDINAL
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_DEPRECATED,
BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback,
knowhere::feature::MMAP | knowhere::feature::MV)
#else
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW, BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback,
knowhere::feature::MMAP | knowhere::feature::MV)
#endif

KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_SQ, BaseFaissRegularIndexHNSWSQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PQ, BaseFaissRegularIndexHNSWPQNodeTemplate,
knowhere::feature::MMAP)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(FAISS_HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_PRQ, BaseFaissRegularIndexHNSWPRQNodeTemplate,
knowhere::feature::MMAP)

} // namespace knowhere
9 changes: 1 addition & 8 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@ template class HnswIndexNode<knowhere::fp32, hnswlib::QuantType::None>;
template class HnswIndexNode<knowhere::fp16, hnswlib::QuantType::None>;
template class HnswIndexNode<knowhere::bf16, hnswlib::QuantType::None>;

#ifdef KNOWHERE_WITH_CARDINAL
KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(HNSW_DEPRECATED, HnswIndexNode,
KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(HNSWLIB_DEPRECATED, HnswIndexNode,
knowhere::feature::MMAP | knowhere::feature::MV)
#else
KNOWHERE_SIMPLE_REGISTER_DENSE_ALL_GLOBAL(HNSW, HnswIndexNode, knowhere::feature::MMAP | knowhere::feature::MV)
#endif

KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ8, HnswIndexNode, knowhere::feature::MMAP, QuantType::SQ8)
KNOWHERE_SIMPLE_REGISTER_DENSE_FLOAT_ALL_GLOBAL(HNSW_SQ8_REFINE, HnswIndexNode, knowhere::feature::MMAP,
QuantType::SQ8Refine)
} // namespace knowhere
9 changes: 1 addition & 8 deletions src/index/hnsw/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,14 +565,7 @@ class HnswIndexNode : public IndexNode {

std::string
Type() const override {
if constexpr (quant_type == QuantType::SQ8) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8;
} else if constexpr (quant_type == QuantType::SQ8Refine) {
return knowhere::IndexEnum::INDEX_HNSW_SQ8_REFINE;

} else {
return knowhere::IndexEnum::INDEX_HNSW;
}
return knowhere::IndexEnum::INDEX_HNSW;
}

~HnswIndexNode() override {
Expand Down
11 changes: 4 additions & 7 deletions src/index/hnsw/impl/IndexBruteForceWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,10 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss:
// set up a filter
faiss::IDSelector* sel = (params == nullptr) ? nullptr : params->sel;

// sel is assumed to be non-null
if (sel == nullptr) {
throw;
}

// try knowhere-specific filter
const knowhere::BitsetViewIDSelector* __restrict bw_idselector =
dynamic_cast<const knowhere::BitsetViewIDSelector*>(sel);

BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view);

if (is_similarity_metric(index->metric_type)) {
using C = faiss::CMin<float, idx_t>;

Expand All @@ -88,6 +81,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss:
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, faiss::IDSelectorAll>(
index->ntotal, *dis, sel_all, k, local_distances, local_ids);
} else {
BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view);

faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer,
BitsetViewIDSelectorWrapper>(
index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids);
Expand All @@ -100,6 +95,8 @@ IndexBruteForceWrapper::search(faiss::idx_t n, const float* __restrict x, faiss:
faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer, faiss::IDSelectorAll>(
index->ntotal, *dis, sel_all, k, local_distances, local_ids);
} else {
BitsetViewIDSelectorWrapper bw_idselector_w(bw_idselector->bitset_view);

faiss::cppcontrib::knowhere::brute_force_search_impl<C, faiss::DistanceComputer,
BitsetViewIDSelectorWrapper>(
index->ntotal, *dis, bw_idselector_w, k, local_distances, local_ids);
Expand Down
5 changes: 2 additions & 3 deletions src/index/index_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,8 @@ IndexStaticFaced<DataType>::HasRawData(const IndexType& indexType, const IndexVe
}

static std::set<knowhere::IndexType> has_raw_data_index_set = {
IndexEnum::INDEX_FAISS_BIN_IDMAP, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, IndexEnum::INDEX_FAISS_IVFFLAT,
IndexEnum::INDEX_FAISS_IVFFLAT_CC, IndexEnum::INDEX_HNSW_SQ8_REFINE, IndexEnum::INDEX_SPARSE_INVERTED_INDEX,
IndexEnum::INDEX_SPARSE_WAND};
IndexEnum::INDEX_FAISS_BIN_IDMAP, IndexEnum::INDEX_FAISS_BIN_IVFFLAT, IndexEnum::INDEX_FAISS_IVFFLAT,
IndexEnum::INDEX_FAISS_IVFFLAT_CC, IndexEnum::INDEX_SPARSE_INVERTED_INDEX, IndexEnum::INDEX_SPARSE_WAND};

static std::set<knowhere::IndexType> has_raw_data_index_alias_set = {"IVFBIN", "BINFLAT", "IVFFLAT", "IVFFLATCC"};

Expand Down
18 changes: 9 additions & 9 deletions tests/ut/test_faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {
// is parallelized on its own

SECTION("FLAT") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -539,7 +539,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {
}

SECTION("SQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_SQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -714,7 +714,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {
}

SECTION("PQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -881,7 +881,7 @@ TEST_CASE("Search for FAISS HNSW Indices", "Benchmark and validation") {
}

SECTION("PRQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PRQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -1150,7 +1150,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
// is parallelized on its own

SECTION("FLAT") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW;
// const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

Expand Down Expand Up @@ -1246,7 +1246,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
}

SECTION("SQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_SQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_SQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -1435,7 +1435,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
}

SECTION("PQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -1619,7 +1619,7 @@ TEST_CASE("RangeSearch for FAISS HNSW Indices", "Benchmark and validation for Ra
}

SECTION("PRQ") {
const std::string& index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_PRQ;
const std::string& index_type = knowhere::IndexEnum::INDEX_HNSW_PRQ;
const std::string& golden_index_type = knowhere::IndexEnum::INDEX_FAISS_IDMAP;

for (size_t distance_type = 0; distance_type < DISTANCE_TYPES.size(); distance_type++) {
Expand Down Expand Up @@ -1834,7 +1834,7 @@ TEST_CASE("hnswlib to FAISS HNSW for HNSW_FLAT", "Check search fallback") {

//
const std::string hnswlib_index_type = knowhere::IndexEnum::INDEX_HNSW;
const std::string faiss_index_type = knowhere::IndexEnum::INDEX_FAISS_HNSW_FLAT;
const std::string faiss_index_type = knowhere::IndexEnum::INDEX_HNSW;

//
const auto dim = DIM;
Expand Down
47 changes: 24 additions & 23 deletions tests/ut/test_feder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,29 +175,30 @@ TEST_CASE("Test Feder", "[feder]") {
const knowhere::Json conf = base_gen();
auto gt = knowhere::BruteForce::Search<knowhere::fp32>(train_ds, query_ds, conf, nullptr);

SECTION("Test HNSW Feder") {
auto name = knowhere::IndexEnum::INDEX_HNSW;
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
REQUIRE(idx.Type() == name);

auto json = hnsw_gen();
auto res = idx.Build(train_ds, json);
REQUIRE(res == knowhere::Status::success);

auto res1 = idx.GetIndexMeta(json);
REQUIRE(res1.has_value());
CheckHnswMeta(res1.value(), nb, json);

auto res2 = idx.Search(query_ds, json, nullptr);
REQUIRE(res2.has_value());
CheckHnswVisitInfo(res2.value(), nb);

json[knowhere::meta::RADIUS] = 160000;
json[knowhere::meta::RANGE_FILTER] = 0;
auto res3 = idx.RangeSearch(query_ds, json, nullptr);
REQUIRE(res3.has_value());
CheckHnswVisitInfo(res3.value(), nb);
}
// Feder is deprecated, and faiss hnsw does not implement it
// SECTION("Test HNSW Feder") {
// auto name = knowhere::IndexEnum::INDEX_HNSW;
// auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
// REQUIRE(idx.Type() == name);
//
// auto json = hnsw_gen();
// auto res = idx.Build(train_ds, json);
// REQUIRE(res == knowhere::Status::success);
//
// auto res1 = idx.GetIndexMeta(json);
// REQUIRE(res1.has_value());
// CheckHnswMeta(res1.value(), nb, json);
//
// auto res2 = idx.Search(query_ds, json, nullptr);
// REQUIRE(res2.has_value());
// CheckHnswVisitInfo(res2.value(), nb);
//
// json[knowhere::meta::RADIUS] = 160000;
// json[knowhere::meta::RANGE_FILTER] = 0;
// auto res3 = idx.RangeSearch(query_ds, json, nullptr);
// REQUIRE(res3.has_value());
// CheckHnswVisitInfo(res3.value(), nb);
// }

SECTION("Test IVF_FLAT Feder") {
auto name = knowhere::IndexEnum::INDEX_FAISS_IVFFLAT;
Expand Down
2 changes: 2 additions & 0 deletions tests/ut/test_get_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IDMAP, bin_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_BIN_IVFFLAT, bin_ivfflat_gen),
#ifdef KNOWHERE_WITH_CARDINAL
make_tuple(knowhere::IndexEnum::INDEX_HNSW, bin_hnsw_gen),
#endif
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::bin1>(name, version).value();
auto cfg_json = gen().dump();
Expand Down
Loading

0 comments on commit 6339e5f

Please sign in to comment.