diff --git a/.gitignore b/.gitignore index 8ff170e28..2fa389c97 100644 --- a/.gitignore +++ b/.gitignore @@ -69,4 +69,4 @@ CMakeUserPresets.json # python .eggs -*.egg-info \ No newline at end of file +*.egg-info diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 57820a68a..158f56e39 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -362,45 +362,48 @@ Dump(knowhere::BinarySetPtr binset, const std::string& file_name) { auto binary_map = binset -> binary_map_; std::ofstream outfile; outfile.open(file_name, std::ios::out | std::ios::trunc); - for (auto it = binary_map.begin(); it != binary_map.end(); ++it) { - // serialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); - auto name = it->first; - outfile << name.size(); - outfile << name; - auto value = it->second; - outfile << value->size; - outfile.write(reinterpret_cast(value->data.get()), value->size); + if (outfile.good()) { + for (auto it = binary_map.begin(); it != binary_map.end(); ++it) { + // serialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); + auto name = it->first; + uint64_t name_len = name.size(); + outfile << name_len; + outfile << name; + auto value = it->second; + outfile << value->size; + outfile.write(reinterpret_cast(value->data.get()), value->size); + } + // end with 0 + outfile << 0; + outfile.flush(); } - // end with 0 - outfile << 0; - outfile.close(); } void Load(knowhere::BinarySetPtr binset, const std::string& file_name) { std::ifstream infile; infile.open(file_name, std::ios::in); - size_t name_len; - char flag; - while (true) { - // serialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); - infile >> name_len; - if (name_len == 0) break; - - auto _name = new char[name_len]; - infile.read(_name, name_len); - std::string name(_name, name_len); - - int64_t size; - infile >> size; - - auto data = new uint8_t[size]; - infile.read(reinterpret_cast(data), size); - std::shared_ptr data_ptr(data); - - binset->Append(name, data_ptr, size); + if (infile.good()) { + uint64_t name_len; + while (true) { + // deserialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); + infile >> name_len; + if (name_len == 0) break; + + auto _name = new char[name_len]; + infile.read(_name, name_len); + std::string name(_name, name_len); + + int64_t size; + infile >> size; + if (size > 0) { + auto data = new uint8_t[size]; + std::shared_ptr data_ptr(data); + infile.read(reinterpret_cast(data_ptr.get()), size); + binset->Append(name, data_ptr, size); + } + } } - infile.close(); } %} diff --git a/tests/python/test_index_load_and_save.py b/tests/python/test_index_load_and_save.py index f0f03c319..bc6e18568 100644 --- a/tests/python/test_index_load_and_save.py +++ b/tests/python/test_index_load_and_save.py @@ -22,7 +22,7 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): print(name, config) version = knowhere.GetCurrentVersion() build_idx = knowhere.CreateIndex(name, version) - xb, xq = gen_data(10000, 100, 256) + xb, xq = gen_data(10_000, 100, 256) # build, serialize and dump build_idx.Build( @@ -39,7 +39,7 @@ def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): search_idx = knowhere.CreateIndex(name, version) search_idx.Deserialize(new_binset) - # teset the loaded index + # test the loaded index ans, _ = search_idx.Search( knowhere.ArrayToDataSet(xq), json.dumps(config), knowhere.GetNullBitSetView() )