Skip to content

Commit

Permalink
optimize config check error msg (#940)
Browse files Browse the repository at this point in the history
Signed-off-by: xianliang.li <[email protected]>
  • Loading branch information
foxspy authored Nov 14, 2024
1 parent 5935c1f commit 87747be
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 161 deletions.
66 changes: 26 additions & 40 deletions include/knowhere/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,6 @@ class Config {

static Status
Load(Config& cfg, const Json& json, PARAM_TYPE type, std::string* const err_msg = nullptr) {
auto show_err_msg = [&](std::string& msg) {
LOG_KNOWHERE_ERROR_ << msg;
if (err_msg) {
*err_msg = msg;
}
};
for (const auto& it : cfg.__DICT__) {
const auto& var = it.second;

Expand All @@ -363,8 +357,7 @@ class Config {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand All @@ -373,16 +366,14 @@ class Config {
if (!json[it.first].is_number_integer()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be integer";
show_err_msg(msg);
return Status::type_conflict_in_json;
return HandleError(err_msg, msg, Status::type_conflict_in_json);
}
if (ptr->range.has_value()) {
if (json[it.first].get<int64_t>() > std::numeric_limits<CFG_INT::value_type>::max()) {
std::string msg = "Arithmetic overflow: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should not bigger than " +
std::to_string(std::numeric_limits<CFG_INT::value_type>::max());
show_err_msg(msg);
return Status::arithmetic_overflow;
return HandleError(err_msg, msg, Status::arithmetic_overflow);
}
CFG_INT::value_type v = json[it.first];
auto range_val = ptr->range.value();
Expand All @@ -391,8 +382,7 @@ class Config {
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
return HandleError(err_msg, msg, Status::out_of_range_in_json);
}
} else {
*ptr->val = json[it.first];
Expand All @@ -409,8 +399,7 @@ class Config {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand All @@ -419,16 +408,14 @@ class Config {
if (!json[it.first].is_number_integer()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be long integer";
show_err_msg(msg);
return Status::type_conflict_in_json;
return HandleError(err_msg, msg, Status::type_conflict_in_json);
}
if (ptr->range.has_value()) {
if (json[it.first].get<int64_t>() > std::numeric_limits<CFG_INT64::value_type>::max()) {
std::string msg = "Arithmetic overflow: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should not bigger than " +
std::to_string(std::numeric_limits<CFG_INT64::value_type>::max());
show_err_msg(msg);
return Status::arithmetic_overflow;
return HandleError(err_msg, msg, Status::arithmetic_overflow);
}
CFG_INT64::value_type v = json[it.first];
auto range_val = ptr->range.value();
Expand All @@ -437,8 +424,7 @@ class Config {
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
return HandleError(err_msg, msg, Status::out_of_range_in_json);
}
} else {
*ptr->val = json[it.first];
Expand All @@ -455,8 +441,7 @@ class Config {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand All @@ -465,16 +450,14 @@ class Config {
if (!json[it.first].is_number()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be a number";
show_err_msg(msg);
return Status::type_conflict_in_json;
return HandleError(err_msg, msg, Status::type_conflict_in_json);
}
if (ptr->range.has_value()) {
if (json[it.first].get<double>() > std::numeric_limits<CFG_FLOAT::value_type>::max()) {
std::string msg = "Arithmetic overflow: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should not bigger than " +
std::to_string(std::numeric_limits<CFG_FLOAT::value_type>::max());
show_err_msg(msg);
return Status::arithmetic_overflow;
return HandleError(err_msg, msg, Status::arithmetic_overflow);
}
CFG_FLOAT::value_type v = json[it.first];
auto range_val = ptr->range.value();
Expand All @@ -483,8 +466,7 @@ class Config {
} else {
std::string msg = "Out of range in json: param '" + it.first + "' (" +
to_string(json[it.first]) + ") should be in range " + range_val.to_string();
show_err_msg(msg);
return Status::out_of_range_in_json;
return HandleError(err_msg, msg, Status::out_of_range_in_json);
}
} else {
*ptr->val = json[it.first];
Expand All @@ -501,8 +483,7 @@ class Config {
continue;
}
std::string msg = "param [" + it.first + "] not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand All @@ -511,8 +492,7 @@ class Config {
if (!json[it.first].is_string()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be a string";
show_err_msg(msg);
return Status::type_conflict_in_json;
return HandleError(err_msg, msg, Status::type_conflict_in_json);
}
*ptr->val = json[it.first];
}
Expand All @@ -527,8 +507,7 @@ class Config {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand All @@ -537,8 +516,7 @@ class Config {
if (!json[it.first].is_boolean()) {
std::string msg = "Type conflict in json: param '" + it.first + "' (" + to_string(json[it.first]) +
") should be a boolean";
show_err_msg(msg);
return Status::type_conflict_in_json;
return HandleError(err_msg, msg, Status::type_conflict_in_json);
}
*ptr->val = json[it.first];
}
Expand All @@ -554,8 +532,7 @@ class Config {
continue;
}
std::string msg = "param '" + it.first + "' not exist in json";
show_err_msg(msg);
return Status::invalid_param_in_json;
return HandleError(err_msg, msg, Status::invalid_param_in_json);
} else {
*ptr->val = ptr->default_val;
continue;
Expand Down Expand Up @@ -584,6 +561,15 @@ class Config {
CheckAndAdjust(PARAM_TYPE param_type, std::string* const err_msg) {
return Status::success;
}

static knowhere::Status
HandleError(std::string* error_msg, const std::string& msg, const knowhere::Status& status) {
if (error_msg) {
*error_msg = msg;
}
LOG_KNOWHERE_ERROR_ << msg;
return status;
}
};

#define KNOHWERE_DECLARE_CONFIG(CONFIG) CONFIG()
Expand Down
61 changes: 31 additions & 30 deletions src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,18 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
}
if (v < std::numeric_limits<CFG_INT::value_type>::min() ||
v > std::numeric_limits<CFG_INT::value_type>::max()) {
if (err_msg) {
*err_msg =
"integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
std::string msg =
"integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
}
json[key_str] = static_cast<CFG_INT::value_type>(v);
} catch (const std::out_of_range&) {
if (err_msg) {
*err_msg = "integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
std::string msg =
"integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
} catch (const std::invalid_argument&) {
KNOWHERE_THROW_MSG("invalid integer value, key: '" + key_str + "', value: '" + value_str + "'");
std::string msg = "invalid integer value, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
}
}
if (std::get_if<Entry<CFG_INT64>>(&var)) {
Expand All @@ -125,29 +123,36 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
}
if (v < std::numeric_limits<CFG_INT64::value_type>::min() ||
v > std::numeric_limits<CFG_INT64::value_type>::max()) {
if (err_msg) {
*err_msg = "long integer value out of range, key: '" + key_str + "', value: '" +
value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
std::string msg =
"long integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
}
json[key_str] = static_cast<CFG_INT64::value_type>(v);
} catch (const std::out_of_range&) {
if (err_msg) {
*err_msg =
"long integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
}
return knowhere::Status::invalid_value_in_json;
std::string msg =
"long integer value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
} catch (const std::invalid_argument&) {
KNOWHERE_THROW_MSG("invalid long integer value, key: '" + key_str + "', value: '" + value_str +
"'");
std::string msg =
"invalid long integer value, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
}
}
if (std::get_if<Entry<CFG_FLOAT>>(&var)) {
CFG_FLOAT::value_type v = std::stof(json[it.first].get<std::string>().c_str());
json[it.first] = v;
auto key_str = it.first;
auto value_str = json[key_str].get<std::string>();
try {
CFG_FLOAT::value_type v = std::stof(json[it.first].get<std::string>().c_str());
json[it.first] = v;
} catch (const std::out_of_range&) {
std::string msg =
"float value out of range, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
} catch (const std::invalid_argument&) {
std::string msg = "invalid float value, key: '" + key_str + "', value: '" + value_str + "'";
return HandleError(err_msg, msg, Status::invalid_value_in_json);
}
}

if (std::get_if<Entry<CFG_BOOL>>(&var)) {
if (json[it.first] == "true") {
json[it.first] = true;
Expand All @@ -159,11 +164,7 @@ Config::FormatAndCheck(const Config& cfg, Json& json, std::string* const err_msg
}
}
} catch (std::exception& e) {
LOG_KNOWHERE_ERROR_ << e.what();
if (err_msg) {
*err_msg = e.what();
}
return Status::invalid_value_in_json;
return HandleError(err_msg, e.what(), Status::invalid_value_in_json);
}
return Status::success;
}
Expand Down
9 changes: 3 additions & 6 deletions src/index/diskann/diskann_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,9 @@ class DiskANNConfig : public BaseConfig {
if (!search_list_size.has_value()) {
search_list_size = std::max(k.value(), kSearchListSizeMinValue);
} else if (k.value() > search_list_size.value()) {
if (err_msg) {
*err_msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
}
LOG_KNOWHERE_ERROR_ << *err_msg;
return Status::out_of_range_in_json;
std::string msg = "search_list_size(" + std::to_string(search_list_size.value()) +
") should be larger than k(" + std::to_string(k.value()) + ")";
return HandleError(err_msg, msg, Status::out_of_range_in_json);
}
break;
}
Expand Down
6 changes: 2 additions & 4 deletions src/index/gpu_raft/gpu_raft_brute_force_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ struct GpuRaftBruteForceConfig : public BaseConfig {
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
}
return Status::invalid_metric_type;
std::string msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
return HandleError(err_msg, msg, Status::invalid_metric_type);
}
}
return Status::success;
Expand Down
13 changes: 4 additions & 9 deletions src/index/gpu_raft/gpu_raft_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ struct GpuRaftCagraConfig : public BaseConfig {
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
}
return Status::invalid_metric_type;
std::string msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
return HandleError(err_msg, msg, Status::invalid_metric_type);
}
}

Expand All @@ -140,11 +138,8 @@ struct GpuRaftCagraConfig : public BaseConfig {

if (search_width.has_value()) {
if (std::max(itopk_size.value(), kAlignFactor * search_width.value()) < k.value()) {
if (err_msg) {
*err_msg = "max((itopk_size + 31)// 32, search_width) * 32< topk";
LOG_KNOWHERE_ERROR_ << *err_msg;
}
return Status::out_of_range_in_json;
std::string msg = "max((itopk_size + 31)// 32, search_width) * 32< topk";
return HandleError(err_msg, msg, Status::out_of_range_in_json);
}
} else {
search_width = std::max((k.value() - 1) / kAlignFactor + 1, kSearchWidth);
Expand Down
Loading

0 comments on commit 87747be

Please sign in to comment.