Skip to content

Commit

Permalink
[pyknowhere] support dump and load index file (#145)
Browse files Browse the repository at this point in the history
review



fix pre-commit

Signed-off-by: min.tian <[email protected]>
  • Loading branch information
alwayslove2013 authored Oct 17, 2023
1 parent 67bb622 commit 95be2bb
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,7 @@ wheelhouse/*
*.bin

CMakeUserPresets.json

# python
.eggs
*.egg-info
11 changes: 11 additions & 0 deletions python/knowhere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from .swigknowhere import GetBinarySet, GetNullDataSet, GetNullBitSetView
import numpy as np


def CreateIndex(name, version):
return swigknowhere.IndexWrap(name, version)


def GetCurrentVersion():
return swigknowhere.CurrentVersion()

Expand All @@ -14,6 +16,14 @@ def CreateBitSet(bits_num):
return swigknowhere.BitSet(bits_num)


def Load(binset, file_name):
return swigknowhere.Load(binset, file_name)


def Dump(binset, file_name):
return swigknowhere.Dump(binset, file_name)


def ArrayToDataSet(arr):
if arr.ndim == 1:
return swigknowhere.Array2DataSetIds(arr)
Expand Down Expand Up @@ -70,6 +80,7 @@ def RangeSearchDataSetToArray(ans):

return dis_list, ids_list


def GetVectorDataSetToArray(ans):
dim = swigknowhere.DataSet_Dim(ans)
rows = swigknowhere.DataSet_Rows(ans)
Expand Down
52 changes: 52 additions & 0 deletions python/knowhere/knowhere.i
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ typedef uint64_t size_t;
#include <knowhere/version.h>
#include <knowhere/utils.h>
#include <knowhere/comp/local_file_manager.h>
#include <fstream>
#include <string>
using namespace knowhere;
%}

Expand Down Expand Up @@ -354,4 +356,54 @@ DumpRangeResultDis(knowhere::DataSetPtr result, float* dis, int len) {
}
}

void
Dump(knowhere::BinarySetPtr binset, const std::string& file_name) {
auto binary_set = *binset;
auto binary_map = binset -> binary_map_;
std::ofstream outfile;
outfile.open(file_name, std::ios::out | std::ios::trunc);
if (outfile.good()) {
for (auto it = binary_map.begin(); it != binary_map.end(); ++it) {
// serialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]);
auto name = it->first;
uint64_t name_len = name.size();
outfile << name_len;
outfile << name;
auto value = it->second;
outfile << value->size;
outfile.write(reinterpret_cast<char*>(value->data.get()), value->size);
}
// end with 0
outfile << 0;
outfile.flush();
}
}

void
Load(knowhere::BinarySetPtr binset, const std::string& file_name) {
std::ifstream infile;
infile.open(file_name, std::ios::in);
if (infile.good()) {
uint64_t name_len;
while (true) {
// deserialization: name_length(size_t); name(char[]); binset_size(size_t); binset(uint8[]);
infile >> name_len;
if (name_len == 0) break;

auto _name = new char[name_len];
infile.read(_name, name_len);
std::string name(_name, name_len);

int64_t size;
infile >> size;
if (size > 0) {
auto data = new uint8_t[size];
std::shared_ptr<uint8_t[]> data_ptr(data);
infile.read(reinterpret_cast<char*>(data_ptr.get()), size);
binset->Append(name, data_ptr, size);
}
}
}
}

%}
23 changes: 18 additions & 5 deletions tests/python/test_index_load_and_save.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import knowhere
import json
import pytest
import os

test_data = [
(
Expand All @@ -12,28 +13,40 @@
},
),
]
index_file = "test_index_load_and_save.index"


@pytest.mark.parametrize("name,config", test_data)
def test_save_and_load(gen_data, faiss_ans, recall, error, name, config):
# simple load and save not work for ivf nm
print(name, config)
version = knowhere.GetCurrentVersion()
build_idx = knowhere.CreateIndex(name, version)
xb, xq = gen_data(10000, 100, 256)
xb, xq = gen_data(10_000, 100, 256)

# build, serialize and dump
build_idx.Build(
knowhere.ArrayToDataSet(xb),
json.dumps(config),
)
binset = knowhere.GetBinarySet()
build_idx.Serialize(binset)
knowhere.Dump(binset, index_file)

# load and deserialize
new_binset = knowhere.GetBinarySet()
knowhere.Load(new_binset, index_file)
search_idx = knowhere.CreateIndex(name, version)
search_idx.Deserialize(binset)
search_idx.Deserialize(new_binset)

# test the loaded index
ans, _ = search_idx.Search(
knowhere.ArrayToDataSet(xq),
json.dumps(config),
knowhere.GetNullBitSetView()
knowhere.ArrayToDataSet(xq), json.dumps(config), knowhere.GetNullBitSetView()
)
k_dis, k_ids = knowhere.DataSetToArray(ans)
f_dis, f_ids = faiss_ans(xb, xq, config["metric_type"], config["k"])
assert recall(f_ids, k_ids) >= 0.99
assert error(f_dis, f_dis) <= 0.01

# delete the index_file
os.remove(index_file)
2 changes: 1 addition & 1 deletion tests/python/test_index_with_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"metric_type": "L2",
"nlist": 1024,
"nprobe": 1024,
"ssize": 48
"ssize": 48,
},
),
(
Expand Down

0 comments on commit 95be2bb

Please sign in to comment.