diff --git a/include/knowhere/config.h b/include/knowhere/config.h index 94575ba08..11560368d 100644 --- a/include/knowhere/config.h +++ b/include/knowhere/config.h @@ -476,7 +476,11 @@ class Config { } } - return Status::success; + if (!err_msg) { + std::string tem_msg; + return cfg.CheckAndAdjust(type, &tem_msg); + } + return cfg.CheckAndAdjust(type, err_msg); } virtual ~Config() { @@ -485,6 +489,12 @@ class Config { using VarEntry = std::variant, Entry, Entry, Entry, Entry>; std::unordered_map __DICT__; + + protected: + inline virtual Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* const err_msg) { + return Status::success; + } }; #define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG() @@ -557,26 +567,6 @@ class BaseConfig : public Config { .for_deserialize_from_file(); KNOWHERE_CONFIG_DECLARE_FIELD(for_tuning).set_default(false).description("for tuning").for_search(); } - - virtual Status - CheckAndAdjustForSearch(std::string* err_msg) { - return Status::success; - } - - virtual Status - CheckAndAdjustForRangeSearch(std::string* err_msg) { - return Status::success; - } - - virtual Status - CheckAndAdjustForIterator() { - return Status::success; - } - - virtual inline Status - CheckAndAdjustForBuild() { - return Status::success; - } }; } // namespace knowhere diff --git a/src/common/index.cc b/src/common/index.cc index 3d372ffe2..2c7de7e8a 100644 --- a/src/common/index.cc +++ b/src/common/index.cc @@ -37,7 +37,6 @@ inline Status Index::Build(const DataSet& dataset, const Json& json) { auto cfg = this->node->CreateConfig(); RETURN_IF_ERROR(LoadConfig(cfg.get(), json, knowhere::TRAIN, "Build")); - RETURN_IF_ERROR(cfg->CheckAndAdjustForBuild()); #ifdef NOT_COMPILE_FOR_SWIG TimeRecorder rc("Build index", 2); @@ -77,10 +76,6 @@ Index::Search(const DataSet& dataset, const Json& json, const BitsetView& bit if (load_status != Status::success) { return expected::Err(load_status, msg); } - const Status search_status = cfg->CheckAndAdjustForSearch(&msg); - if (search_status != Status::success) { - return expected::Err(search_status, msg); - } #ifdef NOT_COMPILE_FOR_SWIG TimeRecorder rc("Search"); @@ -105,10 +100,6 @@ Index::AnnIterator(const DataSet& dataset, const Json& json, const BitsetView if (status != Status::success) { return expected>>::Err(status, msg); } - status = cfg->CheckAndAdjustForIterator(); - if (status != Status::success) { - return expected>>::Err(status, "invalid params for iterator"); - } #ifdef NOT_COMPILE_FOR_SWIG // note that this time includes only the initial search phase of iterator. @@ -133,10 +124,6 @@ Index::RangeSearch(const DataSet& dataset, const Json& json, const BitsetView if (status != Status::success) { return expected::Err(status, std::move(msg)); } - status = cfg->CheckAndAdjustForRangeSearch(&msg); - if (status != Status::success) { - return expected::Err(status, std::move(msg)); - } #ifdef NOT_COMPILE_FOR_SWIG TimeRecorder rc("Range Search"); diff --git a/src/index/diskann/diskann_config.h b/src/index/diskann/diskann_config.h index 94638ce1b..23ef6b025 100644 --- a/src/index/diskann/diskann_config.h +++ b/src/index/diskann/diskann_config.h @@ -149,24 +149,28 @@ class DiskANNConfig : public BaseConfig { .for_search(); } - inline Status - CheckAndAdjustForSearch(std::string* err_msg) override { - if (!search_list_size.has_value()) { - search_list_size = std::max(k.value(), kSearchListSizeMinValue); - } else if (k.value() > search_list_size.value()) { - *err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + ") should be larger than k(" + - std::to_string(k.value()) + ")"; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::out_of_range_in_json; - } - - return Status::success; - } - - inline Status - CheckAndAdjustForBuild() override { - if (!search_list_size.has_value()) { - search_list_size = kDefaultSearchListSizeForBuild; + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::TRAIN: { + if (!search_list_size.has_value()) { + search_list_size = kDefaultSearchListSizeForBuild; + } + break; + } + case PARAM_TYPE::SEARCH: { + if (!search_list_size.has_value()) { + search_list_size = std::max(k.value(), kSearchListSizeMinValue); + } else if (k.value() > search_list_size.value()) { + *err_msg = "search_list_size(" + std::to_string(search_list_size.value()) + + ") should be larger than k(" + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::out_of_range_in_json; + } + break; + } + default: + break; } return Status::success; } diff --git a/src/index/hnsw/hnsw_config.h b/src/index/hnsw/hnsw_config.h index aa06295ed..98d02b0ad 100644 --- a/src/index/hnsw/hnsw_config.h +++ b/src/index/hnsw/hnsw_config.h @@ -57,25 +57,29 @@ class HnswConfig : public BaseConfig { .for_feder(); } - inline Status - CheckAndAdjustForSearch(std::string* err_msg) override { - if (!ef.has_value()) { - ef = std::max(k.value(), kEfMinValue); - } else if (k.value() > ef.value()) { - *err_msg = - "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + std::to_string(k.value()) + ")"; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::out_of_range_in_json; - } - - return Status::success; - } - - inline Status - CheckAndAdjustForRangeSearch(std::string* err_msg) override { - if (!ef.has_value()) { - // if ef is not set by user, set it to default - ef = kDefaultRangeSearchEf; + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::SEARCH: { + if (!ef.has_value()) { + ef = std::max(k.value(), kEfMinValue); + } else if (k.value() > ef.value()) { + *err_msg = "ef(" + std::to_string(ef.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::out_of_range_in_json; + } + break; + } + case PARAM_TYPE::RANGE_SEARCH: { + if (!ef.has_value()) { + // if ef is not set by user, set it to default + ef = kDefaultRangeSearchEf; + } + break; + } + default: + break; } return Status::success; } diff --git a/src/index/ivf/ivf_config.h b/src/index/ivf/ivf_config.h index fabd1d6a2..82da2cd54 100644 --- a/src/index/ivf/ivf_config.h +++ b/src/index/ivf/ivf_config.h @@ -79,41 +79,46 @@ class ScannConfig : public IvfFlatConfig { .for_train(); } - inline Status - CheckAndAdjustForSearch(std::string* err_msg) override { - if (!faiss::support_pq_fast_scan) { - *err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch."; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::invalid_instruction_set; - } - if (!reorder_k.has_value()) { - reorder_k = k.value(); - } else if (reorder_k.value() < k.value()) { - *err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" + - std::to_string(k.value()) + ")"; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::out_of_range_in_json; - } - - return Status::success; - } - - inline Status - CheckAndAdjustForRangeSearch(std::string* err_msg) override { - if (!faiss::support_pq_fast_scan) { - *err_msg = "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch."; - LOG_KNOWHERE_ERROR_ << *err_msg; - return Status::invalid_instruction_set; - } - return Status::success; - } - - inline Status - CheckAndAdjustForBuild() override { - if (!faiss::support_pq_fast_scan) { - LOG_KNOWHERE_ERROR_ - << "SCANN index is not supported on the current CPU model, avx2 support is needed for x86 arch."; - return Status::invalid_instruction_set; + Status + CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override { + switch (param_type) { + case PARAM_TYPE::TRAIN: { + if (!faiss::support_pq_fast_scan) { + LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is " + "needed for x86 arch."; + return Status::invalid_instruction_set; + } + break; + } + case PARAM_TYPE::SEARCH: { + if (!faiss::support_pq_fast_scan) { + LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is " + "needed for x86 arch."; + return Status::invalid_instruction_set; + } + if (!reorder_k.has_value()) { + reorder_k = k.value(); + } else if (reorder_k.value() < k.value()) { + if (!err_msg) { + err_msg = new std::string(); + } + *err_msg = "reorder_k(" + std::to_string(reorder_k.value()) + ") should be larger than k(" + + std::to_string(k.value()) + ")"; + LOG_KNOWHERE_ERROR_ << *err_msg; + return Status::out_of_range_in_json; + } + break; + } + case PARAM_TYPE::RANGE_SEARCH: { + if (!faiss::support_pq_fast_scan) { + LOG_KNOWHERE_ERROR_ << "SCANN index is not supported on the current CPU model, avx2 support is " + "needed for x86 arch."; + return Status::invalid_instruction_set; + } + break; + } + default: + break; } return Status::success; } diff --git a/tests/ut/test_config.cc b/tests/ut/test_config.cc index c6b9181d3..e7dfe70ef 100644 --- a/tests/ut/test_config.cc +++ b/tests/ut/test_config.cc @@ -175,7 +175,6 @@ TEST_CASE("Test config json parse", "[config]") { invalid_value_json = json; invalid_value_json["ef"] = 99; s = knowhere::Config::Load(wrong_cfg, invalid_value_json, knowhere::SEARCH); - s = wrong_cfg.CheckAndAdjustForSearch(&err_msg); CHECK(s == knowhere::Status::out_of_range_in_json); } @@ -189,7 +188,6 @@ TEST_CASE("Test config json parse", "[config]") { { knowhere::HnswConfig search_cfg; s = knowhere::Config::Load(search_cfg, json, knowhere::SEARCH); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); CHECK(s == knowhere::Status::success); } @@ -198,7 +196,6 @@ TEST_CASE("Test config json parse", "[config]") { auto search_json = json; search_json.erase("ef"); s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); CHECK(s == knowhere::Status::success); CHECK_EQ(100, search_cfg.ef.value()); } @@ -209,7 +206,6 @@ TEST_CASE("Test config json parse", "[config]") { search_json.erase("ef"); search_json["k"] = 10; s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); CHECK(s == knowhere::Status::success); CHECK_EQ(16, search_cfg.ef.value()); } @@ -244,8 +240,6 @@ TEST_CASE("Test config json parse", "[config]") { knowhere::DiskANNConfig train_cfg; s = knowhere::Config::Load(train_cfg, json, knowhere::TRAIN); CHECK(s == knowhere::Status::success); - s = train_cfg.CheckAndAdjustForBuild(); - CHECK(s == knowhere::Status::success); CHECK_EQ(128, train_cfg.search_list_size.value()); CHECK_EQ("L2", train_cfg.metric_type.value()); } @@ -254,8 +248,6 @@ TEST_CASE("Test config json parse", "[config]") { knowhere::DiskANNConfig search_cfg; s = knowhere::Config::Load(search_cfg, json, knowhere::SEARCH); CHECK(s == knowhere::Status::success); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); - CHECK(s == knowhere::Status::success); CHECK_EQ("L2", search_cfg.metric_type.value()); CHECK_EQ(100, search_cfg.k.value()); CHECK_EQ(100, search_cfg.search_list_size.value()); @@ -267,8 +259,6 @@ TEST_CASE("Test config json parse", "[config]") { search_json["k"] = 2; s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH); CHECK(s == knowhere::Status::success); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); - CHECK(s == knowhere::Status::success); CHECK_EQ(16, search_cfg.search_list_size.value()); } @@ -277,7 +267,6 @@ TEST_CASE("Test config json parse", "[config]") { auto search_json = json; search_json["search_list_size"] = 99; s = knowhere::Config::Load(search_cfg, search_json, knowhere::SEARCH); - s = search_cfg.CheckAndAdjustForSearch(&err_msg); CHECK(s == knowhere::Status::out_of_range_in_json); }