Skip to content

Commit

Permalink
fix_hnsw_pq_iterator
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed Nov 7, 2024
1 parent 9e6af18 commit 9de27c4
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 36 deletions.
49 changes: 28 additions & 21 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,24 +551,32 @@ class FaissHnswIterator : public IndexIterator {
// wrap a sign, if needed
workspace.qdis = std::unique_ptr<faiss::DistanceComputer>(storage_distance_computer(index_hnsw));

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer())));
if (refine_ratio != 0) {
// the refine is needed

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(
index_refine->refine_index->get_distance_computer())));
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
}
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
// the refine is not needed
workspace.qdis_refine = nullptr;
}
} else {
const faiss::IndexHNSW* index_hnsw = dynamic_cast<const faiss::IndexHNSW*>(index.get());
Expand Down Expand Up @@ -1229,11 +1237,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
throw;
}

//
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);

const float iterator_refine_ratio =
(dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr)
? hnsw_cfg.iterator_refine_ratio.value_or(0.5)
: 0;
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;

// create an iterator and initialize it
auto it =
Expand Down
11 changes: 8 additions & 3 deletions src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,10 @@ class FaissHnswConfig : public BaseHnswConfig {
.for_static();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_k)
.description("refine k")
.allow_empty_without_default()
.set_default(1)
.set_range(1, std::numeric_limits<CFG_FLOAT::value_type>::max())
.for_search();
.for_search()
.for_range_search();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_type)
.description("the type of a refine index")
.allow_empty_without_default()
Expand Down Expand Up @@ -83,7 +84,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig {
// check our parameters
if (param_type == PARAM_TYPE::TRAIN) {
// prohibit refine
if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) {
if (refine.value_or(false) || refine_type.has_value()) {
if (err_msg) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
Expand Down Expand Up @@ -189,6 +190,8 @@ class FaissHnswPqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down Expand Up @@ -232,6 +235,8 @@ class FaissHnswPrqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down
17 changes: 5 additions & 12 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
Expand All @@ -209,8 +206,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down Expand Up @@ -286,8 +283,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
Expand All @@ -297,8 +292,8 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down Expand Up @@ -334,8 +329,6 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator insufficient results") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
Expand All @@ -346,7 +339,7 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
Expand Down

0 comments on commit 9de27c4

Please sign in to comment.