Skip to content

Commit

Permalink
fix_hnsw_pq_iterator (#932)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva authored Nov 9, 2024
1 parent 9e6af18 commit 27dea7f
Show file tree
Hide file tree
Showing 7 changed files with 116 additions and 50 deletions.
63 changes: 38 additions & 25 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -551,24 +551,32 @@ class FaissHnswIterator : public IndexIterator {
// wrap a sign, if needed
workspace.qdis = std::unique_ptr<faiss::DistanceComputer>(storage_distance_computer(index_hnsw));

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer())));
if (refine_ratio != 0) {
// the refine is needed

// a tricky point here.
// Basically, if out hnsw index's storage is HasInverseL2Norms, then
// this is a cosine index. But because refine always keeps original
// data, then we need to use a wrapper over a distance computer
const faiss::HasInverseL2Norms* has_l2_norms =
dynamic_cast<const faiss::HasInverseL2Norms*>(index_hnsw->storage);
if (has_l2_norms != nullptr) {
// add a cosine wrapper over it
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(new faiss::WithCosineNormDistanceComputer(
has_l2_norms->get_inverse_l2_norms(), index->d,
std::unique_ptr<faiss::DistanceComputer>(
index_refine->refine_index->get_distance_computer())));
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
}
} else {
// use it as is
// DO NOT WRAP A SIGN, by design
workspace.qdis_refine =
std::unique_ptr<faiss::DistanceComputer>(index_refine->refine_index->get_distance_computer());
// the refine is not needed
workspace.qdis_refine = nullptr;
}
} else {
const faiss::IndexHNSW* index_hnsw = dynamic_cast<const faiss::IndexHNSW*>(index.get());
Expand Down Expand Up @@ -882,9 +890,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<DataSetPtr>::Err(Status::invalid_args, "k parameter is missing");
}

// whether a user wants a refine
const bool whether_to_enable_refine = hnsw_cfg.refine_k.has_value();

// set up an index wrapper
auto [index_wrapper, is_refined] =
create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false));
auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper(
index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine);

if (index_wrapper == nullptr) {
return expected<DataSetPtr>::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW");
Expand Down Expand Up @@ -1011,9 +1022,12 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
return expected<DataSetPtr>::Err(Status::invalid_args, "ef parameter is missing");
}

// whether a user wants a refine
const bool whether_to_enable_refine = true;

// set up an index wrapper
auto [index_wrapper, is_refined] =
create_conditional_hnsw_wrapper(index.get(), hnsw_cfg, whether_bf_search.value_or(false));
auto [index_wrapper, is_refined] = create_conditional_hnsw_wrapper(
index.get(), hnsw_cfg, whether_bf_search.value_or(false), whether_to_enable_refine);

if (index_wrapper == nullptr) {
return expected<DataSetPtr>::Err(Status::invalid_args, "an input index seems to be unrelated to HNSW");
Expand Down Expand Up @@ -1229,11 +1243,10 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode {
throw;
}

//
const bool should_use_refine = (dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr);

const float iterator_refine_ratio =
(dynamic_cast<const faiss::IndexRefine*>(index.get()) != nullptr)
? hnsw_cfg.iterator_refine_ratio.value_or(0.5)
: 0;
should_use_refine ? hnsw_cfg.iterator_refine_ratio.value_or(0.5) : 0;

// create an iterator and initialize it
auto it =
Expand Down
8 changes: 6 additions & 2 deletions src/index/hnsw/faiss_hnsw_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class FaissHnswConfig : public BaseHnswConfig {
.for_static();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_k)
.description("refine k")
.allow_empty_without_default()
.set_default(1)
.set_range(1, std::numeric_limits<CFG_FLOAT::value_type>::max())
.for_search();
KNOWHERE_CONFIG_DECLARE_FIELD(refine_type)
Expand Down Expand Up @@ -83,7 +83,7 @@ class FaissHnswFlatConfig : public FaissHnswConfig {
// check our parameters
if (param_type == PARAM_TYPE::TRAIN) {
// prohibit refine
if (refine.value_or(false) || refine_type.has_value() || refine_k.has_value()) {
if (refine.value_or(false) || refine_type.has_value()) {
if (err_msg) {
*err_msg = "refine is not supported for this index";
LOG_KNOWHERE_ERROR_ << *err_msg;
Expand Down Expand Up @@ -189,6 +189,8 @@ class FaissHnswPqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down Expand Up @@ -232,6 +234,8 @@ class FaissHnswPrqConfig : public FaissHnswConfig {
}
}
}
default:
break;
}
return Status::success;
}
Expand Down
10 changes: 7 additions & 3 deletions src/index/hnsw/impl/IndexConditionalWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,13 @@ WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswCo
return false;
}

// returns nullptr in case of invalid index
// returns nullptr in case of invalid index.
//
// `whether_to_enable_refine` allows to enable the refine for the search if the
// index was trained with the refine.
std::tuple<std::unique_ptr<faiss::Index>, bool>
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search) {
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search,
const bool whether_to_enable_refine) {
const bool is_cosine = IsMetricType(hnsw_cfg.metric_type.value(), knowhere::metric::COSINE);

// check if we have a refine available.
Expand Down Expand Up @@ -126,7 +130,7 @@ create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw
}

