From 95be2bb703eebb7faad2ea63f395b248f87872b9 Mon Sep 17 00:00:00 2001 From: Min Tian Date: Mon, 16 Oct 2023 22:09:58 -0500 Subject: [PATCH] [pyknowhere] support dump and load index file (#145) review fix pre-commit Signed-off-by: min.tian --- .gitignore | 4 ++ python/knowhere/__init__.py | 11 +++++ python/knowhere/knowhere.i | 52 ++++++++++++++++++++++++ tests/python/test_index_load_and_save.py | 23 ++++++++--- tests/python/test_index_with_random.py | 2 +- 5 files changed, 86 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index adf2c2de5..8133e6e70 100644 --- a/.gitignore +++ b/.gitignore @@ -66,3 +66,7 @@ wheelhouse/* *.bin CMakeUserPresets.json + +# python +.eggs +*.egg-info diff --git a/python/knowhere/__init__.py b/python/knowhere/__init__.py index 05ad4ed0e..3b194a0c8 100644 --- a/python/knowhere/__init__.py +++ b/python/knowhere/__init__.py @@ -3,9 +3,11 @@ from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView import numpy as np + def CreateIndex(name, version): return swigknowhere.IndexWrap(name, version) + def GetCurrentVersion(): return swigknowhere.CurrentVersion() @@ -14,6 +16,14 @@ def CreateBitSet(bits_num): return swigknowhere.BitSet(bits_num) +def Load(binset, file_name): + return swigknowhere.Load(binset, file_name) + + +def Dump(binset, file_name): + return swigknowhere.Dump(binset, file_name) + + def ArrayToDataSet(arr): if arr.ndim == 1: return swigknowhere.Array2DataSetIds(arr) @@ -70,6 +80,7 @@ def RangeSearchDataSetToArray(ans): return dis_list, ids_list + def GetVectorDataSetToArray(ans): dim = swigknowhere.DataSet_Dim(ans) rows = swigknowhere.DataSet_Rows(ans) diff --git a/python/knowhere/knowhere.i b/python/knowhere/knowhere.i index 1c8a4dfcf..ee47cf328 100644 --- a/python/knowhere/knowhere.i +++ b/python/knowhere/knowhere.i @@ -31,6 +31,8 @@ typedef uint64_t size_t; #include #include #include +#include +#include using namespace knowhere; %} @@ -354,4 +356,54 @@ DumpRangeResultDis(knowhere::DataSetPtr result, float* dis, int len) { } } +void +Dump(knowhere::BinarySetPtr binset, const std::string& file_name) { + auto binary_set = *binset; + auto binary_map = binset -> binary_map_; + std::ofstream outfile; + outfile.open(file_name, std::ios::out | std::ios::trunc); + if (outfile.good()) { + for (auto it = binary_map.begin(); it != binary_map.end(); ++it) { + // serialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); + auto name = it->first; + uint64_t name_len = name.size(); + outfile << name_len; + outfile << name; + auto value = it->second; + outfile << value->size; + outfile.write(reinterpret_cast(value->data.get()), value->size); + } + // end with 0 + outfile << 0; + outfile.flush(); + } +} + +void +Load(knowhere::BinarySetPtr binset, const std::string& file_name) { + std::ifstream infile; + infile.open(file_name, std::ios::in); + if (infile.good()) { + uint64_t name_len; + while (true) { + // deserialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]); + infile >> name_len; + if (name_len == 0) break; + + auto _name = new char[name_len]; + infile.read(_name, name_len); + std::string name(_name, name_len); + + int64_t size; + infile >> size; + if (size > 0) { + auto data = new uint8_t[size]; + std::shared_ptr data_ptr(data); + infile.read(reinterpret_cast(data_ptr.get()), size); + binset->Append(name, data_ptr, size); + } + } + } +} + %} diff --git a/tests/python/test_index_load_and_save.py b/tests/python/test_index_load_and_save.py index 9c6a56860..bc6e18568 100644 --- a/tests/python/test_index_load_and_save.py +++ b/tests/python/test_index_load_and_save.py @@ -1,6 +1,7 @@ import knowhere import json import pytest +import os test_data = [ ( @@ -12,28 +13,40 @@ }, ), ] +index_file = "test_index_load_and_save.index" + + @pytest.mark.parametrize("name,config", test_data) def test_save_and_load(gen_data, faiss_ans, recall, error, name, config): # simple load and save not work for ivf nm print(name, config) version = knowhere.GetCurrentVersion() build_idx = knowhere.CreateIndex(name, version) - xb, xq = gen_data(10000, 100, 256) + xb, xq = gen_data(10_000, 100, 256) + # build, serialize and dump build_idx.Build( knowhere.ArrayToDataSet(xb), json.dumps(config), ) binset = knowhere.GetBinarySet() build_idx.Serialize(binset) + knowhere.Dump(binset, index_file) + + # load and deserialize + new_binset = knowhere.GetBinarySet() + knowhere.Load(new_binset, index_file) search_idx = knowhere.CreateIndex(name, version) - search_idx.Deserialize(binset) + search_idx.Deserialize(new_binset) + + # test the loaded index ans, _ = search_idx.Search( - knowhere.ArrayToDataSet(xq), - json.dumps(config), - knowhere.GetNullBitSetView() + knowhere.ArrayToDataSet(xq), json.dumps(config), knowhere.GetNullBitSetView() ) k_dis, k_ids = knowhere.DataSetToArray(ans) f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"]) assert recall(f_ids, k_ids) >= 0.99 assert error(f_dis, f_dis) <= 0.01 + + # delete the index_file + os.remove(index_file) diff --git a/tests/python/test_index_with_random.py b/tests/python/test_index_with_random.py index e4492a8b3..13211ba9d 100644 --- a/tests/python/test_index_with_random.py +++ b/tests/python/test_index_with_random.py @@ -29,7 +29,7 @@ "metric_type": "L2", "nlist": 1024, "nprobe": 1024, - "ssize": 48 + "ssize": 48, }, ), (