Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Isolate faiss_hnsw and hnsw by index version #952

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 48 additions & 77 deletions src/index/hnsw/faiss_hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,20 +1381,33 @@ class BaseFaissRegularIndexHNSWFlatNodeTemplate : public BaseFaissRegularIndexHN
// but a deserialization may override its search behavior.
// It is a concrete implementation's responsibility to initialize BaseIndex and
// FallbackSearchIndex properly.
class IndexNodeWithSearchFallback : public IndexNode {
class HNSWIndexNodeWithFallback : public IndexNode {
public:
IndexNodeWithSearchFallback(const int32_t& version, const Object& object) {
use_base_index = true;
HNSWIndexNodeWithFallback(const int32_t& version, const Object& object) {
constexpr int faiss_hnsw_support_version = 6;
if (version >= faiss_hnsw_support_version) {
use_base_index = true;
} else {
use_base_index = false;
}
}

Status
Train(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
return base_index->Train(dataset, cfg);
if (use_base_index) {
return base_index->Train(dataset, cfg);
} else {
return fallback_search_index->Train(dataset, cfg);
}
}

Status
Add(const DataSetPtr dataset, std::shared_ptr<Config> cfg) override {
return base_index->Add(dataset, cfg);
if (use_base_index) {
return base_index->Add(dataset, cfg);
} else {
return fallback_search_index->Add(dataset, cfg);
}
}

expected<DataSetPtr>
Expand All @@ -1408,7 +1421,29 @@ class IndexNodeWithSearchFallback : public IndexNode {

Status
Serialize(BinarySet& binset) const override {
return base_index->Serialize(binset);
if (use_base_index) {
return base_index->Serialize(binset);
} else {
return fallback_search_index->Serialize(binset);
}
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->Deserialize(binset, config);
} else {
return fallback_search_index->Deserialize(binset, config);
}
}

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
if (use_base_index) {
return base_index->DeserializeFromFile(filename, config);
} else {
return fallback_search_index->DeserializeFromFile(filename, config);
}
}

int64_t
Expand Down Expand Up @@ -1440,7 +1475,11 @@ class IndexNodeWithSearchFallback : public IndexNode {

std::string
Type() const override {
return base_index->Type();
if (use_base_index) {
return base_index->Type();
} else {
return fallback_search_index->Type();
}
}

bool
Expand Down Expand Up @@ -1494,79 +1533,11 @@ class IndexNodeWithSearchFallback : public IndexNode {
std::unique_ptr<IndexNode> fallback_search_index;
};

class BaseFaissRegular2HnswlibIndexNode : public IndexNodeWithSearchFallback {
public:
BaseFaissRegular2HnswlibIndexNode(const int32_t& version, const Object& object)
: IndexNodeWithSearchFallback(version, object) {
}

Status
Deserialize(const BinarySet& binset, std::shared_ptr<Config> config) override {
// is the name for a base index?
BinaryPtr binary = binset.GetByName(base_index->Type());
if (binary != nullptr) {
auto base_status = base_index->Deserialize(binset, config);
if (base_status == Status::success) {
// switch to a base index
use_base_index = true;
}

if (base_status != Status::invalid_serialized_index_type) {
return base_status;
}

// we go ahead if base_index returned Status::invalid_serialized_index_type
}

// ok, try to deserialize as a fallback one
BinaryPtr binary_fallback = binset.GetByName(fallback_search_index->Type());
if (binary_fallback != nullptr) {
LOG_KNOWHERE_INFO_ << "The provided data does not look like a FAISS index. Falling back to hnswlib index.";
auto fallback_status = fallback_search_index->Deserialize(binset, config);
if (fallback_status == Status::success) {
// switch to a fallback index
use_base_index = false;
}

return fallback_status;
}

// unknown index
LOG_KNOWHERE_ERROR_ << "Invalid binary set.";
return Status::invalid_binary_set;
};

Status
DeserializeFromFile(const std::string& filename, std::shared_ptr<Config> config) override {
auto base_status = base_index->DeserializeFromFile(filename, config);
if (base_status == Status::success) {
// switch to a base index
use_base_index = true;
}

if (base_status != Status::invalid_serialized_index_type) {
return base_status;
}

// we go ahead if base_index returned Status::invalid_serialized_index_type

// ok, try to deserialize as a fallback one
LOG_KNOWHERE_INFO_ << "The provided data does not look like a FAISS index. Falling back to hnswlib index.";
auto fallback_status = fallback_search_index->DeserializeFromFile(filename, config);
if (fallback_status == Status::success) {
// switch to a fallback index
use_base_index = false;
}

return fallback_status;
}
};

template <typename DataType>
class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public BaseFaissRegular2HnswlibIndexNode {
class BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback : public HNSWIndexNodeWithFallback {
public:
BaseFaissRegularIndexHNSWFlatNodeTemplateWithSearchFallback(const int32_t& version, const Object& object)
: BaseFaissRegular2HnswlibIndexNode(version, object) {
: HNSWIndexNodeWithFallback(version, object) {
// initialize underlying nodes
base_index = std::make_unique<BaseFaissRegularIndexHNSWFlatNodeTemplate<DataType>>(version, object);
fallback_search_index = std::make_unique<HnswIndexNode<DataType, hnswlib::QuantType::None>>(version, object);
Expand Down
Loading