// check if a user wants a refined result
if (hnsw_cfg.refine_k.has_value()) {
if (whether_to_enable_refine) {
// yes, a user wants to perform a refine

// thus, we need to define a new refine index and pass
Expand Down
8 changes: 6 additions & 2 deletions src/index/hnsw/impl/IndexConditionalWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,12 @@ std::optional<bool>
WhetherPerformBruteForceRangeSearch(const faiss::Index* index, const FaissHnswConfig& cfg, const BitsetView& bitset);

// first return arg: returns nullptr in case of invalid index
// second return arg: returns whether an index does refine
// second return arg: returns whether an index does the refine
//
// `whether_to_enable_refine` allows to enable the refine for the search if the
// index was trained with the refine.
std::tuple<std::unique_ptr<faiss::Index>, bool>
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search);
create_conditional_hnsw_wrapper(faiss::Index* index, const FaissHnswConfig& hnsw_cfg, const bool whether_bf_search,
const bool whether_to_enable_refine);

} // namespace knowhere
29 changes: 11 additions & 18 deletions tests/ut/test_iterator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -286,20 +283,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down Expand Up @@ -334,20 +329,18 @@ TEST_CASE("Test Iterator Mem Index With Float Vector", "[float metrics]") {
SECTION("Test Search with Bitset using iterator insufficient results") {
using std::make_tuple;
auto [name, gen] = GENERATE_REF(table<std::string, std::function<knowhere::Json()>>({
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFFLAT_CC, ivfflatcc_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ8, ivf_base_gen),
make_tuple(knowhere::IndexEnum::INDEX_FAISS_IVFSQ_CC, ivf_sq_cc_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW, hnsw_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_SQ, hnsw_sq_refine_fp16_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_flat_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PQ, hnsw_pq_refine_sq8_gen),
// make_tuple(knowhere::IndexEnum::INDEX_HNSW_PRQ, hnsw_prq_gen),
}));
auto idx = knowhere::IndexFactory::Instance().Create<knowhere::fp32>(name, version).value();
auto cfg_json = gen().dump();
Expand Down
41 changes: 41 additions & 0 deletions thirdparty/faiss/faiss/IndexRefine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,47 @@ void IndexRefine::search(
}
}

void IndexRefine::range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params_in) const
{
const IndexRefineSearchParameters* params = nullptr;
if (params_in) {
params = dynamic_cast<const IndexRefineSearchParameters*>(params_in);
FAISS_THROW_IF_NOT_MSG(
params, "IndexRefine params have incorrect type");
}

SearchParameters* base_index_params =
(params != nullptr) ? params->base_index_params : nullptr;

base_index->range_search(
n, x, radius, result, base_index_params);

#pragma omp parallel if (n > 1)
{
std::unique_ptr<DistanceComputer> dc(
refine_index->get_distance_computer());

#pragma omp for
for (idx_t i = 0; i < n; i++) {
dc->set_query(x + i * d);

// reevaluate distances
const size_t idx_start = result->lims[i];
const size_t idx_end = result->lims[i + 1];

for (size_t j = idx_start; j < idx_end; j++) {
const auto label = result->labels[j];
result->distances[j] = (*dc)(label);
}
}
}
}

void IndexRefine::reconstruct(idx_t key, float* recons) const {
refine_index->reconstruct(key, recons);
}
Expand Down
7 changes: 7 additions & 0 deletions thirdparty/faiss/faiss/IndexRefine.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,13 @@ struct IndexRefine : Index {
idx_t* labels,
const SearchParameters* params = nullptr) const override;

void range_search(
idx_t n,
const float* x,
float radius,
RangeSearchResult* result,
const SearchParameters* params = nullptr) const override;

// reconstruct is routed to the refine_index
void reconstruct(idx_t key, float* recons) const override;

Expand Down

0 comments on commit 27dea7f

Please sign in to comment.