diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 7dfba102..8cfa469a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -3,15 +3,15 @@ name: HNSW CI on: [push, pull_request] jobs: - test: + test_python: runs-on: ${{matrix.os}} strategy: matrix: os: [ubuntu-latest, windows-latest] - python-version: ['3.6', '3.7', '3.8', '3.9'] + python-version: ["3.7", "3.8", "3.9", "3.10"] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} @@ -19,4 +19,58 @@ jobs: run: python -m pip install . - name: Test - run: python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" + timeout-minutes: 15 + run: | + python -m unittest discover -v --start-directory examples/python --pattern "example*.py" + python -m unittest discover -v --start-directory tests/python --pattern "bindings_test*.py" + + test_cpp: + runs-on: ${{matrix.os}} + strategy: + matrix: + os: [ubuntu-latest, windows-latest] + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + + - name: Build + run: | + mkdir build + cd build + cmake .. + if [ "$RUNNER_OS" == "Linux" ]; then + make + elif [ "$RUNNER_OS" == "Windows" ]; then + cmake --build ./ --config Release + fi + shell: bash + + - name: Prepare test data + run: | + pip install numpy + cd tests/cpp/ + python update_gen_data.py + shell: bash + + - name: Test + timeout-minutes: 15 + run: | + cd build + if [ "$RUNNER_OS" == "Windows" ]; then + cp ./Release/* ./ + fi + ./example_search + ./example_filter + ./example_replace_deleted + ./example_mt_search + ./example_mt_filter + ./example_mt_replace_deleted + ./searchKnnCloserFirst_test + ./searchKnnWithFilter_test + ./multiThreadLoad_test + ./multiThread_replace_test + ./test_updates + ./test_updates update + shell: bash diff --git a/.gitignore b/.gitignore index dab30385..48f74604 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ hnswlib.cpython*.so var/ .idea/ .vscode/ - +.vs/ +**.DS_Store diff --git a/ALGO_PARAMS.md b/ALGO_PARAMS.md index b0a6b7ad..0d5133f3 100644 --- a/ALGO_PARAMS.md +++ b/ALGO_PARAMS.md @@ -27,5 +27,5 @@ ef_construction leads to longer construction, but better index quality. At some not improve the quality of the index. One way to check if the selection of ef_construction was ok is to measure a recall for M nearest neighbor search when ```ef``` =```ef_construction```: if the recall is lower than 0.9, than there is room for improvement. -* ```num_elements``` - defines the maximum number of elements in the index. The index can be extened by saving/loading(load_index +* ```num_elements``` - defines the maximum number of elements in the index. The index can be extended by saving/loading (load_index function has a parameter which defines the new maximum number of elements). diff --git a/CMakeLists.txt b/CMakeLists.txt index e2f3d716..7cebe600 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,12 +16,41 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME) SET( CMAKE_CXX_FLAGS "-Ofast -lrt -DNDEBUG -std=c++11 -DHAVE_CXX0X -openmp -march=native -fpic -w -fopenmp -ftree-vectorize" ) endif() - add_executable(test_updates examples/updates_test.cpp) + # examples + add_executable(example_search examples/cpp/example_search.cpp) + target_link_libraries(example_search hnswlib) + + add_executable(example_filter examples/cpp/example_filter.cpp) + target_link_libraries(example_filter hnswlib) + + add_executable(example_replace_deleted examples/cpp/example_replace_deleted.cpp) + target_link_libraries(example_replace_deleted hnswlib) + + add_executable(example_mt_search examples/cpp/example_mt_search.cpp) + target_link_libraries(example_mt_search hnswlib) + + add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp) + target_link_libraries(example_mt_filter hnswlib) + + add_executable(example_mt_replace_deleted examples/cpp/example_mt_replace_deleted.cpp) + target_link_libraries(example_mt_replace_deleted hnswlib) + + # tests + add_executable(test_updates tests/cpp/updates_test.cpp) target_link_libraries(test_updates hnswlib) - add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp) + add_executable(searchKnnCloserFirst_test tests/cpp/searchKnnCloserFirst_test.cpp) target_link_libraries(searchKnnCloserFirst_test hnswlib) - add_executable(main main.cpp sift_1b.cpp) + add_executable(searchKnnWithFilter_test tests/cpp/searchKnnWithFilter_test.cpp) + target_link_libraries(searchKnnWithFilter_test hnswlib) + + add_executable(multiThreadLoad_test tests/cpp/multiThreadLoad_test.cpp) + target_link_libraries(multiThreadLoad_test hnswlib) + + add_executable(multiThread_replace_test tests/cpp/multiThread_replace_test.cpp) + target_link_libraries(multiThread_replace_test hnswlib) + + add_executable(main tests/cpp/main.cpp tests/cpp/sift_1b.cpp) target_link_libraries(main hnswlib) endif() diff --git a/Makefile b/Makefile index b5e8fda9..0de9c765 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ dist: python3 -m build --sdist test: - python3 -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py" + python3 -m unittest discover --start-directory tests/python --pattern "bindings_test*.py" clean: rm -rf *.egg-info build dist tmp var tests/__pycache__ hnswlib.cpython*.so diff --git a/README.md b/README.md index 9bcb6775..3ed466a7 100644 --- a/README.md +++ b/README.md @@ -1,34 +1,22 @@ # Hnswlib - fast approximate nearest neighbor search -Header-only C++ HNSW implementation with python bindings. +Header-only C++ HNSW implementation with python bindings, insertions and updates. **NEWS:** +**version 0.7.0** -**version 0.6.2** - -* Fixed a bug in saving of large pickles. The pickles with > 4GB could have been corrupted. Thanks Kai Wohlfahrt for reporting. -* Thanks to ([@GuyAv46](https://github.com/GuyAv46)) hnswlib inner product now is more consitent accross architectures (SSE, AVX, etc). -* - -**version 0.6.1** - -* Thanks to ([@tony-kuo](https://github.com/tony-kuo)) hnswlib AVX512 and AVX builds are not backwards-compatible with older SSE and non-AVX512 architectures. -* Thanks to ([@psobot](https://github.com/psobot)) there is now a sencible message instead of segfault when passing a scalar to get_items. -* Thanks to ([@urigoren](https://github.com/urigoren)) hnswlib has a lazy index creation python wrapper. - -**version 0.6.0** -* Thanks to ([@dyashuni](https://github.com/dyashuni)) hnswlib now uses github actions for CI, there is a search speedup in some scenarios with deletions. `unmark_deleted(label)` is now also a part of the python interface (note now it throws an exception for double deletions). -* Thanks to ([@slice4e](https://github.com/slice4e)) we now support AVX512; thanks to ([@LTLA](https://github.com/LTLA)) the cmake interface for the lib is now updated. -* Thanks to ([@alonre24](https://github.com/alonre24)) we now have a python bindings for brute-force (and examples for recall tuning: [TESTING_RECALL.md](TESTING_RECALL.md). -* Thanks to ([@dorosy-yeong](https://github.com/dorosy-yeong)) there is a bug fixed in the handling large quantities of deleted elements and large K. - - +* Added support to filtering (#402, #430) by [@kishorenc](https://github.com/kishorenc) +* Added python interface for filtering (though note its performance is limited by GIL) (#417) by [@gtsoukas](https://github.com/gtsoukas) +* Added support for replacing the elements that were marked as delete with newly inserted elements (to control the size of the index, #418) by [@dyashuni](https://github.com/dyashuni) +* Fixed data races/deadlocks in updates/insertion, added stress test for multithreaded operation (#418) by [@dyashuni](https://github.com/dyashuni) +* Documentation, tests, exception handling, refactoring (#375, #379, #380, #395, #396, #401, #406, #404, #409, #410, #416, #415, #431, #432, #433) by [@jlmelville](https://github.com/jlmelville), [@dyashuni](https://github.com/dyashuni), [@kishorenc](https://github.com/kishorenc), [@korzhenevski](https://github.com/korzhenevski), [@yoshoku](https://github.com/yoshoku), [@jianshu93](https://github.com/jianshu93), [@PLNech](https://github.com/PLNech) +* global linkages (#383) by [@MasterAler](https://github.com/MasterAler), USE_SSE usage in MSVC (#408) by [@alxvth](https://github.com/alxvth) ### Highlights: 1) Lightweight, header-only, no dependencies other than C++ 11 -2) Interfaces for C++, Java, Python and R (https://github.com/jlmelville/rcpphnsw). -3) Has full support for incremental index construction. Has support for element deletions +2) Interfaces for C++, Python, external support for Java and R (https://github.com/jlmelville/rcpphnsw). +3) Has full support for incremental index construction and updating the elements. Has support for element deletions (by marking them in index). Index is picklable. 4) Can work with custom user defined distances (C++). 5) Significantly less memory footprint and faster build time compared to current nmslib's implementation. @@ -50,23 +38,26 @@ Note that inner product is not an actual metric. An element can be closer to som For other spaces use the nmslib library https://github.com/nmslib/nmslib. -#### Short API description +#### API description * `hnswlib.Index(space, dim)` creates a non-initialized index an HNSW in space `space` with integer dimension `dim`. `hnswlib.Index` methods: -* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100)` initializes the index from with no elements. +* `init_index(max_elements, M = 16, ef_construction = 200, random_seed = 100, allow_replace_deleted = False)` initializes the index from with no elements. * `max_elements` defines the maximum number of elements that can be stored in the structure(can be increased/shrunk). * `ef_construction` defines a construction time/accuracy trade-off (see [ALGO_PARAMS.md](ALGO_PARAMS.md)). * `M` defines tha maximum number of outgoing connections in the graph ([ALGO_PARAMS.md](ALGO_PARAMS.md)). + * `allow_replace_deleted` enables replacing of deleted elements with new added ones. -* `add_items(data, ids, num_threads = -1)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. +* `add_items(data, ids, num_threads = -1, replace_deleted = False)` - inserts the `data`(numpy array of vectors, shape:`N*dim`) into the structure. * `num_threads` sets the number of cpu threads to use (-1 means use default). * `ids` are optional N-size numpy array of integer labels for all elements in `data`. - If index already has the elements with the same labels, their features will be updated. Note that update procedure is slower than insertion of a new element, but more memory- and query-efficient. + * `replace_deleted` replaces deleted elements. Note it allows to save memory. + - to use it `init_index` should be called with `allow_replace_deleted=True` * Thread-safe with other `add_items` calls, but not with `knn_query`. * `mark_deleted(label)` - marks the element as deleted, so it will be omitted from search results. Throws an exception if it is already deleted. -* + * `unmark_deleted(label)` - unmarks the element as deleted, so it will be not be omitted from search results. * `resize_index(new_size)` - changes the maximum capacity of the index. Not thread safe with `add_items` and `knn_query`. @@ -74,13 +65,15 @@ For other spaces use the nmslib library https://github.com/nmslib/nmslib. * `set_ef(ef)` - sets the query time accuracy/speed trade-off, defined by the `ef` parameter ( [ALGO_PARAMS.md](ALGO_PARAMS.md)). Note that the parameter is currently not saved along with the index, so you need to set it manually after loading. -* `knn_query(data, k = 1, num_threads = -1)` make a batch query for `k` closest elements for each element of the +* `knn_query(data, k = 1, num_threads = -1, filter = None)` make a batch query for `k` closest elements for each element of the * `data` (shape:`N*dim`). Returns a numpy array of (shape:`N*k`). * `num_threads` sets the number of cpu threads to use (-1 means use default). + * `filter` filters elements by its labels, returns elements with allowed ids. Note that search with a filter works slow in python in multithreaded mode. It is recommended to set `num_threads=1` * Thread-safe with other `knn_query` calls, but not with `add_items`. -* `load_index(path_to_index, max_elements = 0)` loads the index from persistence to the uninitialized index. +* `load_index(path_to_index, max_elements = 0, allow_replace_deleted = False)` loads the index from persistence to the uninitialized index. * `max_elements`(optional) resets the maximum number of elements in the structure. + * `allow_replace_deleted` specifies whether the index being loaded has enabled replacing of deleted elements. * `save_index(path_to_index)` saves the index from persistence. @@ -118,6 +111,12 @@ Properties of `hnswlib.Index` that support reading and writing: #### Python bindings examples +[See more examples here](examples/python/EXAMPLES.md): +* Creating index, inserting elements, searching, serialization/deserialization +* Filtering during the search with a boolean function +* Deleting the elements and reusing the memory of the deleted elements for newly added elements + +An example of creating index, inserting elements, searching and pickle serialization: ```python import hnswlib import numpy as np @@ -142,7 +141,7 @@ p.add_items(data, ids) # Controlling the recall by setting ef: p.set_ef(50) # ef should always be > k -# Query dataset, k - number of closest elements (returns 2 numpy arrays) +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) labels, distances = p.knn_query(data, k = 1) # Index objects support pickling @@ -155,7 +154,6 @@ print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") - ``` An example with updates after serialization/deserialization: @@ -196,7 +194,6 @@ p.set_ef(10) # By default using all available cores p.set_num_threads(4) - print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -226,6 +223,14 @@ labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") ``` +#### C++ examples +[See examples here](examples/cpp/EXAMPLES.md): +* creating index, inserting elements, searching, serialization/deserialization +* filtering during the search with a boolean function +* deleting the elements and reusing the memory of the deleted elements for newly added elements +* multithreaded usage + + ### Bindings installation You can install from sources: @@ -245,9 +250,9 @@ Contributions are highly welcome! Please make pull requests against the `develop` branch. -When making changes please run tests (and please add a test to `python_bindings/tests` in case there is new functionality): +When making changes please run tests (and please add a test to `tests/python` in case there is new functionality): ```bash -python -m unittest discover --start-directory python_bindings/tests --pattern "*_test*.py +python -m unittest discover --start-directory tests/python --pattern "bindings_test*.py" ``` @@ -259,20 +264,23 @@ https://github.com/facebookresearch/faiss ["Revisiting the Inverted Indices for Billion-Scale Approximate Nearest Neighbors"](https://arxiv.org/abs/1802.02422) (current state-of-the-art in compressed indexes, C++): https://github.com/dbaranchuk/ivf-hnsw +* Amazon PECOS https://github.com/amzn/pecos * TOROS N2 (python, C++): https://github.com/kakao/n2 * Online HNSW (C++): https://github.com/andrusha97/online-hnsw) * Go implementation: https://github.com/Bithack/go-hnsw * Python implementation (as a part of the clustering code by by Matteo Dell'Amico): https://github.com/matteodellamico/flexible-clustering +* Julia implmentation https://github.com/JuliaNeighbors/HNSW.jl * Java implementation: https://github.com/jelmerk/hnswlib * Java bindings using Java Native Access: https://github.com/stepstone-tech/hnswlib-jna -* .Net implementation: https://github.com/microsoft/HNSW.Net +* .Net implementation: https://github.com/curiosity-ai/hnsw-sharp * CUDA implementation: https://github.com/js1010/cuhnsw - +* Rust implementation https://github.com/rust-cv/hnsw +* Rust implementation for memory and thread safety purposes and There is A Trait to enable the user to implement its own distances. It takes as data slices of types T satisfying T:Serialize+Clone+Send+Sync.: https://github.com/jean-pierreBoth/hnswlib-rs ### 200M SIFT test reproduction To download and extract the bigann dataset (from root directory): ```bash -python3 download_bigann.py +python tests/cpp/download_bigann.py ``` To compile: ```bash @@ -292,7 +300,7 @@ The size of the BigANN subset (in millions) is controlled by the variable **subs ### Updates test To generate testing data (from root directory): ```bash -cd examples +cd tests/cpp python update_gen_data.py ``` To compile (from root directory): diff --git a/TESTING_RECALL.md b/TESTING_RECALL.md index 23a6f654..29136ec8 100644 --- a/TESTING_RECALL.md +++ b/TESTING_RECALL.md @@ -27,7 +27,7 @@ max_elements defines the maximum number of elements that can be stored in the st ### measuring recall example -``` +```python import hnswlib import numpy as np diff --git a/examples/cpp/EXAMPLES.md b/examples/cpp/EXAMPLES.md new file mode 100644 index 00000000..3af603d4 --- /dev/null +++ b/examples/cpp/EXAMPLES.md @@ -0,0 +1,185 @@ +# C++ examples + +Creating index, inserting elements, searching and serialization +```cpp +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Query the elements for themselves and measure recall + float correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + // Serialize index + std::string hnsw_path = "hnsw.bin"; + alg_hnsw->saveIndex(hnsw_path); + delete alg_hnsw; + + // Deserialize index and check recall + alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); + correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + recall = (float)correct / max_elements; + std::cout << "Recall of deserialized index: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} +``` + +An example of filtering with a boolean function during the search: +```cpp +#include "../../hnswlib/hnswlib.h" + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + for (int i = 0; i < max_elements; i++) { + std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); + for (auto item: result) { + if (item.second % 2 == 1) std::cout << "Error: found odd label\n"; + } + } + + delete[] data; + delete alg_hnsw; + return 0; +} +``` + +An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): +```cpp +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, 100, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + for (int i = 0; i < num_deleted; i++) { + alg_hnsw->markDelete(i); + } + + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + for (int i = 0; i < num_deleted; i++) { + int label = max_elements + i; + alg_hnsw->addPoint(add_data + i * dim, label, true); + } + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} +``` + +Multithreaded examples: +* Creating index, inserting elements, searching [example_mt_search.cpp](example_mt_search.cpp) +* Filtering during the search with a boolean function [example_mt_filter.cpp](example_mt_filter.cpp) +* Reusing the memory of the deleted elements when new elements are being added [example_mt_replace_deleted.cpp](example_mt_replace_deleted.cpp) \ No newline at end of file diff --git a/examples/cpp/example_filter.cpp b/examples/cpp/example_filter.cpp new file mode 100644 index 00000000..dc978c57 --- /dev/null +++ b/examples/cpp/example_filter.cpp @@ -0,0 +1,57 @@ +#include "../../hnswlib/hnswlib.h" + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + for (int i = 0; i < max_elements; i++) { + std::vector> result = alg_hnsw->searchKnnCloserFirst(data + i * dim, k, &pickIdsDivisibleByTwo); + for (auto item: result) { + if (item.second % 2 == 1) std::cout << "Error: found odd label\n"; + } + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_filter.cpp b/examples/cpp/example_mt_filter.cpp new file mode 100644 index 00000000..b39de4c3 --- /dev/null +++ b/examples/cpp/example_mt_filter.cpp @@ -0,0 +1,124 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +// Filter that allows labels divisible by divisor +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(hnswlib::labeltype label_id) { + return label_id % divisor == 0; + } +}; + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Create filter that allows only even labels + PickDivisibleIds pickIdsDivisibleByTwo(2); + + // Query the elements for themselves with filter and check returned labels + int k = 10; + std::vector neighbors(max_elements * k); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, k, &pickIdsDivisibleByTwo); + for (int i = 0; i < k; i++) { + hnswlib::labeltype label = result.top().second; + result.pop(); + neighbors[row * k + i] = label; + } + }); + + for (hnswlib::labeltype label: neighbors) { + if (label % 2 == 1) std::cout << "Error: found odd label\n"; + } + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_replace_deleted.cpp b/examples/cpp/example_mt_replace_deleted.cpp new file mode 100644 index 00000000..40a94ce7 --- /dev/null +++ b/examples/cpp/example_mt_replace_deleted.cpp @@ -0,0 +1,114 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index with allow_replace_deleted=true + int seed = 100; + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->markDelete(row); + }); + + // Generate additional random data + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + ParallelFor(0, num_deleted, num_threads, [&](size_t row, size_t threadId) { + hnswlib::labeltype label = max_elements + row; + alg_hnsw->addPoint((void*)(add_data + dim * row), label, true); + }); + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_mt_search.cpp b/examples/cpp/example_mt_search.cpp new file mode 100644 index 00000000..e315b9ff --- /dev/null +++ b/examples/cpp/example_mt_search.cpp @@ -0,0 +1,107 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 20; // Number of threads for operations with index + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(data + dim * row), row); + }); + + // Query the elements for themselves and measure recall + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_replace_deleted.cpp b/examples/cpp/example_replace_deleted.cpp new file mode 100644 index 00000000..64c995bb --- /dev/null +++ b/examples/cpp/example_replace_deleted.cpp @@ -0,0 +1,54 @@ +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index with allow_replace_deleted=true + int seed = 100; + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction, seed, true); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Mark first half of elements as deleted + int num_deleted = max_elements / 2; + for (int i = 0; i < num_deleted; i++) { + alg_hnsw->markDelete(i); + } + + // Generate additional random data + float* add_data = new float[dim * num_deleted]; + for (int i = 0; i < dim * num_deleted; i++) { + add_data[i] = distrib_real(rng); + } + + // Replace deleted data with new elements + // Maximum number of elements is reached therefore we cannot add new items, + // but we can replace the deleted ones by using replace_deleted=true + for (int i = 0; i < num_deleted; i++) { + hnswlib::labeltype label = max_elements + i; + alg_hnsw->addPoint(add_data + i * dim, label, true); + } + + delete[] data; + delete[] add_data; + delete alg_hnsw; + return 0; +} diff --git a/examples/cpp/example_search.cpp b/examples/cpp/example_search.cpp new file mode 100644 index 00000000..2c28738f --- /dev/null +++ b/examples/cpp/example_search.cpp @@ -0,0 +1,58 @@ +#include "../../hnswlib/hnswlib.h" + + +int main() { + int dim = 16; // Dimension of the elements + int max_elements = 10000; // Maximum number of elements, should be known beforehand + int M = 16; // Tightly connected with internal dimensionality of the data + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + + // Initing index + hnswlib::L2Space space(dim); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, M, ef_construction); + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data = new float[dim * max_elements]; + for (int i = 0; i < dim * max_elements; i++) { + data[i] = distrib_real(rng); + } + + // Add data to index + for (int i = 0; i < max_elements; i++) { + alg_hnsw->addPoint(data + i * dim, i); + } + + // Query the elements for themselves and measure recall + float correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + + // Serialize index + std::string hnsw_path = "hnsw.bin"; + alg_hnsw->saveIndex(hnsw_path); + delete alg_hnsw; + + // Deserialize index and check recall + alg_hnsw = new hnswlib::HierarchicalNSW(&space, hnsw_path); + correct = 0; + for (int i = 0; i < max_elements; i++) { + std::priority_queue> result = alg_hnsw->searchKnn(data + i * dim, 1); + hnswlib::labeltype label = result.top().second; + if (label == i) correct++; + } + recall = (float)correct / max_elements; + std::cout << "Recall of deserialized index: " << recall << "\n"; + + delete[] data; + delete alg_hnsw; + return 0; +} diff --git a/examples/git_tester.py b/examples/git_tester.py deleted file mode 100644 index aaf70c82..00000000 --- a/examples/git_tester.py +++ /dev/null @@ -1,34 +0,0 @@ -from pydriller import Repository -import os -import datetime -os.system("cp examples/speedtest.py examples/speedtest2.py") # the file has to be outside of git -for idx, commit in enumerate(Repository('.', from_tag="v0.6.0").traverse_commits()): - name=commit.msg.replace('\n', ' ').replace('\r', ' ') - print(idx, commit.hash, name) - - - -for commit in Repository('.', from_tag="v0.6.0").traverse_commits(): - - name=commit.msg.replace('\n', ' ').replace('\r', ' ') - print(commit.hash, name) - - os.system(f"git checkout {commit.hash}; rm -rf build; ") - print("\n\n--------------------\n\n") - ret=os.system("python -m pip install .") - print(ret) - - if ret != 0: - print ("build failed!!!!") - print ("build failed!!!!") - print ("build failed!!!!") - print ("build failed!!!!") - continue - - os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 64 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 1') - os.system(f'python examples/speedtest2.py -n "{name}" -d 4 -t 24') - os.system(f'python examples/speedtest2.py -n "{name}" -d 128 -t 24') - - diff --git a/examples/python/EXAMPLES.md b/examples/python/EXAMPLES.md new file mode 100644 index 00000000..6c1b20e4 --- /dev/null +++ b/examples/python/EXAMPLES.md @@ -0,0 +1,207 @@ +# Python bindings examples + +Creating index, inserting elements, searching and pickle serialization: +```python +import hnswlib +import numpy as np +import pickle + +dim = 128 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) +ids = np.arange(num_elements) + +# Declaring index +p = hnswlib.Index(space = 'l2', dim = dim) # possible options are l2, cosine or ip + +# Initializing index - the maximum number of elements should be known beforehand +p.init_index(max_elements = num_elements, ef_construction = 200, M = 16) + +# Element insertion (can be called several times): +p.add_items(data, ids) + +# Controlling the recall by setting ef: +p.set_ef(50) # ef should always be > k + +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) +labels, distances = p.knn_query(data, k = 1) + +# Index objects support pickling +# WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! +# Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load +p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip + +### Index parameters are exposed as class properties: +print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") +print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") +print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") +print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") +``` + +An example with updates after serialization/deserialization: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# We split the data in two batches: +data1 = data[:num_elements // 2] +data2 = data[num_elements // 2:] + +# Declaring index +p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initializing index +# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded +# during insertion of an element. +# The capacity can be increased by saving/loading the index, see below. +# +# ef_construction - controls index search speed/build speed tradeoff +# +# M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) +# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction + +p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +p.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(4) + +print("Adding first batch of %d elements" % (len(data1))) +p.add_items(data1) + +# Query the elements for themselves and measure recall: +labels, distances = p.knn_query(data1, k=1) +print("Recall for the first batch:", np.mean(labels.reshape(-1) == np.arange(len(data1))), "\n") + +# Serializing and deleting the index: +index_path='first_half.bin' +print("Saving index to '%s'" % index_path) +p.save_index("first_half.bin") +del p + +# Re-initializing, loading the index +p = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. + +print("\nLoading index from 'first_half.bin'\n") + +# Increase the total capacity (max_elements), so that it will handle the new data +p.load_index("first_half.bin", max_elements = num_elements) + +print("Adding the second batch of %d elements" % (len(data2))) +p.add_items(data2) + +# Query the elements for themselves and measure recall: +labels, distances = p.knn_query(data, k=1) +print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") +``` + +An example with a symbolic filter `filter_function` during the search: +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +# Warning: search with python filter works slow in multithreaded mode, therefore we set num_threads=1 +labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) +# labels contain only elements with even id +``` + +An example with reusing the memory of the deleted elements when new elements are being added (via `allow_replace_deleted` flag): +```python +import hnswlib +import numpy as np + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only +``` \ No newline at end of file diff --git a/examples/example.py b/examples/python/example.py similarity index 94% rename from examples/example.py rename to examples/python/example.py index a08955a1..3d6d7477 100644 --- a/examples/example.py +++ b/examples/python/example.py @@ -1,6 +1,12 @@ +import os import hnswlib import numpy as np + +""" +Example of index building, search and serialization/deserialization +""" + dim = 16 num_elements = 10000 @@ -34,7 +40,6 @@ # By default using all available cores p.set_num_threads(4) - print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -62,3 +67,5 @@ # Query the elements for themselves and measure recall: labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") + +os.remove("first_half.bin") diff --git a/examples/python/example_filter.py b/examples/python/example_filter.py new file mode 100644 index 00000000..add22a3d --- /dev/null +++ b/examples/python/example_filter.py @@ -0,0 +1,46 @@ +import hnswlib +import numpy as np + + +""" +Example of filtering elements when searching +""" + +dim = 16 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +print("Adding %d elements" % (len(data))) +# Added elements will have consecutive ids +hnsw_index.add_items(data, ids=np.arange(num_elements)) + +print("Querying only even elements") +# Define filter function that allows only even ids +filter_function = lambda idx: idx%2 == 0 +# Query the elements for themselves and search only for even elements: +# Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 +labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) +# labels contain only elements with even id diff --git a/examples/python/example_replace_deleted.py b/examples/python/example_replace_deleted.py new file mode 100644 index 00000000..3c0b62e7 --- /dev/null +++ b/examples/python/example_replace_deleted.py @@ -0,0 +1,55 @@ +import hnswlib +import numpy as np + + +""" +Example of replacing deleted elements with new ones +""" + +dim = 16 +num_elements = 1_000 +max_num_elements = 2 * num_elements + +# Generating sample data +labels1 = np.arange(0, num_elements) +data1 = np.float32(np.random.random((num_elements, dim))) # batch 1 +labels2 = np.arange(num_elements, 2 * num_elements) +data2 = np.float32(np.random.random((num_elements, dim))) # batch 2 +labels3 = np.arange(2 * num_elements, 3 * num_elements) +data3 = np.float32(np.random.random((num_elements, dim))) # batch 3 + +# Declaring index +hnsw_index = hnswlib.Index(space='l2', dim=dim) + +# Initiating index +# max_elements - the maximum number of elements, should be known beforehand +# (probably will be made optional in the future) +# +# ef_construction - controls index search speed/build speed tradeoff +# M - is tightly connected with internal dimensionality of the data +# strongly affects the memory consumption + +# Enable replacing of deleted elements +hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +hnsw_index.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +hnsw_index.set_num_threads(4) + +# Add batch 1 and 2 data +hnsw_index.add_items(data1, labels1) +hnsw_index.add_items(data2, labels2) # Note: maximum number of elements is reached + +# Delete data of batch 2 +for label in labels2: + hnsw_index.mark_deleted(label) + +# Replace deleted elements +# Maximum number of elements is reached therefore we cannot add new items, +# but we can replace the deleted ones by using replace_deleted=True +hnsw_index.add_items(data3, labels3, replace_deleted=True) +# hnsw_index contains the data of batch 1 and batch 3 only diff --git a/examples/python/example_search.py b/examples/python/example_search.py new file mode 100644 index 00000000..4581843b --- /dev/null +++ b/examples/python/example_search.py @@ -0,0 +1,41 @@ +import hnswlib +import numpy as np +import pickle + + +""" +Example of search +""" + +dim = 128 +num_elements = 10000 + +# Generating sample data +data = np.float32(np.random.random((num_elements, dim))) +ids = np.arange(num_elements) + +# Declaring index +p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# Initializing index - the maximum number of elements should be known beforehand +p.init_index(max_elements=num_elements, ef_construction=200, M=16) + +# Element insertion (can be called several times): +p.add_items(data, ids) + +# Controlling the recall by setting ef: +p.set_ef(50) # ef should always be > k + +# Query dataset, k - number of the closest elements (returns 2 numpy arrays) +labels, distances = p.knn_query(data, k=1) + +# Index objects support pickling +# WARNING: serialization via pickle.dumps(p) or p.__getstate__() is NOT thread-safe with p.add_items method! +# Note: ef parameter is included in serialization; random number generator is initialized with random_seed on Index load +p_copy = pickle.loads(pickle.dumps(p)) # creates a copy of index p using pickle round-trip + +### Index parameters are exposed as class properties: +print(f"Parameters passed to constructor: space={p_copy.space}, dim={p_copy.dim}") +print(f"Index construction: M={p_copy.M}, ef_construction={p_copy.ef_construction}") +print(f"Index size is {p_copy.element_count} and index capacity is {p_copy.max_elements}") +print(f"Search speed/quality trade-off parameter: ef={p_copy.ef}") diff --git a/examples/example_old.py b/examples/python/example_serialization.py similarity index 59% rename from examples/example_old.py rename to examples/python/example_serialization.py index 6654a027..76ca1436 100644 --- a/examples/example_old.py +++ b/examples/python/example_serialization.py @@ -1,34 +1,45 @@ +import os + import hnswlib import numpy as np + +""" +Example of serialization/deserialization +""" + dim = 16 num_elements = 10000 # Generating sample data data = np.float32(np.random.random((num_elements, dim))) +# We split the data in two batches: +data1 = data[:num_elements // 2] +data2 = data[num_elements // 2:] + # Declaring index p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip -# Initing index -# max_elements - the maximum number of elements, should be known beforehand -# (probably will be made optional in the future) +# Initializing index +# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded +# during insertion of an element. +# The capacity can be increased by saving/loading the index, see below. # # ef_construction - controls index search speed/build speed tradeoff -# M - is tightly connected with internal dimensionality of the data -# stronlgy affects the memory consumption +# +# M - is tightly connected with internal dimensionality of the data. Strongly affects memory consumption (~M) +# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction -p.init_index(max_elements=num_elements, ef_construction=100, M=16) +p.init_index(max_elements=num_elements//2, ef_construction=100, M=16) # Controlling the recall by setting ef: # higher ef leads to better accuracy, but slower search p.set_ef(10) -p.set_num_threads(4) # by default using all available cores - -# We split the data in two batches: -data1 = data[:num_elements // 2] -data2 = data[num_elements // 2:] +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(4) print("Adding first batch of %d elements" % (len(data1))) p.add_items(data1) @@ -43,11 +54,13 @@ p.save_index("first_half.bin") del p -# Reiniting, loading the index -p = hnswlib.Index(space='l2', dim=dim) # you can change the sa +# Re-initializing, loading the index +p = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. print("\nLoading index from 'first_half.bin'\n") -p.load_index("first_half.bin") + +# Increase the total capacity (max_elements), so that it will handle the new data +p.load_index("first_half.bin", max_elements = num_elements) print("Adding the second batch of %d elements" % (len(data2))) p.add_items(data2) @@ -55,3 +68,5 @@ # Query the elements for themselves and measure recall: labels, distances = p.knn_query(data, k=1) print("Recall for two batches:", np.mean(labels.reshape(-1) == np.arange(len(data))), "\n") + +os.remove("first_half.bin") diff --git a/examples/pyw_hnswlib.py b/examples/python/pyw_hnswlib.py similarity index 95% rename from examples/pyw_hnswlib.py rename to examples/python/pyw_hnswlib.py index aeb93f10..0ccfbc5e 100644 --- a/examples/pyw_hnswlib.py +++ b/examples/python/pyw_hnswlib.py @@ -4,6 +4,10 @@ import pickle +""" +Example of python wrapper for hnswlib that supports python objects as ids +""" + class Index(): def __init__(self, space, dim): self.index = hnswlib.Index(space, dim) diff --git a/examples/speedtest.py b/examples/speedtest.py deleted file mode 100644 index cf8e6085..00000000 --- a/examples/speedtest.py +++ /dev/null @@ -1,62 +0,0 @@ -import hnswlib -import numpy as np -import os.path -import time -import argparse - -# Use nargs to specify how many arguments an option should take. -ap = argparse.ArgumentParser() -ap.add_argument('-d') -ap.add_argument('-n') -ap.add_argument('-t') -args = ap.parse_args() -dim = int(args.d) -name = args.n -threads=int(args.t) -num_elements = 1000000 * 4//dim - -# Generating sample data -np.random.seed(1) -data = np.float32(np.random.random((num_elements, dim))) - - -index_path=f'speed_index{dim}.bin' -# Declaring index -p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip - -if not os.path.isfile(index_path) : - - p.init_index(max_elements=num_elements, ef_construction=100, M=16) - - # Controlling the recall by setting ef: - # higher ef leads to better accuracy, but slower search - p.set_ef(10) - - # Set number of threads used during batch search/construction - # By default using all available cores - p.set_num_threads(12) - - p.add_items(data) - - # Serializing and deleting the index: - - print("Saving index to '%s'" % index_path) - p.save_index(index_path) -p.set_num_threads(threads) -times=[] -time.sleep(10) -p.set_ef(100) -for _ in range(3): - p.load_index(index_path) - for _ in range(10): - t0=time.time() - labels, distances = p.knn_query(data, k=1) - tt=time.time()-t0 - times.append(tt) - print(f"{tt} seconds") -str_out=f"mean time:{np.mean(times)}, median time:{np.median(times)}, std time {np.std(times)} {name}" -print(str_out) -with open (f"log_{dim}_t{threads}.txt","a") as f: - f.write(str_out+"\n") - f.flush() - diff --git a/hnswlib/bruteforce.h b/hnswlib/bruteforce.h index 24260400..30b33ae9 100644 --- a/hnswlib/bruteforce.h +++ b/hnswlib/bruteforce.h @@ -3,150 +3,165 @@ #include #include #include +#include namespace hnswlib { - template - class BruteforceSearch : public AlgorithmInterface { - public: - BruteforceSearch(SpaceInterface *s) { - - } - BruteforceSearch(SpaceInterface *s, const std::string &location) { - loadIndex(location, s); - } - - BruteforceSearch(SpaceInterface *s, size_t maxElements) { - maxelements_ = maxElements; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxElements * size_per_element_); - if (data_ == nullptr) - std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); - cur_element_count = 0; - } - - ~BruteforceSearch() { - free(data_); - } - - char *data_; - size_t maxelements_; - size_t cur_element_count; - size_t size_per_element_; - - size_t data_size_; - DISTFUNC fstdistfunc_; - void *dist_func_param_; - std::mutex index_lock; - - std::unordered_map dict_external_to_internal; - - void addPoint(const void *datapoint, labeltype label) { - - int idx; - { - std::unique_lock lock(index_lock); - - - - auto search=dict_external_to_internal.find(label); - if (search != dict_external_to_internal.end()) { - idx=search->second; - } - else{ - if (cur_element_count >= maxelements_) { - throw std::runtime_error("The number of elements exceeds the specified limit\n"); - } - idx=cur_element_count; - dict_external_to_internal[label] = idx; - cur_element_count++; +template +class BruteforceSearch : public AlgorithmInterface { + public: + char *data_; + size_t maxelements_; + size_t cur_element_count; + size_t size_per_element_; + + size_t data_size_; + DISTFUNC fstdistfunc_; + void *dist_func_param_; + std::mutex index_lock; + + std::unordered_map dict_external_to_internal; + + + BruteforceSearch(SpaceInterface *s) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + } + + + BruteforceSearch(SpaceInterface *s, const std::string &location) + : data_(nullptr), + maxelements_(0), + cur_element_count(0), + size_per_element_(0), + data_size_(0), + dist_func_param_(nullptr) { + loadIndex(location, s); + } + + + BruteforceSearch(SpaceInterface *s, size_t maxElements) { + maxelements_ = maxElements; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxElements * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: BruteforceSearch failed to allocate data"); + cur_element_count = 0; + } + + + ~BruteforceSearch() { + free(data_); + } + + + void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) { + int idx; + { + std::unique_lock lock(index_lock); + + auto search = dict_external_to_internal.find(label); + if (search != dict_external_to_internal.end()) { + idx = search->second; + } else { + if (cur_element_count >= maxelements_) { + throw std::runtime_error("The number of elements exceeds the specified limit\n"); } + idx = cur_element_count; + dict_external_to_internal[label] = idx; + cur_element_count++; } - memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); - memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); - - - - - }; - - void removePoint(labeltype cur_external) { - size_t cur_c=dict_external_to_internal[cur_external]; - - dict_external_to_internal.erase(cur_external); - - labeltype label=*((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); - dict_external_to_internal[label]=cur_c; - memcpy(data_ + size_per_element_ * cur_c, - data_ + size_per_element_ * (cur_element_count-1), - data_size_+sizeof(labeltype)); - cur_element_count--; - } - - - std::priority_queue> - searchKnn(const void *query_data, size_t k) const { - std::priority_queue> topResults; - if (cur_element_count == 0) return topResults; - for (int i = 0; i < k; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); + memcpy(data_ + size_per_element_ * idx + data_size_, &label, sizeof(labeltype)); + memcpy(data_ + size_per_element_ * idx, datapoint, data_size_); + } + + + void removePoint(labeltype cur_external) { + size_t cur_c = dict_external_to_internal[cur_external]; + + dict_external_to_internal.erase(cur_external); + + labeltype label = *((labeltype*)(data_ + size_per_element_ * (cur_element_count-1) + data_size_)); + dict_external_to_internal[label] = cur_c; + memcpy(data_ + size_per_element_ * cur_c, + data_ + size_per_element_ * (cur_element_count-1), + data_size_+sizeof(labeltype)); + cur_element_count--; + } + + + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + assert(k <= cur_element_count); + std::priority_queue> topResults; + if (cur_element_count == 0) return topResults; + for (int i = 0; i < k; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.push(std::pair(dist, label)); } - dist_t lastdist = topResults.top().first; - for (int i = k; i < cur_element_count; i++) { - dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); - if (dist <= lastdist) { - topResults.push(std::pair(dist, *((labeltype *) (data_ + size_per_element_ * i + - data_size_)))); - if (topResults.size() > k) - topResults.pop(); - lastdist = topResults.top().first; + } + dist_t lastdist = topResults.empty() ? std::numeric_limits::max() : topResults.top().first; + for (int i = k; i < cur_element_count; i++) { + dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); + if (dist <= lastdist) { + labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); + if ((!isIdAllowed) || (*isIdAllowed)(label)) { + topResults.push(std::pair(dist, label)); } + if (topResults.size() > k) + topResults.pop(); + if (!topResults.empty()) { + lastdist = topResults.top().first; + } } - return topResults; - }; - - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, maxelements_); - writeBinaryPOD(output, size_per_element_); - writeBinaryPOD(output, cur_element_count); + } + return topResults; + } - output.write(data_, maxelements_ * size_per_element_); - output.close(); - } + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; - void loadIndex(const std::string &location, SpaceInterface *s) { + writeBinaryPOD(output, maxelements_); + writeBinaryPOD(output, size_per_element_); + writeBinaryPOD(output, cur_element_count); + output.write(data_, maxelements_ * size_per_element_); - std::ifstream input(location, std::ios::binary); - std::streampos position; + output.close(); + } - readBinaryPOD(input, maxelements_); - readBinaryPOD(input, size_per_element_); - readBinaryPOD(input, cur_element_count); - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - size_per_element_ = data_size_ + sizeof(labeltype); - data_ = (char *) malloc(maxelements_ * size_per_element_); - if (data_ == nullptr) - std::runtime_error("Not enough memory: loadIndex failed to allocate data"); + void loadIndex(const std::string &location, SpaceInterface *s) { + std::ifstream input(location, std::ios::binary); + std::streampos position; - input.read(data_, maxelements_ * size_per_element_); + readBinaryPOD(input, maxelements_); + readBinaryPOD(input, size_per_element_); + readBinaryPOD(input, cur_element_count); - input.close(); + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + size_per_element_ = data_size_ + sizeof(labeltype); + data_ = (char *) malloc(maxelements_ * size_per_element_); + if (data_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate data"); - } + input.read(data_, maxelements_ * size_per_element_); - }; -} + input.close(); + } +}; +} // namespace hnswlib diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index e95e0b52..bef00170 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -10,1197 +10,1262 @@ #include namespace hnswlib { - typedef unsigned int tableint; - typedef unsigned int linklistsizeint; - - template - class HierarchicalNSW : public AlgorithmInterface { - public: - static const tableint max_update_element_locks = 65536; - HierarchicalNSW(SpaceInterface *s) { - } +typedef unsigned int tableint; +typedef unsigned int linklistsizeint; + +template +class HierarchicalNSW : public AlgorithmInterface { + public: + static const tableint MAX_LABEL_OPERATION_LOCKS = 65536; + static const unsigned char DELETE_MARK = 0x01; + + size_t max_elements_{0}; + mutable std::atomic cur_element_count{0}; // current number of elements + size_t size_data_per_element_{0}; + size_t size_links_per_element_{0}; + mutable std::atomic num_deleted_{0}; // number of deleted elements + size_t M_{0}; + size_t maxM_{0}; + size_t maxM0_{0}; + size_t ef_construction_{0}; + size_t ef_{ 0 }; + + double mult_{0.0}, revSize_{0.0}; + int maxlevel_{0}; + + VisitedListPool *visited_list_pool_{nullptr}; + + // Locks operations with element by label value + mutable std::vector label_op_locks_; + + std::mutex global; + std::vector link_list_locks_; + + tableint enterpoint_node_{0}; + + size_t size_links_level0_{0}; + size_t offsetData_{0}, offsetLevel0_{0}, label_offset_{ 0 }; + + char *data_level0_memory_{nullptr}; + char **linkLists_{nullptr}; + std::vector element_levels_; // keeps level of each element + + size_t data_size_{0}; + + DISTFUNC fstdistfunc_; + void *dist_func_param_{nullptr}; + + mutable std::mutex label_lookup_lock; // lock for label_lookup_ + std::unordered_map label_lookup_; + + std::default_random_engine level_generator_; + std::default_random_engine update_probability_generator_; + + mutable std::atomic metric_distance_computations{0}; + mutable std::atomic metric_hops{0}; + + bool allow_replace_deleted_ = false; // flag to replace deleted elements (marked as deleted) during insertions + + std::mutex deleted_elements_lock; // lock for deleted_elements + std::unordered_set deleted_elements; // contains internal ids of deleted elements + + + HierarchicalNSW(SpaceInterface *s) { + } + + + HierarchicalNSW( + SpaceInterface *s, + const std::string &location, + bool nmslib = false, + size_t max_elements = 0, + bool allow_replace_deleted = false) + : allow_replace_deleted_(allow_replace_deleted) { + loadIndex(location, s, max_elements); + } + + + HierarchicalNSW( + SpaceInterface *s, + size_t max_elements, + size_t M = 16, + size_t ef_construction = 200, + size_t random_seed = 100, + bool allow_replace_deleted = false) + : link_list_locks_(max_elements), + label_op_locks_(MAX_LABEL_OPERATION_LOCKS), + element_levels_(max_elements), + allow_replace_deleted_(allow_replace_deleted) { + max_elements_ = max_elements; + num_deleted_ = 0; + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + M_ = M; + maxM_ = M_; + maxM0_ = M_ * 2; + ef_construction_ = std::max(ef_construction, M_); + ef_ = 10; + + level_generator_.seed(random_seed); + update_probability_generator_.seed(random_seed + 1); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); + offsetData_ = size_links_level0_; + label_offset_ = size_links_level0_ + data_size_; + offsetLevel0_ = 0; + + data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory"); + + cur_element_count = 0; + + visited_list_pool_ = new VisitedListPool(1, max_elements); - HierarchicalNSW(SpaceInterface *s, const std::string &location, bool nmslib = false, size_t max_elements=0) { - loadIndex(location, s, max_elements); - } + // initializations for special treatment of the first node + enterpoint_node_ = -1; + maxlevel_ = -1; - HierarchicalNSW(SpaceInterface *s, size_t max_elements, size_t M = 16, size_t ef_construction = 200, size_t random_seed = 100) : - link_list_locks_(max_elements), link_list_update_locks_(max_update_element_locks), element_levels_(max_elements) { - max_elements_ = max_elements; - - num_deleted_ = 0; - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); - M_ = M; - maxM_ = M_; - maxM0_ = M_ * 2; - ef_construction_ = std::max(ef_construction,M_); - ef_ = 10; - - level_generator_.seed(random_seed); - update_probability_generator_.seed(random_seed + 1); - - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - size_data_per_element_ = size_links_level0_ + data_size_ + sizeof(labeltype); - offsetData_ = size_links_level0_; - label_offset_ = size_links_level0_ + data_size_; - offsetLevel0_ = 0; - - data_level0_memory_ = (char *) malloc(max_elements_ * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory"); - - cur_element_count = 0; - - visited_list_pool_ = new VisitedListPool(1, max_elements); - - //initializations for special treatment of the first node - enterpoint_node_ = -1; - maxlevel_ = -1; - - linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - mult_ = 1 / log(1.0 * M_); - revSize_ = 1.0 / mult_; - } + linkLists_ = (char **) malloc(sizeof(void *) * max_elements_); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: HierarchicalNSW failed to allocate linklists"); + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + mult_ = 1 / log(1.0 * M_); + revSize_ = 1.0 / mult_; + } - struct CompareByFirst { - constexpr bool operator()(std::pair const &a, - std::pair const &b) const noexcept { - return a.first < b.first; - } - }; - ~HierarchicalNSW() { + ~HierarchicalNSW() { + free(data_level0_memory_); + for (tableint i = 0; i < cur_element_count; i++) { + if (element_levels_[i] > 0) + free(linkLists_[i]); + } + free(linkLists_); + delete visited_list_pool_; + } + - free(data_level0_memory_); - for (tableint i = 0; i < cur_element_count; i++) { - if (element_levels_[i] > 0) - free(linkLists_[i]); - } - free(linkLists_); - delete visited_list_pool_; + struct CompareByFirst { + constexpr bool operator()(std::pair const& a, + std::pair const& b) const noexcept { + return a.first < b.first; } + }; - size_t max_elements_; - size_t cur_element_count; - size_t size_data_per_element_; - size_t size_links_per_element_; - size_t num_deleted_; - size_t M_; - size_t maxM_; - size_t maxM0_; - size_t ef_construction_; + void setEf(size_t ef) { + ef_ = ef; + } - double mult_, revSize_; - int maxlevel_; + inline std::mutex& getLabelOpMutex(labeltype label) const { + // calculate hash + size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1); + return label_op_locks_[lock_id]; + } - VisitedListPool *visited_list_pool_; - std::mutex cur_element_count_guard_; - std::vector link_list_locks_; + inline labeltype getExternalLabel(tableint internal_id) const { + labeltype return_label; + memcpy(&return_label, (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); + return return_label; + } - // Locks to prevent race condition during update/insert of an element at same time. - // Note: Locks for additions can also be used to prevent this race condition if the querying of KNN is not exposed along with update/inserts i.e multithread insert/update/query in parallel. - std::vector link_list_update_locks_; - tableint enterpoint_node_; - size_t size_links_level0_; - size_t offsetData_, offsetLevel0_; + inline void setExternalLabel(tableint internal_id, labeltype label) const { + memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); + } - char *data_level0_memory_; - char **linkLists_; - std::vector element_levels_; - size_t data_size_; + inline labeltype *getExternalLabeLp(tableint internal_id) const { + return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); + } - size_t label_offset_; - DISTFUNC fstdistfunc_; - void *dist_func_param_; - std::unordered_map label_lookup_; - std::default_random_engine level_generator_; - std::default_random_engine update_probability_generator_; + inline char *getDataByInternalId(tableint internal_id) const { + return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); + } - inline labeltype getExternalLabel(tableint internal_id) const { - labeltype return_label; - memcpy(&return_label,(data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), sizeof(labeltype)); - return return_label; - } - inline void setExternalLabel(tableint internal_id, labeltype label) const { - memcpy((data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_), &label, sizeof(labeltype)); - } + int getRandomLevel(double reverse_size) { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -log(distribution(level_generator_)) * reverse_size; + return (int) r; + } - inline labeltype *getExternalLabeLp(tableint internal_id) const { - return (labeltype *) (data_level0_memory_ + internal_id * size_data_per_element_ + label_offset_); - } + size_t getMaxElements() { + return max_elements_; + } - inline char *getDataByInternalId(tableint internal_id) const { - return (data_level0_memory_ + internal_id * size_data_per_element_ + offsetData_); - } + size_t getCurrentElementCount() { + return cur_element_count; + } - int getRandomLevel(double reverse_size) { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -log(distribution(level_generator_)) * reverse_size; - return (int) r; - } + size_t getDeletedCount() { + return num_deleted_; + } + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayer(tableint ep_id, const void *data_point, int layer) { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayer(tableint ep_id, const void *data_point, int layer) { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidateSet; - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidateSet; + dist_t lowerBound; + if (!isMarkedDeleted(ep_id)) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + top_candidates.emplace(dist, ep_id); + lowerBound = dist; + candidateSet.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidateSet.emplace(-lowerBound, ep_id); + } + visited_array[ep_id] = visited_array_tag; - dist_t lowerBound; - if (!isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - top_candidates.emplace(dist, ep_id); - lowerBound = dist; - candidateSet.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidateSet.emplace(-lowerBound, ep_id); + while (!candidateSet.empty()) { + std::pair curr_el_pair = candidateSet.top(); + if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { + break; } - visited_array[ep_id] = visited_array_tag; + candidateSet.pop(); - while (!candidateSet.empty()) { - std::pair curr_el_pair = candidateSet.top(); - if ((-curr_el_pair.first) > lowerBound && top_candidates.size() == ef_construction_) { - break; - } - candidateSet.pop(); + tableint curNodeNum = curr_el_pair.second; - tableint curNodeNum = curr_el_pair.second; + std::unique_lock lock(link_list_locks_[curNodeNum]); - std::unique_lock lock(link_list_locks_[curNodeNum]); - - int *data;// = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); - if (layer == 0) { - data = (int*)get_linklist0(curNodeNum); - } else { - data = (int*)get_linklist(curNodeNum, layer); + int *data; // = (int *)(linkList0_ + curNodeNum * size_links_per_element0_); + if (layer == 0) { + data = (int*)get_linklist0(curNodeNum); + } else { + data = (int*)get_linklist(curNodeNum, layer); // data = (int *) (linkLists_[curNodeNum] + (layer - 1) * size_links_per_element_); - } - size_t size = getListCount((linklistsizeint*)data); - tableint *datal = (tableint *) (data + 1); + } + size_t size = getListCount((linklistsizeint*)data); + tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + 1)), _MM_HINT_T0); #endif - for (size_t j = 0; j < size; j++) { - tableint candidate_id = *(datal + j); + for (size_t j = 0; j < size; j++) { + tableint candidate_id = *(datal + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); - _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(datal + j + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + j + 1)), _MM_HINT_T0); #endif - if (visited_array[candidate_id] == visited_array_tag) continue; - visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); + if (visited_array[candidate_id] == visited_array_tag) continue; + visited_array[candidate_id] = visited_array_tag; + char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { - candidateSet.emplace(-dist1, candidate_id); + dist_t dist1 = fstdistfunc_(data_point, currObj1, dist_func_param_); + if (top_candidates.size() < ef_construction_ || lowerBound > dist1) { + candidateSet.emplace(-dist1, candidate_id); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(candidateSet.top().second), _MM_HINT_T0); #endif - if (!isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist1, candidate_id); + if (!isMarkedDeleted(candidate_id)) + top_candidates.emplace(dist1, candidate_id); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } - visited_list_pool_->releaseVisitedList(vl); - - return top_candidates; + } + visited_list_pool_->releaseVisitedList(vl); + + return top_candidates; + } + + + template + std::priority_queue, std::vector>, CompareByFirst> + searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, BaseFilterFunctor* isIdAllowed = nullptr) const { + VisitedList *vl = visited_list_pool_->getFreeVisitedList(); + vl_type *visited_array = vl->mass; + vl_type visited_array_tag = vl->curV; + + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + std::priority_queue, std::vector>, CompareByFirst> candidate_set; + + dist_t lowerBound; + if ((!has_deletions || !isMarkedDeleted(ep_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id)))) { + dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); + lowerBound = dist; + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + } else { + lowerBound = std::numeric_limits::max(); + candidate_set.emplace(-lowerBound, ep_id); } - mutable std::atomic metric_distance_computations; - mutable std::atomic metric_hops; - - template - std::priority_queue, std::vector>, CompareByFirst> - searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { - VisitedList *vl = visited_list_pool_->getFreeVisitedList(); - vl_type *visited_array = vl->mass; - vl_type visited_array_tag = vl->curV; - - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - std::priority_queue, std::vector>, CompareByFirst> candidate_set; - - dist_t lowerBound; - if (!has_deletions || !isMarkedDeleted(ep_id)) { - dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); - lowerBound = dist; - top_candidates.emplace(dist, ep_id); - candidate_set.emplace(-dist, ep_id); - } else { - lowerBound = std::numeric_limits::max(); - candidate_set.emplace(-lowerBound, ep_id); - } - - visited_array[ep_id] = visited_array_tag; - - while (!candidate_set.empty()) { + visited_array[ep_id] = visited_array_tag; - std::pair current_node_pair = candidate_set.top(); + while (!candidate_set.empty()) { + std::pair current_node_pair = candidate_set.top(); - if ((-current_node_pair.first) > lowerBound && (top_candidates.size() == ef || has_deletions == false)) { - break; - } - candidate_set.pop(); + if ((-current_node_pair.first) > lowerBound && + (top_candidates.size() == ef || (!isIdAllowed && !has_deletions))) { + break; + } + candidate_set.pop(); - tableint current_node_id = current_node_pair.second; - int *data = (int *) get_linklist0(current_node_id); - size_t size = getListCount((linklistsizeint*)data); + tableint current_node_id = current_node_pair.second; + int *data = (int *) get_linklist0(current_node_id); + size_t size = getListCount((linklistsizeint*)data); // bool cur_node_deleted = isMarkedDeleted(current_node_id); - if(collect_metrics){ - metric_hops++; - metric_distance_computations+=size; - } + if (collect_metrics) { + metric_hops++; + metric_distance_computations+=size; + } #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); - _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); - _mm_prefetch((char *) (data + 2), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1)), _MM_HINT_T0); + _mm_prefetch((char *) (visited_array + *(data + 1) + 64), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + 1)) * size_data_per_element_ + offsetData_, _MM_HINT_T0); + _mm_prefetch((char *) (data + 2), _MM_HINT_T0); #endif - for (size_t j = 1; j <= size; j++) { - int candidate_id = *(data + j); + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); // if (candidate_id == 0) continue; #ifdef USE_SSE - _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); - _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, - _MM_HINT_T0);//////////// + _mm_prefetch((char *) (visited_array + *(data + j + 1)), _MM_HINT_T0); + _mm_prefetch(data_level0_memory_ + (*(data + j + 1)) * size_data_per_element_ + offsetData_, + _MM_HINT_T0); //////////// #endif - if (!(visited_array[candidate_id] == visited_array_tag)) { - - visited_array[candidate_id] = visited_array_tag; + if (!(visited_array[candidate_id] == visited_array_tag)) { + visited_array[candidate_id] = visited_array_tag; - char *currObj1 = (getDataByInternalId(candidate_id)); - dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); + char *currObj1 = (getDataByInternalId(candidate_id)); + dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); - if (top_candidates.size() < ef || lowerBound > dist) { - candidate_set.emplace(-dist, candidate_id); + if (top_candidates.size() < ef || lowerBound > dist) { + candidate_set.emplace(-dist, candidate_id); #ifdef USE_SSE - _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + - offsetLevel0_,/////////// - _MM_HINT_T0);//////////////////////// + _mm_prefetch(data_level0_memory_ + candidate_set.top().second * size_data_per_element_ + + offsetLevel0_, /////////// + _MM_HINT_T0); //////////////////////// #endif - if (!has_deletions || !isMarkedDeleted(candidate_id)) - top_candidates.emplace(dist, candidate_id); + if ((!has_deletions || !isMarkedDeleted(candidate_id)) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(candidate_id)))) + top_candidates.emplace(dist, candidate_id); - if (top_candidates.size() > ef) - top_candidates.pop(); + if (top_candidates.size() > ef) + top_candidates.pop(); - if (!top_candidates.empty()) - lowerBound = top_candidates.top().first; - } + if (!top_candidates.empty()) + lowerBound = top_candidates.top().first; } } } - - visited_list_pool_->releaseVisitedList(vl); - return top_candidates; } - void getNeighborsByHeuristic2( - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - const size_t M) { - if (top_candidates.size() < M) { - return; - } + visited_list_pool_->releaseVisitedList(vl); + return top_candidates; + } - std::priority_queue> queue_closest; - std::vector> return_list; - while (top_candidates.size() > 0) { - queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); - top_candidates.pop(); - } - while (queue_closest.size()) { - if (return_list.size() >= M) + void getNeighborsByHeuristic2( + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + const size_t M) { + if (top_candidates.size() < M) { + return; + } + + std::priority_queue> queue_closest; + std::vector> return_list; + while (top_candidates.size() > 0) { + queue_closest.emplace(-top_candidates.top().first, top_candidates.top().second); + top_candidates.pop(); + } + + while (queue_closest.size()) { + if (return_list.size() >= M) + break; + std::pair curent_pair = queue_closest.top(); + dist_t dist_to_query = -curent_pair.first; + queue_closest.pop(); + bool good = true; + + for (std::pair second_pair : return_list) { + dist_t curdist = + fstdistfunc_(getDataByInternalId(second_pair.second), + getDataByInternalId(curent_pair.second), + dist_func_param_); + if (curdist < dist_to_query) { + good = false; break; - std::pair curent_pair = queue_closest.top(); - dist_t dist_to_query = -curent_pair.first; - queue_closest.pop(); - bool good = true; - - for (std::pair second_pair : return_list) { - dist_t curdist = - fstdistfunc_(getDataByInternalId(second_pair.second), - getDataByInternalId(curent_pair.second), - dist_func_param_);; - if (curdist < dist_to_query) { - good = false; - break; - } - } - if (good) { - return_list.push_back(curent_pair); } } - - for (std::pair curent_pair : return_list) { - top_candidates.emplace(-curent_pair.first, curent_pair.second); + if (good) { + return_list.push_back(curent_pair); } } + for (std::pair curent_pair : return_list) { + top_candidates.emplace(-curent_pair.first, curent_pair.second); + } + } - linklistsizeint *get_linklist0(tableint internal_id) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; - - linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { - return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); - }; - - linklistsizeint *get_linklist(tableint internal_id, int level) const { - return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); - }; - linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { - return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); - }; + linklistsizeint *get_linklist0(tableint internal_id) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - tableint mutuallyConnectNewElement(const void *data_point, tableint cur_c, - std::priority_queue, std::vector>, CompareByFirst> &top_candidates, - int level, bool isUpdate) { - size_t Mcurmax = level ? maxM_ : maxM0_; - getNeighborsByHeuristic2(top_candidates, M_); - if (top_candidates.size() > M_) - throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - std::vector selectedNeighbors; - selectedNeighbors.reserve(M_); - while (top_candidates.size() > 0) { - selectedNeighbors.push_back(top_candidates.top().second); - top_candidates.pop(); - } + linklistsizeint *get_linklist0(tableint internal_id, char *data_level0_memory_) const { + return (linklistsizeint *) (data_level0_memory_ + internal_id * size_data_per_element_ + offsetLevel0_); + } - tableint next_closest_entry_point = selectedNeighbors.back(); - { - linklistsizeint *ll_cur; - if (level == 0) - ll_cur = get_linklist0(cur_c); - else - ll_cur = get_linklist(cur_c, level); + linklistsizeint *get_linklist(tableint internal_id, int level) const { + return (linklistsizeint *) (linkLists_[internal_id] + (level - 1) * size_links_per_element_); + } - if (*ll_cur && !isUpdate) { - throw std::runtime_error("The newly inserted element should have blank link list"); - } - setListCount(ll_cur,selectedNeighbors.size()); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { - if (data[idx] && !isUpdate) - throw std::runtime_error("Possible memory corruption"); - if (level > element_levels_[selectedNeighbors[idx]]) - throw std::runtime_error("Trying to make a link on a non-existent level"); - data[idx] = selectedNeighbors[idx]; + linklistsizeint *get_linklist_at_level(tableint internal_id, int level) const { + return level == 0 ? get_linklist0(internal_id) : get_linklist(internal_id, level); + } - } - } - for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + tableint mutuallyConnectNewElement( + const void *data_point, + tableint cur_c, + std::priority_queue, std::vector>, CompareByFirst> &top_candidates, + int level, + bool isUpdate) { + size_t Mcurmax = level ? maxM_ : maxM0_; + getNeighborsByHeuristic2(top_candidates, M_); + if (top_candidates.size() > M_) + throw std::runtime_error("Should be not be more than M_ candidates returned by the heuristic"); - std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); + std::vector selectedNeighbors; + selectedNeighbors.reserve(M_); + while (top_candidates.size() > 0) { + selectedNeighbors.push_back(top_candidates.top().second); + top_candidates.pop(); + } - linklistsizeint *ll_other; - if (level == 0) - ll_other = get_linklist0(selectedNeighbors[idx]); - else - ll_other = get_linklist(selectedNeighbors[idx], level); + tableint next_closest_entry_point = selectedNeighbors.back(); - size_t sz_link_list_other = getListCount(ll_other); + { + // lock only during the update + // because during the addition the lock for cur_c is already acquired + std::unique_lock lock(link_list_locks_[cur_c], std::defer_lock); + if (isUpdate) { + lock.lock(); + } + linklistsizeint *ll_cur; + if (level == 0) + ll_cur = get_linklist0(cur_c); + else + ll_cur = get_linklist(cur_c, level); - if (sz_link_list_other > Mcurmax) - throw std::runtime_error("Bad value of sz_link_list_other"); - if (selectedNeighbors[idx] == cur_c) - throw std::runtime_error("Trying to connect an element to itself"); + if (*ll_cur && !isUpdate) { + throw std::runtime_error("The newly inserted element should have blank link list"); + } + setListCount(ll_cur, selectedNeighbors.size()); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + if (data[idx] && !isUpdate) + throw std::runtime_error("Possible memory corruption"); if (level > element_levels_[selectedNeighbors[idx]]) throw std::runtime_error("Trying to make a link on a non-existent level"); - tableint *data = (tableint *) (ll_other + 1); + data[idx] = selectedNeighbors[idx]; + } + } - bool is_cur_c_present = false; - if (isUpdate) { - for (size_t j = 0; j < sz_link_list_other; j++) { - if (data[j] == cur_c) { - is_cur_c_present = true; - break; - } - } - } + for (size_t idx = 0; idx < selectedNeighbors.size(); idx++) { + std::unique_lock lock(link_list_locks_[selectedNeighbors[idx]]); - // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. - if (!is_cur_c_present) { - if (sz_link_list_other < Mcurmax) { - data[sz_link_list_other] = cur_c; - setListCount(ll_other, sz_link_list_other + 1); - } else { - // finding the "weakest" element to replace it with the new one - dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_); - // Heuristic: - std::priority_queue, std::vector>, CompareByFirst> candidates; - candidates.emplace(d_max, cur_c); - - for (size_t j = 0; j < sz_link_list_other; j++) { - candidates.emplace( - fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), - dist_func_param_), data[j]); - } + linklistsizeint *ll_other; + if (level == 0) + ll_other = get_linklist0(selectedNeighbors[idx]); + else + ll_other = get_linklist(selectedNeighbors[idx], level); - getNeighborsByHeuristic2(candidates, Mcurmax); + size_t sz_link_list_other = getListCount(ll_other); - int indx = 0; - while (candidates.size() > 0) { - data[indx] = candidates.top().second; - candidates.pop(); - indx++; - } + if (sz_link_list_other > Mcurmax) + throw std::runtime_error("Bad value of sz_link_list_other"); + if (selectedNeighbors[idx] == cur_c) + throw std::runtime_error("Trying to connect an element to itself"); + if (level > element_levels_[selectedNeighbors[idx]]) + throw std::runtime_error("Trying to make a link on a non-existent level"); - setListCount(ll_other, indx); - // Nearest K: - /*int indx = -1; - for (int j = 0; j < sz_link_list_other; j++) { - dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); - if (d > d_max) { - indx = j; - d_max = d; - } - } - if (indx >= 0) { - data[indx] = cur_c; - } */ + tableint *data = (tableint *) (ll_other + 1); + + bool is_cur_c_present = false; + if (isUpdate) { + for (size_t j = 0; j < sz_link_list_other; j++) { + if (data[j] == cur_c) { + is_cur_c_present = true; + break; } } } - return next_closest_entry_point; - } - - std::mutex global; - size_t ef_; - - void setEf(size_t ef) { - ef_ = ef; - } + // If cur_c is already present in the neighboring connections of `selectedNeighbors[idx]` then no need to modify any connections or run the heuristics. + if (!is_cur_c_present) { + if (sz_link_list_other < Mcurmax) { + data[sz_link_list_other] = cur_c; + setListCount(ll_other, sz_link_list_other + 1); + } else { + // finding the "weakest" element to replace it with the new one + dist_t d_max = fstdistfunc_(getDataByInternalId(cur_c), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_); + // Heuristic: + std::priority_queue, std::vector>, CompareByFirst> candidates; + candidates.emplace(d_max, cur_c); + for (size_t j = 0; j < sz_link_list_other; j++) { + candidates.emplace( + fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(selectedNeighbors[idx]), + dist_func_param_), data[j]); + } - std::priority_queue> searchKnnInternal(void *query_data, int k) { - std::priority_queue> top_candidates; - if (cur_element_count == 0) return top_candidates; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + getNeighborsByHeuristic2(candidates, Mcurmax); - for (size_t level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - int *data; - data = (int *) get_linklist(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + int indx = 0; + while (candidates.size() > 0) { + data[indx] = candidates.top().second; + candidates.pop(); + indx++; + } - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; + setListCount(ll_other, indx); + // Nearest K: + /*int indx = -1; + for (int j = 0; j < sz_link_list_other; j++) { + dist_t d = fstdistfunc_(getDataByInternalId(data[j]), getDataByInternalId(rez[idx]), dist_func_param_); + if (d > d_max) { + indx = j; + d_max = d; } } + if (indx >= 0) { + data[indx] = cur_c; + } */ } } + } - if (num_deleted_) { - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } - else{ - std::priority_queue> top_candidates1=searchBaseLayerST(currObj, query_data, - ef_); - top_candidates.swap(top_candidates1); - } + return next_closest_entry_point; + } - while (top_candidates.size() > k) { - top_candidates.pop(); - } - return top_candidates; - }; - void resizeIndex(size_t new_max_elements){ - if (new_max_elements(new_max_elements).swap(link_list_locks_); - element_levels_.resize(new_max_elements); + // Reallocate base layer + char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); + if (data_level0_memory_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); + data_level0_memory_ = data_level0_memory_new; - std::vector(new_max_elements).swap(link_list_locks_); + // Reallocate all other layers + char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); + if (linkLists_new == nullptr) + throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); + linkLists_ = linkLists_new; - // Reallocate base layer - char * data_level0_memory_new = (char *) realloc(data_level0_memory_, new_max_elements * size_data_per_element_); - if (data_level0_memory_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate base layer"); - data_level0_memory_ = data_level0_memory_new; + max_elements_ = new_max_elements; + } - // Reallocate all other layers - char ** linkLists_new = (char **) realloc(linkLists_, sizeof(void *) * new_max_elements); - if (linkLists_new == nullptr) - throw std::runtime_error("Not enough memory: resizeIndex failed to allocate other layers"); - linkLists_ = linkLists_new; - max_elements_ = new_max_elements; - } + void saveIndex(const std::string &location) { + std::ofstream output(location, std::ios::binary); + std::streampos position; - void saveIndex(const std::string &location) { - std::ofstream output(location, std::ios::binary); - std::streampos position; - - writeBinaryPOD(output, offsetLevel0_); - writeBinaryPOD(output, max_elements_); - writeBinaryPOD(output, cur_element_count); - writeBinaryPOD(output, size_data_per_element_); - writeBinaryPOD(output, label_offset_); - writeBinaryPOD(output, offsetData_); - writeBinaryPOD(output, maxlevel_); - writeBinaryPOD(output, enterpoint_node_); - writeBinaryPOD(output, maxM_); - - writeBinaryPOD(output, maxM0_); - writeBinaryPOD(output, M_); - writeBinaryPOD(output, mult_); - writeBinaryPOD(output, ef_construction_); - - output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - - for (size_t i = 0; i < cur_element_count; i++) { - unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; - writeBinaryPOD(output, linkListSize); - if (linkListSize) - output.write(linkLists_[i], linkListSize); - } - output.close(); - } + writeBinaryPOD(output, offsetLevel0_); + writeBinaryPOD(output, max_elements_); + writeBinaryPOD(output, cur_element_count); + writeBinaryPOD(output, size_data_per_element_); + writeBinaryPOD(output, label_offset_); + writeBinaryPOD(output, offsetData_); + writeBinaryPOD(output, maxlevel_); + writeBinaryPOD(output, enterpoint_node_); + writeBinaryPOD(output, maxM_); - void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i=0) { - std::ifstream input(location, std::ios::binary); + writeBinaryPOD(output, maxM0_); + writeBinaryPOD(output, M_); + writeBinaryPOD(output, mult_); + writeBinaryPOD(output, ef_construction_); - if (!input.is_open()) - throw std::runtime_error("Cannot open file"); + output.write(data_level0_memory_, cur_element_count * size_data_per_element_); - // get file size: - input.seekg(0,input.end); - std::streampos total_filesize=input.tellg(); - input.seekg(0,input.beg); + for (size_t i = 0; i < cur_element_count; i++) { + unsigned int linkListSize = element_levels_[i] > 0 ? size_links_per_element_ * element_levels_[i] : 0; + writeBinaryPOD(output, linkListSize); + if (linkListSize) + output.write(linkLists_[i], linkListSize); + } + output.close(); + } + + + void loadIndex(const std::string &location, SpaceInterface *s, size_t max_elements_i = 0) { + std::ifstream input(location, std::ios::binary); + + if (!input.is_open()) + throw std::runtime_error("Cannot open file"); + + // get file size: + input.seekg(0, input.end); + std::streampos total_filesize = input.tellg(); + input.seekg(0, input.beg); + + readBinaryPOD(input, offsetLevel0_); + readBinaryPOD(input, max_elements_); + readBinaryPOD(input, cur_element_count); + + size_t max_elements = max_elements_i; + if (max_elements < cur_element_count) + max_elements = max_elements_; + max_elements_ = max_elements; + readBinaryPOD(input, size_data_per_element_); + readBinaryPOD(input, label_offset_); + readBinaryPOD(input, offsetData_); + readBinaryPOD(input, maxlevel_); + readBinaryPOD(input, enterpoint_node_); + + readBinaryPOD(input, maxM_); + readBinaryPOD(input, maxM0_); + readBinaryPOD(input, M_); + readBinaryPOD(input, mult_); + readBinaryPOD(input, ef_construction_); + + data_size_ = s->get_data_size(); + fstdistfunc_ = s->get_dist_func(); + dist_func_param_ = s->get_dist_func_param(); + + auto pos = input.tellg(); + + /// Optional - check if index is ok: + input.seekg(cur_element_count * size_data_per_element_, input.cur); + for (size_t i = 0; i < cur_element_count; i++) { + if (input.tellg() < 0 || input.tellg() >= total_filesize) { + throw std::runtime_error("Index seems to be corrupted or unsupported"); + } - readBinaryPOD(input, offsetLevel0_); - readBinaryPOD(input, max_elements_); - readBinaryPOD(input, cur_element_count); + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize != 0) { + input.seekg(linkListSize, input.cur); + } + } - size_t max_elements = max_elements_i; - if(max_elements < cur_element_count) - max_elements = max_elements_; - max_elements_ = max_elements; - readBinaryPOD(input, size_data_per_element_); - readBinaryPOD(input, label_offset_); - readBinaryPOD(input, offsetData_); - readBinaryPOD(input, maxlevel_); - readBinaryPOD(input, enterpoint_node_); + // throw exception if it either corrupted or old index + if (input.tellg() != total_filesize) + throw std::runtime_error("Index seems to be corrupted or unsupported"); + + input.clear(); + /// Optional check end + + input.seekg(pos, input.beg); + + data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); + if (data_level0_memory_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); + input.read(data_level0_memory_, cur_element_count * size_data_per_element_); + + size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); + + size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); + std::vector(max_elements).swap(link_list_locks_); + std::vector(MAX_LABEL_OPERATION_LOCKS).swap(label_op_locks_); + + visited_list_pool_ = new VisitedListPool(1, max_elements); + + linkLists_ = (char **) malloc(sizeof(void *) * max_elements); + if (linkLists_ == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); + element_levels_ = std::vector(max_elements); + revSize_ = 1.0 / mult_; + ef_ = 10; + for (size_t i = 0; i < cur_element_count; i++) { + label_lookup_[getExternalLabel(i)] = i; + unsigned int linkListSize; + readBinaryPOD(input, linkListSize); + if (linkListSize == 0) { + element_levels_[i] = 0; + linkLists_[i] = nullptr; + } else { + element_levels_[i] = linkListSize / size_links_per_element_; + linkLists_[i] = (char *) malloc(linkListSize); + if (linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + input.read(linkLists_[i], linkListSize); + } + } - readBinaryPOD(input, maxM_); - readBinaryPOD(input, maxM0_); - readBinaryPOD(input, M_); - readBinaryPOD(input, mult_); - readBinaryPOD(input, ef_construction_); + for (size_t i = 0; i < cur_element_count; i++) { + if (isMarkedDeleted(i)) { + num_deleted_ += 1; + if (allow_replace_deleted_) deleted_elements.insert(i); + } + } + input.close(); - data_size_ = s->get_data_size(); - fstdistfunc_ = s->get_dist_func(); - dist_func_param_ = s->get_dist_func_param(); + return; + } - auto pos=input.tellg(); + template + std::vector getDataByLabel(labeltype label) const { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + char* data_ptrv = getDataByInternalId(internalId); + size_t dim = *((size_t *) dist_func_param_); + std::vector data; + data_t* data_ptr = (data_t*) data_ptrv; + for (int i = 0; i < dim; i++) { + data.push_back(*data_ptr); + data_ptr += 1; + } + return data; + } - /// Optional - check if index is ok: - input.seekg(cur_element_count * size_data_per_element_,input.cur); - for (size_t i = 0; i < cur_element_count; i++) { - if(input.tellg() < 0 || input.tellg()>=total_filesize){ - throw std::runtime_error("Index seems to be corrupted or unsupported"); - } + /* + * Marks an element with the given label deleted, does NOT really change the current graph. + */ + void markDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize != 0) { - input.seekg(linkListSize,input.cur); - } + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + markDeletedInternal(internalId); + } + + + /* + * Uses the last 16 bits of the memory for the linked list size to store the mark, + * whereas maxM0_ has to be limited to the lower 16 bits, however, still large enough in almost all cases. + */ + void markDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (!isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; + *ll_cur |= DELETE_MARK; + num_deleted_ += 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.insert(internalId); } + } else { + throw std::runtime_error("The requested to delete element is already deleted"); + } + } + + + /* + * Removes the deleted mark of the node, does NOT really change the current graph. + * + * Note: the method is not safe to use when replacement of deleted elements is enabled, + * because elements marked as deleted can be completely removed by addPoint + */ + void unmarkDelete(labeltype label) { + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search == label_lookup_.end()) { + throw std::runtime_error("Label not found"); + } + tableint internalId = search->second; + lock_table.unlock(); + + unmarkDeletedInternal(internalId); + } + + + + /* + * Remove the deleted mark of the node. + */ + void unmarkDeletedInternal(tableint internalId) { + assert(internalId < cur_element_count); + if (isMarkedDeleted(internalId)) { + unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId)) + 2; + *ll_cur &= ~DELETE_MARK; + num_deleted_ -= 1; + if (allow_replace_deleted_) { + std::unique_lock lock_deleted_elements(deleted_elements_lock); + deleted_elements.erase(internalId); + } + } else { + throw std::runtime_error("The requested to undelete element is not deleted"); + } + } - // throw exception if it either corrupted or old index - if(input.tellg()!=total_filesize) - throw std::runtime_error("Index seems to be corrupted or unsupported"); - - input.clear(); - - /// Optional check end - - input.seekg(pos,input.beg); - - data_level0_memory_ = (char *) malloc(max_elements * size_data_per_element_); - if (data_level0_memory_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate level0"); - input.read(data_level0_memory_, cur_element_count * size_data_per_element_); - - size_links_per_element_ = maxM_ * sizeof(tableint) + sizeof(linklistsizeint); - size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint); - std::vector(max_elements).swap(link_list_locks_); - std::vector(max_update_element_locks).swap(link_list_update_locks_); + /* + * Checks the first 16 bits of the memory to see if the element is marked deleted. + */ + bool isMarkedDeleted(tableint internalId) const { + unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId)) + 2; + return *ll_cur & DELETE_MARK; + } - visited_list_pool_ = new VisitedListPool(1, max_elements); - linkLists_ = (char **) malloc(sizeof(void *) * max_elements); - if (linkLists_ == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklists"); - element_levels_ = std::vector(max_elements); - revSize_ = 1.0 / mult_; - ef_ = 10; - for (size_t i = 0; i < cur_element_count; i++) { - label_lookup_[getExternalLabel(i)]=i; - unsigned int linkListSize; - readBinaryPOD(input, linkListSize); - if (linkListSize == 0) { - element_levels_[i] = 0; + unsigned short int getListCount(linklistsizeint * ptr) const { + return *((unsigned short int *)ptr); + } - linkLists_[i] = nullptr; - } else { - element_levels_[i] = linkListSize / size_links_per_element_; - linkLists_[i] = (char *) malloc(linkListSize); - if (linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - input.read(linkLists_[i], linkListSize); - } - } - for (size_t i = 0; i < cur_element_count; i++) { - if(isMarkedDeleted(i)) - num_deleted_ += 1; - } + void setListCount(linklistsizeint * ptr, unsigned short int size) const { + *((unsigned short int*)(ptr))=*((unsigned short int *)&size); + } - input.close(); - return; + /* + * Adds point. Updates the point if it is already in the index. + * If replacement of deleted elements is enabled: replaces previously deleted point if any, updating it with new point + */ + void addPoint(const void *data_point, labeltype label, bool replace_deleted = false) { + if ((allow_replace_deleted_ == false) && (replace_deleted == true)) { + throw std::runtime_error("Replacement of deleted elements is disabled in constructor"); } - template - std::vector getDataByLabel(labeltype label) const - { - tableint label_c; - auto search = label_lookup_.find(label); - if (search == label_lookup_.end() || isMarkedDeleted(search->second)) { - throw std::runtime_error("Label not found"); - } - label_c = search->second; - - char* data_ptrv = getDataByInternalId(label_c); - size_t dim = *((size_t *) dist_func_param_); - std::vector data; - data_t* data_ptr = (data_t*) data_ptrv; - for (int i = 0; i < dim; i++) { - data.push_back(*data_ptr); - data_ptr += 1; - } - return data; + // lock all operations with element by label + std::unique_lock lock_label(getLabelOpMutex(label)); + if (!replace_deleted) { + addPoint(data_point, label, -1); + return; } - - static const unsigned char DELETE_MARK = 0x01; - // static const unsigned char REUSE_MARK = 0x10; - /** - * Marks an element with the given label deleted, does NOT really change the current graph. - * @param label - */ - void markDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - markDeletedInternal(internalId); + // check if there is vacant place + tableint internal_id_replaced; + std::unique_lock lock_deleted_elements(deleted_elements_lock); + bool is_vacant_place = !deleted_elements.empty(); + if (is_vacant_place) { + internal_id_replaced = *deleted_elements.begin(); + deleted_elements.erase(internal_id_replaced); } - - /** - * Uses the first 8 bits of the memory for the linked list to store the mark, - * whereas maxM0_ has to be limited to the lower 24 bits, however, still large enough in almost all cases. - * @param internalId - */ - void markDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (!isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur |= DELETE_MARK; - num_deleted_ += 1; - } - else - { - throw std::runtime_error("The requested to delete element is already deleted"); - } + lock_deleted_elements.unlock(); + + // if there is no vacant place then add or update point + // else add point to vacant place + if (!is_vacant_place) { + addPoint(data_point, label, -1); + } else { + // we assume that there are no concurrent operations on deleted element + labeltype label_replaced = getExternalLabel(internal_id_replaced); + setExternalLabel(internal_id_replaced, label); + + std::unique_lock lock_table(label_lookup_lock); + label_lookup_.erase(label_replaced); + label_lookup_[label] = internal_id_replaced; + lock_table.unlock(); + + unmarkDeletedInternal(internal_id_replaced); + updatePoint(data_point, internal_id_replaced, 1.0); } + } - /** - * Remove the deleted mark of the node, does NOT really change the current graph. - * @param label - */ - void unmarkDelete(labeltype label) - { - auto search = label_lookup_.find(label); - if (search == label_lookup_.end()) { - throw std::runtime_error("Label not found"); - } - tableint internalId = search->second; - unmarkDeletedInternal(internalId); - } - /** - * Remove the deleted mark of the node. - * @param internalId - */ - void unmarkDeletedInternal(tableint internalId) { - assert(internalId < cur_element_count); - if (isMarkedDeleted(internalId)) - { - unsigned char *ll_cur = ((unsigned char *)get_linklist0(internalId))+2; - *ll_cur &= ~DELETE_MARK; - num_deleted_ -= 1; - } - else - { - throw std::runtime_error("The requested to undelete element is not deleted"); - } - } + void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { + // update the feature vector associated with existing point with new vector + memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - /** - * Checks the first 8 bits of the memory to see if the element is marked deleted. - * @param internalId - * @return - */ - bool isMarkedDeleted(tableint internalId) const { - unsigned char *ll_cur = ((unsigned char*)get_linklist0(internalId))+2; - return *ll_cur & DELETE_MARK; - } + int maxLevelCopy = maxlevel_; + tableint entryPointCopy = enterpoint_node_; + // If point to be updated is entry point and graph just contains single element then just return. + if (entryPointCopy == internalId && cur_element_count == 1) + return; - unsigned short int getListCount(linklistsizeint * ptr) const { - return *((unsigned short int *)ptr); - } + int elemLevel = element_levels_[internalId]; + std::uniform_real_distribution distribution(0.0, 1.0); + for (int layer = 0; layer <= elemLevel; layer++) { + std::unordered_set sCand; + std::unordered_set sNeigh; + std::vector listOneHop = getConnectionsWithLock(internalId, layer); + if (listOneHop.size() == 0) + continue; - void setListCount(linklistsizeint * ptr, unsigned short int size) const { - *((unsigned short int*)(ptr))=*((unsigned short int *)&size); - } + sCand.insert(internalId); - void addPoint(const void *data_point, labeltype label) { - addPoint(data_point, label,-1); - } + for (auto&& elOneHop : listOneHop) { + sCand.insert(elOneHop); - void updatePoint(const void *dataPoint, tableint internalId, float updateNeighborProbability) { - // update the feature vector associated with existing point with new vector - memcpy(getDataByInternalId(internalId), dataPoint, data_size_); - - int maxLevelCopy = maxlevel_; - tableint entryPointCopy = enterpoint_node_; - // If point to be updated is entry point and graph just contains single element then just return. - if (entryPointCopy == internalId && cur_element_count == 1) - return; - - int elemLevel = element_levels_[internalId]; - std::uniform_real_distribution distribution(0.0, 1.0); - for (int layer = 0; layer <= elemLevel; layer++) { - std::unordered_set sCand; - std::unordered_set sNeigh; - std::vector listOneHop = getConnectionsWithLock(internalId, layer); - if (listOneHop.size() == 0) + if (distribution(update_probability_generator_) > updateNeighborProbability) continue; - sCand.insert(internalId); - - for (auto&& elOneHop : listOneHop) { - sCand.insert(elOneHop); - - if (distribution(update_probability_generator_) > updateNeighborProbability) - continue; - - sNeigh.insert(elOneHop); + sNeigh.insert(elOneHop); - std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); - for (auto&& elTwoHop : listTwoHop) { - sCand.insert(elTwoHop); - } + std::vector listTwoHop = getConnectionsWithLock(elOneHop, layer); + for (auto&& elTwoHop : listTwoHop) { + sCand.insert(elTwoHop); } + } - for (auto&& neigh : sNeigh) { - // if (neigh == internalId) - // continue; + for (auto&& neigh : sNeigh) { + // if (neigh == internalId) + // continue; - std::priority_queue, std::vector>, CompareByFirst> candidates; - size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 - size_t elementsToKeep = std::min(ef_construction_, size); - for (auto&& cand : sCand) { - if (cand == neigh) - continue; - - dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); - if (candidates.size() < elementsToKeep) { + std::priority_queue, std::vector>, CompareByFirst> candidates; + size_t size = sCand.find(neigh) == sCand.end() ? sCand.size() : sCand.size() - 1; // sCand guaranteed to have size >= 1 + size_t elementsToKeep = std::min(ef_construction_, size); + for (auto&& cand : sCand) { + if (cand == neigh) + continue; + + dist_t distance = fstdistfunc_(getDataByInternalId(neigh), getDataByInternalId(cand), dist_func_param_); + if (candidates.size() < elementsToKeep) { + candidates.emplace(distance, cand); + } else { + if (distance < candidates.top().first) { + candidates.pop(); candidates.emplace(distance, cand); - } else { - if (distance < candidates.top().first) { - candidates.pop(); - candidates.emplace(distance, cand); - } } } + } - // Retrieve neighbours using heuristic and set connections. - getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); - - { - std::unique_lock lock(link_list_locks_[neigh]); - linklistsizeint *ll_cur; - ll_cur = get_linklist_at_level(neigh, layer); - size_t candSize = candidates.size(); - setListCount(ll_cur, candSize); - tableint *data = (tableint *) (ll_cur + 1); - for (size_t idx = 0; idx < candSize; idx++) { - data[idx] = candidates.top().second; - candidates.pop(); - } + // Retrieve neighbours using heuristic and set connections. + getNeighborsByHeuristic2(candidates, layer == 0 ? maxM0_ : maxM_); + + { + std::unique_lock lock(link_list_locks_[neigh]); + linklistsizeint *ll_cur; + ll_cur = get_linklist_at_level(neigh, layer); + size_t candSize = candidates.size(); + setListCount(ll_cur, candSize); + tableint *data = (tableint *) (ll_cur + 1); + for (size_t idx = 0; idx < candSize; idx++) { + data[idx] = candidates.top().second; + candidates.pop(); } } } + } - repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); - }; + repairConnectionsForUpdate(dataPoint, entryPointCopy, internalId, elemLevel, maxLevelCopy); + } - void repairConnectionsForUpdate(const void *dataPoint, tableint entryPointInternalId, tableint dataPointInternalId, int dataPointLevel, int maxLevel) { - tableint currObj = entryPointInternalId; - if (dataPointLevel < maxLevel) { - dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxLevel; level > dataPointLevel; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist_at_level(currObj,level); - int size = getListCount(data); - tableint *datal = (tableint *) (data + 1); + + void repairConnectionsForUpdate( + const void *dataPoint, + tableint entryPointInternalId, + tableint dataPointInternalId, + int dataPointLevel, + int maxLevel) { + tableint currObj = entryPointInternalId; + if (dataPointLevel < maxLevel) { + dist_t curdist = fstdistfunc_(dataPoint, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxLevel; level > dataPointLevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist_at_level(currObj, level); + int size = getListCount(data); + tableint *datal = (tableint *) (data + 1); #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*datal), _MM_HINT_T0); #endif - for (int i = 0; i < size; i++) { + for (int i = 0; i < size; i++) { #ifdef USE_SSE - _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); + _mm_prefetch(getDataByInternalId(*(datal + i + 1)), _MM_HINT_T0); #endif - tableint cand = datal[i]; - dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint cand = datal[i]; + dist_t d = fstdistfunc_(dataPoint, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } } + } - if (dataPointLevel > maxLevel) - throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); + if (dataPointLevel > maxLevel) + throw std::runtime_error("Level of item to be updated cannot be bigger than max level"); - for (int level = dataPointLevel; level >= 0; level--) { - std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( - currObj, dataPoint, level); + for (int level = dataPointLevel; level >= 0; level--) { + std::priority_queue, std::vector>, CompareByFirst> topCandidates = searchBaseLayer( + currObj, dataPoint, level); - std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; - while (topCandidates.size() > 0) { - if (topCandidates.top().second != dataPointInternalId) - filteredTopCandidates.push(topCandidates.top()); + std::priority_queue, std::vector>, CompareByFirst> filteredTopCandidates; + while (topCandidates.size() > 0) { + if (topCandidates.top().second != dataPointInternalId) + filteredTopCandidates.push(topCandidates.top()); - topCandidates.pop(); - } - - // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. - // To prevent self loops, the `topCandidates` is filtered and thus can be empty. - if (filteredTopCandidates.size() > 0) { - bool epDeleted = isMarkedDeleted(entryPointInternalId); - if (epDeleted) { - filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); - if (filteredTopCandidates.size() > ef_construction_) - filteredTopCandidates.pop(); - } + topCandidates.pop(); + } - currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); + // Since element_levels_ is being used to get `dataPointLevel`, there could be cases where `topCandidates` could just contains entry point itself. + // To prevent self loops, the `topCandidates` is filtered and thus can be empty. + if (filteredTopCandidates.size() > 0) { + bool epDeleted = isMarkedDeleted(entryPointInternalId); + if (epDeleted) { + filteredTopCandidates.emplace(fstdistfunc_(dataPoint, getDataByInternalId(entryPointInternalId), dist_func_param_), entryPointInternalId); + if (filteredTopCandidates.size() > ef_construction_) + filteredTopCandidates.pop(); } + + currObj = mutuallyConnectNewElement(dataPoint, dataPointInternalId, filteredTopCandidates, level, true); } } + } + + + std::vector getConnectionsWithLock(tableint internalId, int level) { + std::unique_lock lock(link_list_locks_[internalId]); + unsigned int *data = get_linklist_at_level(internalId, level); + int size = getListCount(data); + std::vector result(size); + tableint *ll = (tableint *) (data + 1); + memcpy(result.data(), ll, size * sizeof(tableint)); + return result; + } - std::vector getConnectionsWithLock(tableint internalId, int level) { - std::unique_lock lock(link_list_locks_[internalId]); - unsigned int *data = get_linklist_at_level(internalId, level); - int size = getListCount(data); - std::vector result(size); - tableint *ll = (tableint *) (data + 1); - memcpy(result.data(), ll,size * sizeof(tableint)); - return result; - }; - - tableint addPoint(const void *data_point, labeltype label, int level) { - - tableint cur_c = 0; - { - // Checking if the element with the same label already exists - // if so, updating it *instead* of creating a new element. - std::unique_lock templock_curr(cur_element_count_guard_); - auto search = label_lookup_.find(label); - if (search != label_lookup_.end()) { - tableint existingInternalId = search->second; - templock_curr.unlock(); - - std::unique_lock lock_el_update(link_list_update_locks_[(existingInternalId & (max_update_element_locks - 1))]); + tableint addPoint(const void *data_point, labeltype label, int level) { + tableint cur_c = 0; + { + // Checking if the element with the same label already exists + // if so, updating it *instead* of creating a new element. + std::unique_lock lock_table(label_lookup_lock); + auto search = label_lookup_.find(label); + if (search != label_lookup_.end()) { + tableint existingInternalId = search->second; + if (allow_replace_deleted_) { if (isMarkedDeleted(existingInternalId)) { - unmarkDeletedInternal(existingInternalId); + throw std::runtime_error("Can't use addPoint to update deleted elements if replacement of deleted elements is enabled."); } - updatePoint(data_point, existingInternalId, 1.0); - - return existingInternalId; } + lock_table.unlock(); - if (cur_element_count >= max_elements_) { - throw std::runtime_error("The number of elements exceeds the specified limit"); - }; + if (isMarkedDeleted(existingInternalId)) { + unmarkDeletedInternal(existingInternalId); + } + updatePoint(data_point, existingInternalId, 1.0); - cur_c = cur_element_count; - cur_element_count++; - label_lookup_[label] = cur_c; + return existingInternalId; } - // Take update lock to prevent race conditions on an element with insertion/update at the same time. - std::unique_lock lock_el_update(link_list_update_locks_[(cur_c & (max_update_element_locks - 1))]); - std::unique_lock lock_el(link_list_locks_[cur_c]); - int curlevel = getRandomLevel(mult_); - if (level > 0) - curlevel = level; + if (cur_element_count >= max_elements_) { + throw std::runtime_error("The number of elements exceeds the specified limit"); + } - element_levels_[cur_c] = curlevel; + cur_c = cur_element_count; + cur_element_count++; + label_lookup_[label] = cur_c; + } + std::unique_lock lock_el(link_list_locks_[cur_c]); + int curlevel = getRandomLevel(mult_); + if (level > 0) + curlevel = level; - std::unique_lock templock(global); - int maxlevelcopy = maxlevel_; - if (curlevel <= maxlevelcopy) - templock.unlock(); - tableint currObj = enterpoint_node_; - tableint enterpoint_copy = enterpoint_node_; + element_levels_[cur_c] = curlevel; + std::unique_lock templock(global); + int maxlevelcopy = maxlevel_; + if (curlevel <= maxlevelcopy) + templock.unlock(); + tableint currObj = enterpoint_node_; + tableint enterpoint_copy = enterpoint_node_; - memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); + memset(data_level0_memory_ + cur_c * size_data_per_element_ + offsetLevel0_, 0, size_data_per_element_); - // Initialisation of the data and label - memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); - memcpy(getDataByInternalId(cur_c), data_point, data_size_); + // Initialisation of the data and label + memcpy(getExternalLabeLp(cur_c), &label, sizeof(labeltype)); + memcpy(getDataByInternalId(cur_c), data_point, data_size_); + if (curlevel) { + linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); + if (linkLists_[cur_c] == nullptr) + throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); + memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); + } - if (curlevel) { - linkLists_[cur_c] = (char *) malloc(size_links_per_element_ * curlevel + 1); - if (linkLists_[cur_c] == nullptr) - throw std::runtime_error("Not enough memory: addPoint failed to allocate linklist"); - memset(linkLists_[cur_c], 0, size_links_per_element_ * curlevel + 1); - } + if ((signed)currObj != -1) { + if (curlevel < maxlevelcopy) { + dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); + for (int level = maxlevelcopy; level > curlevel; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; + std::unique_lock lock(link_list_locks_[currObj]); + data = get_linklist(currObj, level); + int size = getListCount(data); - if ((signed)currObj != -1) { - - if (curlevel < maxlevelcopy) { - - dist_t curdist = fstdistfunc_(data_point, getDataByInternalId(currObj), dist_func_param_); - for (int level = maxlevelcopy; level > curlevel; level--) { - - - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; - std::unique_lock lock(link_list_locks_[currObj]); - data = get_linklist(currObj,level); - int size = getListCount(data); - - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(data_point, getDataByInternalId(cand), dist_func_param_); + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } } + } - bool epDeleted = isMarkedDeleted(enterpoint_copy); - for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { - if (level > maxlevelcopy || level < 0) // possible? - throw std::runtime_error("Level error"); - - std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( - currObj, data_point, level); - if (epDeleted) { - top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); - if (top_candidates.size() > ef_construction_) - top_candidates.pop(); - } - currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); + bool epDeleted = isMarkedDeleted(enterpoint_copy); + for (int level = std::min(curlevel, maxlevelcopy); level >= 0; level--) { + if (level > maxlevelcopy || level < 0) // possible? + throw std::runtime_error("Level error"); + + std::priority_queue, std::vector>, CompareByFirst> top_candidates = searchBaseLayer( + currObj, data_point, level); + if (epDeleted) { + top_candidates.emplace(fstdistfunc_(data_point, getDataByInternalId(enterpoint_copy), dist_func_param_), enterpoint_copy); + if (top_candidates.size() > ef_construction_) + top_candidates.pop(); } - - - } else { - // Do nothing for the first element - enterpoint_node_ = 0; - maxlevel_ = curlevel; - + currObj = mutuallyConnectNewElement(data_point, cur_c, top_candidates, level, false); } + } else { + // Do nothing for the first element + enterpoint_node_ = 0; + maxlevel_ = curlevel; + } - //Releasing lock for the maximum level - if (curlevel > maxlevelcopy) { - enterpoint_node_ = cur_c; - maxlevel_ = curlevel; - } - return cur_c; - }; + // Releasing lock for the maximum level + if (curlevel > maxlevelcopy) { + enterpoint_node_ = cur_c; + maxlevel_ = curlevel; + } + return cur_c; + } - std::priority_queue> - searchKnn(const void *query_data, size_t k) const { - std::priority_queue> result; - if (cur_element_count == 0) return result; - tableint currObj = enterpoint_node_; - dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); + std::priority_queue> + searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + std::priority_queue> result; + if (cur_element_count == 0) return result; - for (int level = maxlevel_; level > 0; level--) { - bool changed = true; - while (changed) { - changed = false; - unsigned int *data; + tableint currObj = enterpoint_node_; + dist_t curdist = fstdistfunc_(query_data, getDataByInternalId(enterpoint_node_), dist_func_param_); - data = (unsigned int *) get_linklist(currObj, level); - int size = getListCount(data); - metric_hops++; - metric_distance_computations+=size; + for (int level = maxlevel_; level > 0; level--) { + bool changed = true; + while (changed) { + changed = false; + unsigned int *data; - tableint *datal = (tableint *) (data + 1); - for (int i = 0; i < size; i++) { - tableint cand = datal[i]; - if (cand < 0 || cand > max_elements_) - throw std::runtime_error("cand error"); - dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + data = (unsigned int *) get_linklist(currObj, level); + int size = getListCount(data); + metric_hops++; + metric_distance_computations+=size; - if (d < curdist) { - curdist = d; - currObj = cand; - changed = true; - } + tableint *datal = (tableint *) (data + 1); + for (int i = 0; i < size; i++) { + tableint cand = datal[i]; + if (cand < 0 || cand > max_elements_) + throw std::runtime_error("cand error"); + dist_t d = fstdistfunc_(query_data, getDataByInternalId(cand), dist_func_param_); + + if (d < curdist) { + curdist = d; + currObj = cand; + changed = true; } } } + } - std::priority_queue, std::vector>, CompareByFirst> top_candidates; - if (num_deleted_) { - top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); - } - else{ - top_candidates=searchBaseLayerST( - currObj, query_data, std::max(ef_, k)); - } + std::priority_queue, std::vector>, CompareByFirst> top_candidates; + if (num_deleted_) { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } else { + top_candidates = searchBaseLayerST( + currObj, query_data, std::max(ef_, k), isIdAllowed); + } - while (top_candidates.size() > k) { - top_candidates.pop(); - } - while (top_candidates.size() > 0) { - std::pair rez = top_candidates.top(); - result.push(std::pair(rez.first, getExternalLabel(rez.second))); - top_candidates.pop(); - } - return result; - }; - - void checkIntegrity(){ - int connections_checked=0; - std::vector inbound_connections_num(cur_element_count,0); - for(int i = 0;i < cur_element_count; i++){ - for(int l = 0;l <= element_levels_[i]; l++){ - linklistsizeint *ll_cur = get_linklist_at_level(i,l); - int size = getListCount(ll_cur); - tableint *data = (tableint *) (ll_cur + 1); - std::unordered_set s; - for (int j=0; j 0); - assert(data[j] < cur_element_count); - assert (data[j] != i); - inbound_connections_num[data[j]]++; - s.insert(data[j]); - connections_checked++; + while (top_candidates.size() > k) { + top_candidates.pop(); + } + while (top_candidates.size() > 0) { + std::pair rez = top_candidates.top(); + result.push(std::pair(rez.first, getExternalLabel(rez.second))); + top_candidates.pop(); + } + return result; + } - } - assert(s.size() == size); + + void checkIntegrity() { + int connections_checked = 0; + std::vector inbound_connections_num(cur_element_count, 0); + for (int i = 0; i < cur_element_count; i++) { + for (int l = 0; l <= element_levels_[i]; l++) { + linklistsizeint *ll_cur = get_linklist_at_level(i, l); + int size = getListCount(ll_cur); + tableint *data = (tableint *) (ll_cur + 1); + std::unordered_set s; + for (int j = 0; j < size; j++) { + assert(data[j] > 0); + assert(data[j] < cur_element_count); + assert(data[j] != i); + inbound_connections_num[data[j]]++; + s.insert(data[j]); + connections_checked++; } + assert(s.size() == size); } - if(cur_element_count > 1){ - int min1=inbound_connections_num[0], max1=inbound_connections_num[0]; - for(int i=0; i < cur_element_count; i++){ - assert(inbound_connections_num[i] > 0); - min1=std::min(inbound_connections_num[i],min1); - max1=std::max(inbound_connections_num[i],max1); - } - std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; + } + if (cur_element_count > 1) { + int min1 = inbound_connections_num[0], max1 = inbound_connections_num[0]; + for (int i=0; i < cur_element_count; i++) { + assert(inbound_connections_num[i] > 0); + min1 = std::min(inbound_connections_num[i], min1); + max1 = std::max(inbound_connections_num[i], max1); } - std::cout << "integrity ok, checked " << connections_checked << " connections\n"; - + std::cout << "Min inbound: " << min1 << ", Max inbound:" << max1 << "\n"; } - - }; - -} + std::cout << "integrity ok, checked " << connections_checked << " connections\n"; + } +}; +} // namespace hnswlib diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 58eb7607..fb7118fa 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -1,6 +1,6 @@ #pragma once #ifndef NO_MANUAL_VECTORIZATION -#ifdef __SSE__ +#if (defined(__SSE__) || _M_IX86_FP > 0 || defined(_M_AMD64) || defined(_M_X64)) #define USE_SSE #ifdef __AVX__ #define USE_AVX @@ -15,21 +15,20 @@ #ifdef _MSC_VER #include #include -#include "cpu_x86.h" -void cpu_x86::cpuid(int32_t out[4], int32_t eax, int32_t ecx) { +void cpuid(int32_t out[4], int32_t eax, int32_t ecx) { __cpuidex(out, eax, ecx); } -__int64 xgetbv(unsigned int x) { +static __int64 xgetbv(unsigned int x) { return _xgetbv(x); } #else #include #include #include -void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { +static void cpuid(int32_t cpuInfo[4], int32_t eax, int32_t ecx) { __cpuid_count(eax, ecx, cpuInfo[0], cpuInfo[1], cpuInfo[2], cpuInfo[3]); } -uint64_t xgetbv(unsigned int index) { +static uint64_t xgetbv(unsigned int index) { uint32_t eax, edx; __asm__ __volatile__("xgetbv" : "=a"(eax), "=d"(edx) : "c"(index)); return ((uint64_t)edx << 32) | eax; @@ -51,7 +50,7 @@ uint64_t xgetbv(unsigned int index) { // Adapted from https://github.com/Mysticial/FeatureDetector #define _XCR_XFEATURE_ENABLED_MASK 0 -bool AVXCapable() { +static bool AVXCapable() { int cpuInfo[4]; // CPU support @@ -78,7 +77,7 @@ bool AVXCapable() { return HW_AVX && avxSupported; } -bool AVX512Capable() { +static bool AVX512Capable() { if (!AVXCapable()) return false; int cpuInfo[4]; @@ -88,7 +87,7 @@ bool AVX512Capable() { int nIds = cpuInfo[0]; bool HW_AVX512F = false; - if (nIds >= 0x00000007) { // AVX512 Foundation + if (nIds >= 0x00000007) { // AVX512 Foundation cpuid(cpuInfo, 0x00000007, 0); HW_AVX512F = (cpuInfo[1] & ((int)1 << 16)) != 0; } @@ -114,77 +113,85 @@ bool AVX512Capable() { #include namespace hnswlib { - typedef size_t labeltype; - - template - class pairGreater { - public: - bool operator()(const T& p1, const T& p2) { - return p1.first > p2.first; - } - }; - - template - static void writeBinaryPOD(std::ostream &out, const T &podRef) { - out.write((char *) &podRef, sizeof(T)); +typedef size_t labeltype; + +// This can be extended to store state for filtering (e.g. from a std::set) +class BaseFilterFunctor { + public: + virtual bool operator()(hnswlib::labeltype id) { return true; } +}; + +template +class pairGreater { + public: + bool operator()(const T& p1, const T& p2) { + return p1.first > p2.first; } +}; - template - static void readBinaryPOD(std::istream &in, T &podRef) { - in.read((char *) &podRef, sizeof(T)); - } +template +static void writeBinaryPOD(std::ostream &out, const T &podRef) { + out.write((char *) &podRef, sizeof(T)); +} - template - using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); +template +static void readBinaryPOD(std::istream &in, T &podRef) { + in.read((char *) &podRef, sizeof(T)); +} +template +using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); - template - class SpaceInterface { - public: - //virtual void search(void *); - virtual size_t get_data_size() = 0; +template +class SpaceInterface { + public: + // virtual void search(void *); + virtual size_t get_data_size() = 0; - virtual DISTFUNC get_dist_func() = 0; + virtual DISTFUNC get_dist_func() = 0; - virtual void *get_dist_func_param() = 0; + virtual void *get_dist_func_param() = 0; - virtual ~SpaceInterface() {} - }; + virtual ~SpaceInterface() {} +}; - template - class AlgorithmInterface { - public: - virtual void addPoint(const void *datapoint, labeltype label)=0; - virtual std::priority_queue> searchKnn(const void *, size_t) const = 0; +template +class AlgorithmInterface { + public: + virtual void addPoint(const void *datapoint, labeltype label, bool replace_deleted = false) = 0; - // Return k nearest neighbor in the order of closer fist - virtual std::vector> - searchKnnCloserFirst(const void* query_data, size_t k) const; + virtual std::priority_queue> + searchKnn(const void*, size_t, BaseFilterFunctor* isIdAllowed = nullptr) const = 0; - virtual void saveIndex(const std::string &location)=0; - virtual ~AlgorithmInterface(){ - } - }; - - template - std::vector> - AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k) const { - std::vector> result; - - // here searchKnn returns the result in the order of further first - auto ret = searchKnn(query_data, k); - { - size_t sz = ret.size(); - result.resize(sz); - while (!ret.empty()) { - result[--sz] = ret.top(); - ret.pop(); - } - } + // Return k nearest neighbor in the order of closer fist + virtual std::vector> + searchKnnCloserFirst(const void* query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const; - return result; + virtual void saveIndex(const std::string &location) = 0; + virtual ~AlgorithmInterface(){ } +}; + +template +std::vector> +AlgorithmInterface::searchKnnCloserFirst(const void* query_data, size_t k, + BaseFilterFunctor* isIdAllowed) const { + std::vector> result; + + // here searchKnn returns the result in the order of further first + auto ret = searchKnn(query_data, k, isIdAllowed); + { + size_t sz = ret.size(); + result.resize(sz); + while (!ret.empty()) { + result[--sz] = ret.top(); + ret.pop(); + } + } + + return result; } +} // namespace hnswlib #include "space_l2.h" #include "space_ip.h" diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index b4266f78..2b1c359e 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -3,374 +3,373 @@ namespace hnswlib { - static float - InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - float res = 0; - for (unsigned i = 0; i < qty; i++) { - res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; - } - return res; - +static float +InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + float res = 0; + for (unsigned i = 0; i < qty; i++) { + res += ((float *) pVect1)[i] * ((float *) pVect2)[i]; } + return res; +} - static float - InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { - return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); - } +static float +InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr) { + return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); +} #if defined(USE_AVX) // Favor using AVX if available. - static float - InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m256 sum256 = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } +static float +InnerProductSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m256 sum256 = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } - __m128 v1, v2; - __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); + __m128 v1, v2; + __m128 sum_prod = _mm_add_ps(_mm256_extractf128_ps(sum256, 0), _mm256_extractf128_ps(sum256, 1)); - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3];; - return sum; - } - - static float - InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + return sum; +} + +static float +InnerProductDistanceSIMD4ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtAVX(pVect1v, pVect2v, qty_ptr); +} + #endif #if defined(USE_SSE) - static float - InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - size_t qty4 = qty / 4; - - const float *pEnd1 = pVect1 + 16 * qty16; - const float *pEnd2 = pVect1 + 4 * qty4; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } +static float +InnerProductSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } - while (pVect1 < pEnd2) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } + while (pVect1 < pEnd2) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + } - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD4ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD4ExtSSE(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_AVX512) - static float - InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN64 TmpRes[16]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +InnerProductSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN64 TmpRes[16]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty / 16; + size_t qty16 = qty / 16; - const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd1 = pVect1 + 16 * qty16; - __m512 sum512 = _mm512_set1_ps(0); + __m512 sum512 = _mm512_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - __m512 v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - __m512 v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); - } + __m512 v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + __m512 v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + sum512 = _mm512_add_ps(sum512, _mm512_mul_ps(v1, v2)); + } - _mm512_store_ps(TmpRes, sum512); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; + _mm512_store_ps(TmpRes, sum512); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + TmpRes[13] + TmpRes[14] + TmpRes[15]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX512(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_AVX) - static float - InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +InnerProductSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty / 16; + size_t qty16 = qty / 16; - const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd1 = pVect1 + 16 * qty16; - __m256 sum256 = _mm256_set1_ps(0); + __m256 sum256 = _mm256_set1_ps(0); - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - __m256 v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - __m256 v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + __m256 v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + __m256 v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); - } + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + sum256 = _mm256_add_ps(sum256, _mm256_mul_ps(v1, v2)); + } - _mm256_store_ps(TmpRes, sum256); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; + _mm256_store_ps(TmpRes, sum256); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; - return sum; - } + return sum; +} - static float - InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); - } +static float +InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtAVX(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_SSE) - static float - InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - size_t qty16 = qty / 16; - - const float *pEnd1 = pVect1 + 16 * qty16; - - __m128 v1, v2; - __m128 sum_prod = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); - } - _mm_store_ps(TmpRes, sum_prod); - float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - - return sum; +static float +InnerProductSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + size_t qty16 = qty / 16; + + const float *pEnd1 = pVect1 + 16 * qty16; + + __m128 v1, v2; + __m128 sum_prod = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + sum_prod = _mm_add_ps(sum_prod, _mm_mul_ps(v1, v2)); } + _mm_store_ps(TmpRes, sum_prod); + float sum = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; - static float - InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); - } + return sum; +} + +static float +InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + return 1.0f - InnerProductSIMD16ExtSSE(pVect1v, pVect2v, qty_ptr); +} #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; - DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; - DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; - DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; - - static float - InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - return 1.0f - (res + res_tail); - } +static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; +static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; +static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; +static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + +static float +InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = InnerProductSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + return 1.0f - (res + res_tail); +} - static float - InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; +static float +InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; - float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; + float res = InnerProductSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = InnerProduct(pVect1, pVect2, &qty_left); + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = InnerProduct(pVect1, pVect2, &qty_left); - return 1.0f - (res + res_tail); - } + return 1.0f - (res + res_tail); +} #endif - class InnerProductSpace : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - InnerProductSpace(size_t dim) { - fstdistfunc_ = InnerProductDistance; - #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; - } else if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #elif defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #endif - #if defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; - InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; - } - #endif - - if (dim % 16 == 0) - fstdistfunc_ = InnerProductDistanceSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = InnerProductDistanceSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +class InnerProductSpace : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistance; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } #endif - dim_ = dim; - data_size_ = dim * sizeof(float); + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; } + #endif - size_t get_data_size() { - return data_size_; - } + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } - DISTFUNC get_dist_func() { - return fstdistfunc_; - } + size_t get_data_size() { + return data_size_; + } - void *get_dist_func_param() { - return &dim_; - } + DISTFUNC get_dist_func() { + return fstdistfunc_; + } - ~InnerProductSpace() {} - }; + void *get_dist_func_param() { + return &dim_; + } -} +~InnerProductSpace() {} +}; + +} // namespace hnswlib diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 44135370..834d19f7 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -3,328 +3,322 @@ namespace hnswlib { - static float - L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - - float res = 0; - for (size_t i = 0; i < qty; i++) { - float t = *pVect1 - *pVect2; - pVect1++; - pVect2++; - res += t * t; - } - return (res); +static float +L2Sqr(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + + float res = 0; + for (size_t i = 0; i < qty; i++) { + float t = *pVect1 - *pVect2; + pVect1++; + pVect2++; + res += t * t; } + return (res); +} #if defined(USE_AVX512) - // Favor using AVX512 if available. - static float - L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN64 TmpRes[16]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m512 diff, v1, v2; - __m512 sum = _mm512_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm512_loadu_ps(pVect1); - pVect1 += 16; - v2 = _mm512_loadu_ps(pVect2); - pVect2 += 16; - diff = _mm512_sub_ps(v1, v2); - // sum = _mm512_fmadd_ps(diff, diff, sum); - sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); - } +// Favor using AVX512 if available. +static float +L2SqrSIMD16ExtAVX512(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN64 TmpRes[16]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m512 diff, v1, v2; + __m512 sum = _mm512_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm512_loadu_ps(pVect1); + pVect1 += 16; + v2 = _mm512_loadu_ps(pVect2); + pVect2 += 16; + diff = _mm512_sub_ps(v1, v2); + // sum = _mm512_fmadd_ps(diff, diff, sum); + sum = _mm512_add_ps(sum, _mm512_mul_ps(diff, diff)); + } - _mm512_store_ps(TmpRes, sum); - float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + - TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + - TmpRes[13] + TmpRes[14] + TmpRes[15]; + _mm512_store_ps(TmpRes, sum); + float res = TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + + TmpRes[7] + TmpRes[8] + TmpRes[9] + TmpRes[10] + TmpRes[11] + TmpRes[12] + + TmpRes[13] + TmpRes[14] + TmpRes[15]; - return (res); + return (res); } #endif #if defined(USE_AVX) - // Favor using AVX if available. - static float - L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m256 diff, v1, v2; - __m256 sum = _mm256_set1_ps(0); - - while (pVect1 < pEnd1) { - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - - v1 = _mm256_loadu_ps(pVect1); - pVect1 += 8; - v2 = _mm256_loadu_ps(pVect2); - pVect2 += 8; - diff = _mm256_sub_ps(v1, v2); - sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); - } - - _mm256_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +// Favor using AVX if available. +static float +L2SqrSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m256 diff, v1, v2; + __m256 sum = _mm256_set1_ps(0); + + while (pVect1 < pEnd1) { + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); + + v1 = _mm256_loadu_ps(pVect1); + pVect1 += 8; + v2 = _mm256_loadu_ps(pVect2); + pVect2 += 8; + diff = _mm256_sub_ps(v1, v2); + sum = _mm256_add_ps(sum, _mm256_mul_ps(diff, diff)); } + _mm256_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3] + TmpRes[4] + TmpRes[5] + TmpRes[6] + TmpRes[7]; +} + #endif #if defined(USE_SSE) - static float - L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); - float PORTABLE_ALIGN32 TmpRes[8]; - size_t qty16 = qty >> 4; - - const float *pEnd1 = pVect1 + (qty16 << 4); - - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); - - while (pVect1 < pEnd1) { - //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +static float +L2SqrSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); + float PORTABLE_ALIGN32 TmpRes[8]; + size_t qty16 = qty >> 4; + + const float *pEnd1 = pVect1 + (qty16 << 4); + + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); + + while (pVect1 < pEnd1) { + //_mm_prefetch((char*)(pVect2 + 16), _MM_HINT_T0); + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); + + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} #endif #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; - - static float - L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty16 = qty >> 4 << 4; - float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); - float *pVect1 = (float *) pVect1v + qty16; - float *pVect2 = (float *) pVect2v + qty16; - - size_t qty_left = qty - qty16; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } +static DISTFUNC L2SqrSIMD16Ext = L2SqrSIMD16ExtSSE; + +static float +L2SqrSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty16 = qty >> 4 << 4; + float res = L2SqrSIMD16Ext(pVect1v, pVect2v, &qty16); + float *pVect1 = (float *) pVect1v + qty16; + float *pVect2 = (float *) pVect2v + qty16; + + size_t qty_left = qty - qty16; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + return (res + res_tail); +} #endif #if defined(USE_SSE) - static float - L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - float PORTABLE_ALIGN32 TmpRes[8]; - float *pVect1 = (float *) pVect1v; - float *pVect2 = (float *) pVect2v; - size_t qty = *((size_t *) qty_ptr); +static float +L2SqrSIMD4Ext(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + float PORTABLE_ALIGN32 TmpRes[8]; + float *pVect1 = (float *) pVect1v; + float *pVect2 = (float *) pVect2v; + size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2; + size_t qty4 = qty >> 2; - const float *pEnd1 = pVect1 + (qty4 << 2); + const float *pEnd1 = pVect1 + (qty4 << 2); - __m128 diff, v1, v2; - __m128 sum = _mm_set1_ps(0); + __m128 diff, v1, v2; + __m128 sum = _mm_set1_ps(0); - while (pVect1 < pEnd1) { - v1 = _mm_loadu_ps(pVect1); - pVect1 += 4; - v2 = _mm_loadu_ps(pVect2); - pVect2 += 4; - diff = _mm_sub_ps(v1, v2); - sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); - } - _mm_store_ps(TmpRes, sum); - return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; + while (pVect1 < pEnd1) { + v1 = _mm_loadu_ps(pVect1); + pVect1 += 4; + v2 = _mm_loadu_ps(pVect2); + pVect2 += 4; + diff = _mm_sub_ps(v1, v2); + sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff)); } + _mm_store_ps(TmpRes, sum); + return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; +} - static float - L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { - size_t qty = *((size_t *) qty_ptr); - size_t qty4 = qty >> 2 << 2; +static float +L2SqrSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + size_t qty4 = qty >> 2 << 2; - float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); - size_t qty_left = qty - qty4; + float res = L2SqrSIMD4Ext(pVect1v, pVect2v, &qty4); + size_t qty_left = qty - qty4; - float *pVect1 = (float *) pVect1v + qty4; - float *pVect2 = (float *) pVect2v + qty4; - float res_tail = L2Sqr(pVect1, pVect2, &qty_left); + float *pVect1 = (float *) pVect1v + qty4; + float *pVect2 = (float *) pVect2v + qty4; + float res_tail = L2Sqr(pVect1, pVect2, &qty_left); - return (res + res_tail); - } + return (res + res_tail); +} #endif - class L2Space : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2Space(size_t dim) { - fstdistfunc_ = L2Sqr; - #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) - #if defined(USE_AVX512) - if (AVX512Capable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; - else if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #elif defined(USE_AVX) - if (AVXCapable()) - L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; - #endif - - if (dim % 16 == 0) - fstdistfunc_ = L2SqrSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = L2SqrSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = L2SqrSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = L2SqrSIMD4ExtResiduals; +class L2Space : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2Space(size_t dim) { + fstdistfunc_ = L2Sqr; +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX512; + else if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; + #elif defined(USE_AVX) + if (AVXCapable()) + L2SqrSIMD16Ext = L2SqrSIMD16ExtAVX; #endif - dim_ = dim; - data_size_ = dim * sizeof(float); - } - size_t get_data_size() { - return data_size_; - } + if (dim % 16 == 0) + fstdistfunc_ = L2SqrSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = L2SqrSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = L2SqrSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = L2SqrSIMD4ExtResiduals; +#endif + dim_ = dim; + data_size_ = dim * sizeof(float); + } - DISTFUNC get_dist_func() { - return fstdistfunc_; - } + size_t get_data_size() { + return data_size_; + } - void *get_dist_func_param() { - return &dim_; - } + DISTFUNC get_dist_func() { + return fstdistfunc_; + } - ~L2Space() {} - }; - - static int - L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { - - size_t qty = *((size_t *) qty_ptr); - int res = 0; - unsigned char *a = (unsigned char *) pVect1; - unsigned char *b = (unsigned char *) pVect2; - - qty = qty >> 2; - for (size_t i = 0; i < qty; i++) { - - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + void *get_dist_func_param() { + return &dim_; } - static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { - size_t qty = *((size_t*)qty_ptr); - int res = 0; - unsigned char* a = (unsigned char*)pVect1; - unsigned char* b = (unsigned char*)pVect2; - - for(size_t i = 0; i < qty; i++) - { - res += ((*a) - (*b)) * ((*a) - (*b)); - a++; - b++; - } - return (res); + ~L2Space() {} +}; + +static int +L2SqrI4x(const void *__restrict pVect1, const void *__restrict pVect2, const void *__restrict qty_ptr) { + size_t qty = *((size_t *) qty_ptr); + int res = 0; + unsigned char *a = (unsigned char *) pVect1; + unsigned char *b = (unsigned char *) pVect2; + + qty = qty >> 2; + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; } + return (res); +} - class L2SpaceI : public SpaceInterface { - - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - L2SpaceI(size_t dim) { - if(dim % 4 == 0) { - fstdistfunc_ = L2SqrI4x; - } - else { - fstdistfunc_ = L2SqrI; - } - dim_ = dim; - data_size_ = dim * sizeof(unsigned char); - } +static int L2SqrI(const void* __restrict pVect1, const void* __restrict pVect2, const void* __restrict qty_ptr) { + size_t qty = *((size_t*)qty_ptr); + int res = 0; + unsigned char* a = (unsigned char*)pVect1; + unsigned char* b = (unsigned char*)pVect2; - size_t get_data_size() { - return data_size_; - } + for (size_t i = 0; i < qty; i++) { + res += ((*a) - (*b)) * ((*a) - (*b)); + a++; + b++; + } + return (res); +} - DISTFUNC get_dist_func() { - return fstdistfunc_; +class L2SpaceI : public SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + + public: + L2SpaceI(size_t dim) { + if (dim % 4 == 0) { + fstdistfunc_ = L2SqrI4x; + } else { + fstdistfunc_ = L2SqrI; } + dim_ = dim; + data_size_ = dim * sizeof(unsigned char); + } - void *get_dist_func_param() { - return &dim_; - } + size_t get_data_size() { + return data_size_; + } - ~L2SpaceI() {} - }; + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + void *get_dist_func_param() { + return &dim_; + } -} + ~L2SpaceI() {} +}; +} // namespace hnswlib diff --git a/hnswlib/visited_list_pool.h b/hnswlib/visited_list_pool.h index 5e1a4a58..2e201ec4 100644 --- a/hnswlib/visited_list_pool.h +++ b/hnswlib/visited_list_pool.h @@ -5,75 +5,74 @@ #include namespace hnswlib { - typedef unsigned short int vl_type; +typedef unsigned short int vl_type; - class VisitedList { - public: - vl_type curV; - vl_type *mass; - unsigned int numelements; +class VisitedList { + public: + vl_type curV; + vl_type *mass; + unsigned int numelements; - VisitedList(int numelements1) { - curV = -1; - numelements = numelements1; - mass = new vl_type[numelements]; - } + VisitedList(int numelements1) { + curV = -1; + numelements = numelements1; + mass = new vl_type[numelements]; + } - void reset() { + void reset() { + curV++; + if (curV == 0) { + memset(mass, 0, sizeof(vl_type) * numelements); curV++; - if (curV == 0) { - memset(mass, 0, sizeof(vl_type) * numelements); - curV++; - } - }; + } + } - ~VisitedList() { delete[] mass; } - }; + ~VisitedList() { delete[] mass; } +}; /////////////////////////////////////////////////////////// // // Class for multi-threaded pool-management of VisitedLists // ///////////////////////////////////////////////////////// - class VisitedListPool { - std::deque pool; - std::mutex poolguard; - int numelements; - - public: - VisitedListPool(int initmaxpools, int numelements1) { - numelements = numelements1; - for (int i = 0; i < initmaxpools; i++) - pool.push_front(new VisitedList(numelements)); - } +class VisitedListPool { + std::deque pool; + std::mutex poolguard; + int numelements; - VisitedList *getFreeVisitedList() { - VisitedList *rez; - { - std::unique_lock lock(poolguard); - if (pool.size() > 0) { - rez = pool.front(); - pool.pop_front(); - } else { - rez = new VisitedList(numelements); - } - } - rez->reset(); - return rez; - }; + public: + VisitedListPool(int initmaxpools, int numelements1) { + numelements = numelements1; + for (int i = 0; i < initmaxpools; i++) + pool.push_front(new VisitedList(numelements)); + } - void releaseVisitedList(VisitedList *vl) { + VisitedList *getFreeVisitedList() { + VisitedList *rez; + { std::unique_lock lock(poolguard); - pool.push_front(vl); - }; - - ~VisitedListPool() { - while (pool.size()) { - VisitedList *rez = pool.front(); + if (pool.size() > 0) { + rez = pool.front(); pool.pop_front(); - delete rez; + } else { + rez = new VisitedList(numelements); } - }; - }; -} + } + rez->reset(); + return rez; + } + void releaseVisitedList(VisitedList *vl) { + std::unique_lock lock(poolguard); + pool.push_front(vl); + } + + ~VisitedListPool() { + while (pool.size()) { + VisitedList *rez = pool.front(); + pool.pop_front(); + delete rez; + } + } +}; +} // namespace hnswlib diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index 12f38e2e..5153bb58 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -9,7 +10,7 @@ #include namespace py = pybind11; -using namespace pybind11::literals; // needed to bring in _a literal +using namespace pybind11::literals; // needed to bring in _a literal /* * replacement for the openmp '#pragma omp parallel for' directive @@ -42,7 +43,7 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn while (true) { size_t id = current.fetch_add(1); - if ((id >= end)) { + if (id >= end) { break; } @@ -74,187 +75,232 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn inline void assert_true(bool expr, const std::string & msg) { - if (expr == false) - throw std::runtime_error("Unpickle Error: "+msg); + if (expr == false) throw std::runtime_error("Unpickle Error: " + msg); return; } -template -class Index { -public: - Index(const std::string &space_name, const int dim) : - space_name(space_name), dim(dim) { - normalize=false; - if(space_name=="l2") { - l2space = new hnswlib::L2Space(dim); - } - else if(space_name=="ip") { - l2space = new hnswlib::InnerProductSpace(dim); - } - else if(space_name=="cosine") { - l2space = new hnswlib::InnerProductSpace(dim); - normalize=true; +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::function filter; + + public: + explicit CustomFilterFunctor(const std::function& f) { + filter = f; + } + + bool operator()(hnswlib::labeltype id) { + return filter(id); + } +}; + + +inline void get_input_array_shapes(const py::buffer_info& buffer, size_t* rows, size_t* features) { + if (buffer.ndim != 2 && buffer.ndim != 1) { + char msg[256]; + snprintf(msg, sizeof(msg), + "Input vector data wrong shape. Number of dimensions %d. Data must be a 1D or 2D array.", + buffer.ndim); + throw std::runtime_error(msg); + } + if (buffer.ndim == 2) { + *rows = buffer.shape[0]; + *features = buffer.shape[1]; } else { - throw new std::runtime_error("Space name must be one of l2, ip, or cosine."); + *rows = 1; + *features = buffer.shape[0]; } - appr_alg = NULL; - ep_added = true; - index_inited = false; - num_threads_default = std::thread::hardware_concurrency(); +} - default_ef=10; - } - static const int ser_version = 1; // serialization version +inline std::vector get_input_ids_and_check_shapes(const py::object& ids_, size_t feature_rows) { + std::vector ids; + if (!ids_.is_none()) { + py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); + auto ids_numpy = items.request(); + // check shapes + if (!((ids_numpy.ndim == 1 && ids_numpy.shape[0] == feature_rows) || + (ids_numpy.ndim == 0 && feature_rows == 1))) { + char msg[256]; + snprintf(msg, sizeof(msg), + "The input label shape %d does not match the input data vector shape %d", + ids_numpy.ndim, feature_rows); + throw std::runtime_error(msg); + } + // extract data + if (ids_numpy.ndim == 1) { + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); + } else if (ids_numpy.ndim == 0) { + ids.push_back(*items.data()); + } + } - std::string space_name; - int dim; - size_t seed; - size_t default_ef; + return ids; +} - bool index_inited; - bool ep_added; - bool normalize; - int num_threads_default; - hnswlib::labeltype cur_l; - hnswlib::HierarchicalNSW *appr_alg; - hnswlib::SpaceInterface *l2space; - ~Index() { - delete l2space; - if (appr_alg) - delete appr_alg; - } +template +class Index { + public: + static const int ser_version = 1; // serialization version + + std::string space_name; + int dim; + size_t seed; + size_t default_ef; + + bool index_inited; + bool ep_added; + bool normalize; + int num_threads_default; + hnswlib::labeltype cur_l; + hnswlib::HierarchicalNSW* appr_alg; + hnswlib::SpaceInterface* l2space; + + + Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { + normalize = false; + if (space_name == "l2") { + l2space = new hnswlib::L2Space(dim); + } else if (space_name == "ip") { + l2space = new hnswlib::InnerProductSpace(dim); + } else if (space_name == "cosine") { + l2space = new hnswlib::InnerProductSpace(dim); + normalize = true; + } else { + throw std::runtime_error("Space name must be one of l2, ip, or cosine."); + } + appr_alg = NULL; + ep_added = true; + index_inited = false; + num_threads_default = std::thread::hardware_concurrency(); + + default_ef = 10; + } + + + ~Index() { + delete l2space; + if (appr_alg) + delete appr_alg; + } + - void init_new_index(const size_t maxElements, const size_t M, const size_t efConstruction, const size_t random_seed) { + void init_new_index( + size_t maxElements, + size_t M, + size_t efConstruction, + size_t random_seed, + bool allow_replace_deleted) { if (appr_alg) { - throw new std::runtime_error("The index is already initiated."); + throw std::runtime_error("The index is already initiated."); } cur_l = 0; - appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed); + appr_alg = new hnswlib::HierarchicalNSW(l2space, maxElements, M, efConstruction, random_seed, allow_replace_deleted); index_inited = true; ep_added = false; appr_alg->ef_ = default_ef; - seed=random_seed; + seed = random_seed; } + void set_ef(size_t ef) { - default_ef=ef; + default_ef = ef; if (appr_alg) - appr_alg->ef_ = ef; + appr_alg->ef_ = ef; } + void set_num_threads(int num_threads) { this->num_threads_default = num_threads; } + void saveIndex(const std::string &path_to_index) { appr_alg->saveIndex(path_to_index); } - void loadIndex(const std::string &path_to_index, size_t max_elements) { + + void loadIndex(const std::string &path_to_index, size_t max_elements, bool allow_replace_deleted) { if (appr_alg) { - std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; + std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete appr_alg; } - appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements); + appr_alg = new hnswlib::HierarchicalNSW(l2space, path_to_index, false, max_elements, allow_replace_deleted); cur_l = appr_alg->cur_element_count; index_inited = true; } - void normalize_vector(float *data, float *norm_array){ - float norm=0.0f; - for(int i=0;i items(input); auto buffer = items.request(); if (num_threads <= 0) num_threads = num_threads_default; size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } - else{ - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); - - // avoid using threads when the number of searches is small: - if(rows<=num_threads*4){ - num_threads=1; - } + throw std::runtime_error("Wrong dimensionality of the vectors"); - std::vector ids; - - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if(ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } - else if(ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } - else - throw std::runtime_error("wrong dimensionality of the labels"); + // avoid using threads when the number of additions is small: + if (rows <= num_threads * 4) { + num_threads = 1; } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); { - - int start = 0; - if (!ep_added) { - size_t id = ids.size() ? ids.at(0) : (cur_l); - float *vector_data = (float *) items.data(0); - std::vector norm_array(dim); - if(normalize){ - normalize_vector(vector_data, norm_array.data()); - vector_data = norm_array.data(); + int start = 0; + if (!ep_added) { + size_t id = ids.size() ? ids.at(0) : (cur_l); + float* vector_data = (float*)items.data(0); + std::vector norm_array(dim); + if (normalize) { + normalize_vector(vector_data, norm_array.data()); + vector_data = norm_array.data(); + } + appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); + start = 1; + ep_added = true; } - appr_alg->addPoint((void *) vector_data, (size_t) id); - start = 1; - ep_added = true; - } py::gil_scoped_release l; - if(normalize==false) { + if (normalize == false) { ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { - size_t id = ids.size() ? ids.at(row) : (cur_l+row); - appr_alg->addPoint((void *) items.data(row), (size_t) id); - }); - } else{ + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); + }); + } else { std::vector norm_array(num_threads * dim); ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { // normalize vector: size_t start_idx = threadId * dim; - normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); - size_t id = ids.size() ? ids.at(row) : (cur_l+row); - appr_alg->addPoint((void *) (norm_array.data()+start_idx), (size_t) id); - }); - }; - cur_l+=rows; + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); + }); + } + cur_l += rows; } } + std::vector> getDataReturnList(py::object ids_ = py::none()) { std::vector ids; if (!ids_.is_none()) { @@ -262,13 +308,13 @@ class Index { auto ids_numpy = items.request(); if (ids_numpy.ndim == 0) { - throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors"); + throw std::invalid_argument("get_items accepts a list of indices and returns a list of vectors"); } else { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); + std::vector ids1(ids_numpy.shape[0]); + for (size_t i = 0; i < ids1.size(); i++) { + ids1[i] = items.data()[i]; + } + ids.swap(ids1); } } @@ -279,10 +325,11 @@ class Index { return data; } + std::vector getIdsList() { std::vector ids; - for(auto kv : appr_alg->label_lookup_) { + for (auto kv : appr_alg->label_lookup_) { ids.push_back(kv.first); } return ids; @@ -290,133 +337,132 @@ class Index { py::dict getAnnData() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ - std::unique_lock templock(appr_alg->global); + std::unique_lock templock(appr_alg->global); - size_t level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; - size_t link_npy_size = 0; - std::vector link_npy_offsets(appr_alg->cur_element_count); + size_t level0_npy_size = appr_alg->cur_element_count * appr_alg->size_data_per_element_; + size_t link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - link_npy_offsets[i]=link_npy_size; - if (linkListSize) - link_npy_size += linkListSize; - } + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i] = link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } - char* data_level0_npy = (char *) malloc(level0_npy_size); - char* link_list_npy = (char *) malloc(link_npy_size); - int* element_levels_npy = (int *) malloc(appr_alg->element_levels_.size()*sizeof(int)); + char* data_level0_npy = (char*)malloc(level0_npy_size); + char* link_list_npy = (char*)malloc(link_npy_size); + int* element_levels_npy = (int*)malloc(appr_alg->element_levels_.size() * sizeof(int)); - hnswlib::labeltype* label_lookup_key_npy = (hnswlib::labeltype *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); - hnswlib::tableint* label_lookup_val_npy = (hnswlib::tableint *) malloc(appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + hnswlib::labeltype* label_lookup_key_npy = (hnswlib::labeltype*)malloc(appr_alg->label_lookup_.size() * sizeof(hnswlib::labeltype)); + hnswlib::tableint* label_lookup_val_npy = (hnswlib::tableint*)malloc(appr_alg->label_lookup_.size() * sizeof(hnswlib::tableint)); - memset(label_lookup_key_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::labeltype)); - memset(label_lookup_val_npy, -1, appr_alg->label_lookup_.size()*sizeof(hnswlib::tableint)); + memset(label_lookup_key_npy, -1, appr_alg->label_lookup_.size() * sizeof(hnswlib::labeltype)); + memset(label_lookup_val_npy, -1, appr_alg->label_lookup_.size() * sizeof(hnswlib::tableint)); - size_t idx=0; - for ( auto it = appr_alg->label_lookup_.begin(); it != appr_alg->label_lookup_.end(); ++it ){ - label_lookup_key_npy[idx]= it->first; - label_lookup_val_npy[idx]= it->second; - idx++; - } + size_t idx = 0; + for (auto it = appr_alg->label_lookup_.begin(); it != appr_alg->label_lookup_.end(); ++it) { + label_lookup_key_npy[idx] = it->first; + label_lookup_val_npy[idx] = it->second; + idx++; + } - memset(link_list_npy, 0, link_npy_size); + memset(link_list_npy, 0, link_npy_size); - memcpy(data_level0_npy, appr_alg->data_level0_memory_, level0_npy_size); - memcpy(element_levels_npy, appr_alg->element_levels_.data(), appr_alg->element_levels_.size() * sizeof(int)); + memcpy(data_level0_npy, appr_alg->data_level0_memory_, level0_npy_size); + memcpy(element_levels_npy, appr_alg->element_levels_.data(), appr_alg->element_levels_.size() * sizeof(int)); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - if (linkListSize){ - memcpy(link_list_npy+link_npy_offsets[i], appr_alg->linkLists_[i], linkListSize); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize) { + memcpy(link_list_npy + link_npy_offsets[i], appr_alg->linkLists_[i], linkListSize); + } } - } - py::capsule free_when_done_l0(data_level0_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_lvl(element_levels_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_lb(label_lookup_key_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_id(label_lookup_val_npy, [](void *f) { - delete[] f; - }); - py::capsule free_when_done_ll(link_list_npy, [](void *f) { - delete[] f; - }); - - /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */ - /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */ - - return py::dict( - "offset_level0"_a=appr_alg->offsetLevel0_, - "max_elements"_a=appr_alg->max_elements_, - "cur_element_count"_a=appr_alg->cur_element_count, - "size_data_per_element"_a=appr_alg->size_data_per_element_, - "label_offset"_a=appr_alg->label_offset_, - "offset_data"_a=appr_alg->offsetData_, - "max_level"_a=appr_alg->maxlevel_, - "enterpoint_node"_a=appr_alg->enterpoint_node_, - "max_M"_a=appr_alg->maxM_, - "max_M0"_a=appr_alg->maxM0_, - "M"_a=appr_alg->M_, - "mult"_a=appr_alg->mult_, - "ef_construction"_a=appr_alg->ef_construction_, - "ef"_a=appr_alg->ef_, - "has_deletions"_a=(bool)appr_alg->num_deleted_, - "size_links_per_element"_a=appr_alg->size_links_per_element_, - - "label_lookup_external"_a=py::array_t( - {appr_alg->label_lookup_.size()}, // shape - {sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - label_lookup_key_npy, // the data pointer - free_when_done_lb), - - "label_lookup_internal"_a=py::array_t( - {appr_alg->label_lookup_.size()}, // shape - {sizeof(hnswlib::tableint)}, // C-style contiguous strides for double - label_lookup_val_npy, // the data pointer - free_when_done_id), - - "element_levels"_a=py::array_t( - {appr_alg->element_levels_.size()}, // shape - {sizeof(int)}, // C-style contiguous strides for double - element_levels_npy, // the data pointer - free_when_done_lvl), - - // linkLists_,element_levels_,data_level0_memory_ - "data_level0"_a=py::array_t( - {level0_npy_size}, // shape - {sizeof(char)}, // C-style contiguous strides for double - data_level0_npy, // the data pointer - free_when_done_l0), - - "link_lists"_a=py::array_t( - {link_npy_size}, // shape - {sizeof(char)}, // C-style contiguous strides for double - link_list_npy, // the data pointer - free_when_done_ll) - ); + py::capsule free_when_done_l0(data_level0_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_lvl(element_levels_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_lb(label_lookup_key_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_id(label_lookup_val_npy, [](void* f) { + delete[] f; + }); + py::capsule free_when_done_ll(link_list_npy, [](void* f) { + delete[] f; + }); + + /* TODO: serialize state of random generators appr_alg->level_generator_ and appr_alg->update_probability_generator_ */ + /* for full reproducibility / to avoid re-initializing generators inside Index::createFromParams */ + + return py::dict( + "offset_level0"_a = appr_alg->offsetLevel0_, + "max_elements"_a = appr_alg->max_elements_, + "cur_element_count"_a = (size_t)appr_alg->cur_element_count, + "size_data_per_element"_a = appr_alg->size_data_per_element_, + "label_offset"_a = appr_alg->label_offset_, + "offset_data"_a = appr_alg->offsetData_, + "max_level"_a = appr_alg->maxlevel_, + "enterpoint_node"_a = appr_alg->enterpoint_node_, + "max_M"_a = appr_alg->maxM_, + "max_M0"_a = appr_alg->maxM0_, + "M"_a = appr_alg->M_, + "mult"_a = appr_alg->mult_, + "ef_construction"_a = appr_alg->ef_construction_, + "ef"_a = appr_alg->ef_, + "has_deletions"_a = (bool)appr_alg->num_deleted_, + "size_links_per_element"_a = appr_alg->size_links_per_element_, + "allow_replace_deleted"_a = appr_alg->allow_replace_deleted_, + + "label_lookup_external"_a = py::array_t( + { appr_alg->label_lookup_.size() }, // shape + { sizeof(hnswlib::labeltype) }, // C-style contiguous strides for each index + label_lookup_key_npy, // the data pointer + free_when_done_lb), + + "label_lookup_internal"_a = py::array_t( + { appr_alg->label_lookup_.size() }, // shape + { sizeof(hnswlib::tableint) }, // C-style contiguous strides for each index + label_lookup_val_npy, // the data pointer + free_when_done_id), + + "element_levels"_a = py::array_t( + { appr_alg->element_levels_.size() }, // shape + { sizeof(int) }, // C-style contiguous strides for each index + element_levels_npy, // the data pointer + free_when_done_lvl), + + // linkLists_,element_levels_,data_level0_memory_ + "data_level0"_a = py::array_t( + { level0_npy_size }, // shape + { sizeof(char) }, // C-style contiguous strides for each index + data_level0_npy, // the data pointer + free_when_done_l0), + + "link_lists"_a = py::array_t( + { link_npy_size }, // shape + { sizeof(char) }, // C-style contiguous strides for each index + link_list_npy, // the data pointer + free_when_done_ll)); } py::dict getIndexParams() const { /* WARNING: Index::getAnnData is not thread-safe with Index::addItems */ auto params = py::dict( - "ser_version"_a=py::int_(Index::ser_version), //serialization version - "space"_a=space_name, - "dim"_a=dim, - "index_inited"_a=index_inited, - "ep_added"_a=ep_added, - "normalize"_a=normalize, - "num_threads"_a=num_threads_default, - "seed"_a=seed - ); - - if(index_inited == false) - return py::dict( **params, "ef"_a=default_ef); + "ser_version"_a = py::int_(Index::ser_version), // serialization version + "space"_a = space_name, + "dim"_a = dim, + "index_inited"_a = index_inited, + "ep_added"_a = ep_added, + "normalize"_a = normalize, + "num_threads"_a = num_threads_default, + "seed"_a = seed); + + if (index_inited == false) + return py::dict(**params, "ef"_a = default_ef); auto ann_params = getAnnData(); @@ -424,125 +470,142 @@ class Index { } - static Index * createFromParams(const py::dict d) { - // check serialization version - assert_true(((int)py::int_(Index::ser_version)) >= d["ser_version"].cast(), "Invalid serialization version!"); + static Index* createFromParams(const py::dict d) { + // check serialization version + assert_true(((int)py::int_(Index::ser_version)) >= d["ser_version"].cast(), "Invalid serialization version!"); - auto space_name_=d["space"].cast(); - auto dim_=d["dim"].cast(); - auto index_inited_=d["index_inited"].cast(); + auto space_name_ = d["space"].cast(); + auto dim_ = d["dim"].cast(); + auto index_inited_ = d["index_inited"].cast(); - Index *new_index = new Index(space_name_, dim_); + Index* new_index = new Index(space_name_, dim_); - /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */ - /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */ - new_index->seed = d["seed"].cast(); + /* TODO: deserialize state of random generators into new_index->level_generator_ and new_index->update_probability_generator_ */ + /* for full reproducibility / state of generators is serialized inside Index::getIndexParams */ + new_index->seed = d["seed"].cast(); - if (index_inited_){ - new_index->appr_alg = new hnswlib::HierarchicalNSW(new_index->l2space, d["max_elements"].cast(), d["M"].cast(), d["ef_construction"].cast(), new_index->seed); - new_index->cur_l = d["cur_element_count"].cast(); - } + if (index_inited_) { + new_index->appr_alg = new hnswlib::HierarchicalNSW( + new_index->l2space, + d["max_elements"].cast(), + d["M"].cast(), + d["ef_construction"].cast(), + new_index->seed); + new_index->cur_l = d["cur_element_count"].cast(); + } - new_index->index_inited = index_inited_; - new_index->ep_added=d["ep_added"].cast(); - new_index->num_threads_default=d["num_threads"].cast(); - new_index->default_ef=d["ef"].cast(); + new_index->index_inited = index_inited_; + new_index->ep_added = d["ep_added"].cast(); + new_index->num_threads_default = d["num_threads"].cast(); + new_index->default_ef = d["ef"].cast(); - if (index_inited_) - new_index->setAnnData(d); + if (index_inited_) + new_index->setAnnData(d); - return new_index; + return new_index; } + static Index * createFromIndex(const Index & index) { return createFromParams(index.getIndexParams()); } + void setAnnData(const py::dict d) { /* WARNING: Index::setAnnData is not thread-safe with Index::addItems */ - std::unique_lock templock(appr_alg->global); + std::unique_lock templock(appr_alg->global); - assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast(), "Invalid value of offsetLevel0_ "); - assert_true(appr_alg->max_elements_ == d["max_elements"].cast(), "Invalid value of max_elements_ "); + assert_true(appr_alg->offsetLevel0_ == d["offset_level0"].cast(), "Invalid value of offsetLevel0_ "); + assert_true(appr_alg->max_elements_ == d["max_elements"].cast(), "Invalid value of max_elements_ "); - appr_alg->cur_element_count = d["cur_element_count"].cast(); + appr_alg->cur_element_count = d["cur_element_count"].cast(); - assert_true(appr_alg->size_data_per_element_ == d["size_data_per_element"].cast(), "Invalid value of size_data_per_element_ "); - assert_true(appr_alg->label_offset_ == d["label_offset"].cast(), "Invalid value of label_offset_ "); - assert_true(appr_alg->offsetData_ == d["offset_data"].cast(), "Invalid value of offsetData_ "); + assert_true(appr_alg->size_data_per_element_ == d["size_data_per_element"].cast(), "Invalid value of size_data_per_element_ "); + assert_true(appr_alg->label_offset_ == d["label_offset"].cast(), "Invalid value of label_offset_ "); + assert_true(appr_alg->offsetData_ == d["offset_data"].cast(), "Invalid value of offsetData_ "); - appr_alg->maxlevel_ = d["max_level"].cast(); - appr_alg->enterpoint_node_ = d["enterpoint_node"].cast(); + appr_alg->maxlevel_ = d["max_level"].cast(); + appr_alg->enterpoint_node_ = d["enterpoint_node"].cast(); - assert_true(appr_alg->maxM_ == d["max_M"].cast(), "Invalid value of maxM_ "); - assert_true(appr_alg->maxM0_ == d["max_M0"].cast(), "Invalid value of maxM0_ "); - assert_true(appr_alg->M_ == d["M"].cast(), "Invalid value of M_ "); - assert_true(appr_alg->mult_ == d["mult"].cast(), "Invalid value of mult_ "); - assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast(), "Invalid value of ef_construction_ "); + assert_true(appr_alg->maxM_ == d["max_M"].cast(), "Invalid value of maxM_ "); + assert_true(appr_alg->maxM0_ == d["max_M0"].cast(), "Invalid value of maxM0_ "); + assert_true(appr_alg->M_ == d["M"].cast(), "Invalid value of M_ "); + assert_true(appr_alg->mult_ == d["mult"].cast(), "Invalid value of mult_ "); + assert_true(appr_alg->ef_construction_ == d["ef_construction"].cast(), "Invalid value of ef_construction_ "); - appr_alg->ef_ = d["ef"].cast(); + appr_alg->ef_ = d["ef"].cast(); - assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast(), "Invalid value of size_links_per_element_ "); + assert_true(appr_alg->size_links_per_element_ == d["size_links_per_element"].cast(), "Invalid value of size_links_per_element_ "); - auto label_lookup_key_npy = d["label_lookup_external"].cast >(); - auto label_lookup_val_npy = d["label_lookup_internal"].cast >(); - auto element_levels_npy = d["element_levels"].cast >(); - auto data_level0_npy = d["data_level0"].cast >(); - auto link_list_npy = d["link_lists"].cast >(); + auto label_lookup_key_npy = d["label_lookup_external"].cast >(); + auto label_lookup_val_npy = d["label_lookup_internal"].cast >(); + auto element_levels_npy = d["element_levels"].cast >(); + auto data_level0_npy = d["data_level0"].cast >(); + auto link_list_npy = d["link_lists"].cast >(); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - if (label_lookup_val_npy.data()[i] < 0){ - throw std::runtime_error("internal id cannot be negative!"); - } - else{ - appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + if (label_lookup_val_npy.data()[i] < 0) { + throw std::runtime_error("Internal id cannot be negative!"); + } else { + appr_alg->label_lookup_.insert(std::make_pair(label_lookup_key_npy.data()[i], label_lookup_val_npy.data()[i])); + } } - } - - memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes()); - size_t link_npy_size = 0; - std::vector link_npy_offsets(appr_alg->cur_element_count); + memcpy(appr_alg->element_levels_.data(), element_levels_npy.data(), element_levels_npy.nbytes()); - for (size_t i = 0; i < appr_alg->cur_element_count; i++){ - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - link_npy_offsets[i]=link_npy_size; - if (linkListSize) - link_npy_size += linkListSize; - } + size_t link_npy_size = 0; + std::vector link_npy_offsets(appr_alg->cur_element_count); - memcpy(appr_alg->data_level0_memory_, data_level0_npy.data(), data_level0_npy.nbytes()); + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + link_npy_offsets[i] = link_npy_size; + if (linkListSize) + link_npy_size += linkListSize; + } - for (size_t i = 0; i < appr_alg->max_elements_; i++) { - size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; - if (linkListSize == 0) { - appr_alg->linkLists_[i] = nullptr; - } else { - appr_alg->linkLists_[i] = (char *) malloc(linkListSize); - if (appr_alg->linkLists_[i] == nullptr) - throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); + memcpy(appr_alg->data_level0_memory_, data_level0_npy.data(), data_level0_npy.nbytes()); - memcpy(appr_alg->linkLists_[i], link_list_npy.data()+link_npy_offsets[i], linkListSize); + for (size_t i = 0; i < appr_alg->max_elements_; i++) { + size_t linkListSize = appr_alg->element_levels_[i] > 0 ? appr_alg->size_links_per_element_ * appr_alg->element_levels_[i] : 0; + if (linkListSize == 0) { + appr_alg->linkLists_[i] = nullptr; + } else { + appr_alg->linkLists_[i] = (char*)malloc(linkListSize); + if (appr_alg->linkLists_[i] == nullptr) + throw std::runtime_error("Not enough memory: loadIndex failed to allocate linklist"); - } - } + memcpy(appr_alg->linkLists_[i], link_list_npy.data() + link_npy_offsets[i], linkListSize); + } + } - // set num_deleted - appr_alg->num_deleted_ = 0; - bool has_deletions = d["has_deletions"].cast(); - if (has_deletions) - { - for (size_t i = 0; i < appr_alg->cur_element_count; i++) { - if(appr_alg->isMarkedDeleted(i)) - appr_alg->num_deleted_ += 1; + // process deleted elements + bool allow_replace_deleted = false; + if (d.contains("allow_replace_deleted")) { + allow_replace_deleted = d["allow_replace_deleted"].cast(); } - } -} + appr_alg->allow_replace_deleted_= allow_replace_deleted; + + appr_alg->num_deleted_ = 0; + bool has_deletions = d["has_deletions"].cast(); + if (has_deletions) { + for (size_t i = 0; i < appr_alg->cur_element_count; i++) { + if (appr_alg->isMarkedDeleted(i)) { + appr_alg->num_deleted_ += 1; + if (allow_replace_deleted) appr_alg->deleted_elements.insert(i); + } + } + } + } + - py::object knnQuery_return_numpy(py::object input, size_t k = 1, int num_threads = -1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + int num_threads = -1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); - hnswlib::labeltype *data_numpy_l; - dist_t *data_numpy_d; + hnswlib::labeltype* data_numpy_l; + dist_t* data_numpy_d; size_t rows, features; if (num_threads <= 0) @@ -550,140 +613,134 @@ class Index { { py::gil_scoped_release l; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } - else{ - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); // avoid using threads when the number of searches is small: - - if(rows<=num_threads*4){ - num_threads=1; + if (rows <= num_threads * 4) { + num_threads = 1; } data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; - if(normalize==false) { + // Warning: search with a filter works slow in python in multithreaded mode. For best performance set num_threads=1 + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + + if (normalize == false) { ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = appr_alg->searchKnn( - (void *) items.data(row), k); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); - for (int i = k - 1; i >= 0; i--) { - auto &result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - } - ); - } - else{ - std::vector norm_array(num_threads*features); + std::priority_queue> result = appr_alg->searchKnn( + (void*)items.data(row), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); + } else { + std::vector norm_array(num_threads * features); ParallelFor(0, rows, num_threads, [&](size_t row, size_t threadId) { - float *data= (float *) items.data(row); - - size_t start_idx = threadId * dim; - normalize_vector((float *) items.data(row), (norm_array.data()+start_idx)); - - std::priority_queue> result = appr_alg->searchKnn( - (void *) (norm_array.data()+start_idx), k); - if (result.size() != k) - throw std::runtime_error( - "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); - for (int i = k - 1; i >= 0; i--) { - auto &result_tuple = result.top(); - data_numpy_d[row * k + i] = result_tuple.first; - data_numpy_l[row * k + i] = result_tuple.second; - result.pop(); - } - } - ); + float* data = (float*)items.data(row); + + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + + std::priority_queue> result = appr_alg->searchKnn( + (void*)(norm_array.data() + start_idx), k, p_idFilter); + if (result.size() != k) + throw std::runtime_error( + "Cannot return the results in a contigious 2D array. Probably ef or M is too small"); + for (int i = k - 1; i >= 0; i--) { + auto& result_tuple = result.top(); + data_numpy_d[row * k + i] = result_tuple.first; + data_numpy_l[row * k + i] = result_tuple.second; + result.pop(); + } + }); } } - py::capsule free_when_done_l(data_numpy_l, [](void *f) { + py::capsule free_when_done_l(data_numpy_l, [](void* f) { delete[] f; - }); - py::capsule free_when_done_d(data_numpy_d, [](void *f) { + }); + py::capsule free_when_done_d(data_numpy_d, [](void* f) { delete[] f; - }); + }); return py::make_tuple( - py::array_t( - {rows, k}, // shape - {k * sizeof(hnswlib::labeltype), - sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - data_numpy_l, // the data pointer - free_when_done_l), - py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for double - data_numpy_d, // the data pointer - free_when_done_d)); - + py::array_t( + { rows, k }, // shape + { k * sizeof(hnswlib::labeltype), + sizeof(hnswlib::labeltype) }, // C-style contiguous strides for each index + data_numpy_l, // the data pointer + free_when_done_l), + py::array_t( + { rows, k }, // shape + { k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index + data_numpy_d, // the data pointer + free_when_done_d)); } + void markDeleted(size_t label) { appr_alg->markDelete(label); } + void unmarkDeleted(size_t label) { appr_alg->unmarkDelete(label); } + void resizeIndex(size_t new_size) { appr_alg->resizeIndex(new_size); } + size_t getMaxElements() const { return appr_alg->max_elements_; } + size_t getCurrentCount() const { return appr_alg->cur_element_count; } }; -template +template class BFIndex { -public: - BFIndex(const std::string &space_name, const int dim) : - space_name(space_name), dim(dim) { - normalize=false; - if(space_name=="l2") { + public: + static const int ser_version = 1; // serialization version + + std::string space_name; + int dim; + bool index_inited; + bool normalize; + + hnswlib::labeltype cur_l; + hnswlib::BruteforceSearch* alg; + hnswlib::SpaceInterface* space; + + + BFIndex(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { + normalize = false; + if (space_name == "l2") { space = new hnswlib::L2Space(dim); - } - else if(space_name=="ip") { + } else if (space_name == "ip") { space = new hnswlib::InnerProductSpace(dim); - } - else if(space_name=="cosine") { + } else if (space_name == "cosine") { space = new hnswlib::InnerProductSpace(dim); - normalize=true; + normalize = true; } else { - throw new std::runtime_error("Space name must be one of l2, ip, or cosine."); + throw std::runtime_error("Space name must be one of l2, ip, or cosine."); } alg = NULL; index_inited = false; } - static const int ser_version = 1; // serialization version - - std::string space_name; - int dim; - bool index_inited; - bool normalize; - - hnswlib::labeltype cur_l; - hnswlib::BruteforceSearch *alg; - hnswlib::SpaceInterface *space; ~BFIndex() { delete space; @@ -691,59 +748,39 @@ class BFIndex { delete alg; } + void init_new_index(const size_t maxElements) { if (alg) { - throw new std::runtime_error("The index is already initiated."); + throw std::runtime_error("The index is already initiated."); } cur_l = 0; alg = new hnswlib::BruteforceSearch(space, maxElements); index_inited = true; } - void normalize_vector(float *data, float *norm_array){ - float norm=0.0f; - for(int i=0;i items(input); auto buffer = items.request(); size_t rows, features; - - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); if (features != dim) - throw std::runtime_error("wrong dimensionality of the vectors"); + throw std::runtime_error("Wrong dimensionality of the vectors"); - std::vector ids; + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); - if (!ids_.is_none()) { - py::array_t < size_t, py::array::c_style | py::array::forcecast > items(ids_); - auto ids_numpy = items.request(); - if (ids_numpy.ndim == 1 && ids_numpy.shape[0] == rows) { - std::vector ids1(ids_numpy.shape[0]); - for (size_t i = 0; i < ids1.size(); i++) { - ids1[i] = items.data()[i]; - } - ids.swap(ids1); - } else if (ids_numpy.ndim == 0 && rows == 1) { - ids.push_back(*items.data()); - } else - throw std::runtime_error("wrong dimensionality of the labels"); - } { - for (size_t row = 0; row < rows; row++) { size_t id = ids.size() ? ids.at(row) : cur_l + row; if (!normalize) { @@ -758,17 +795,20 @@ class BFIndex { } } + void deleteVector(size_t label) { alg->removePoint(label); } + void saveIndex(const std::string &path_to_index) { alg->saveIndex(path_to_index); } + void loadIndex(const std::string &path_to_index, size_t max_elements) { if (alg) { - std::cerr<<"Warning: Calling load_index for an already inited index. Old index is being deallocated."; + std::cerr << "Warning: Calling load_index for an already inited index. Old index is being deallocated." << std::endl; delete alg; } alg = new hnswlib::BruteforceSearch(space, path_to_index); @@ -776,8 +816,11 @@ class BFIndex { index_inited = true; } - py::object knnQuery_return_numpy(py::object input, size_t k = 1) { + py::object knnQuery_return_numpy( + py::object input, + size_t k = 1, + const std::function& filter = nullptr) { py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); auto buffer = items.request(); hnswlib::labeltype *data_numpy_l; @@ -786,21 +829,17 @@ class BFIndex { { py::gil_scoped_release l; - if (buffer.ndim != 2 && buffer.ndim != 1) throw std::runtime_error("data must be a 1d/2d array"); - if (buffer.ndim == 2) { - rows = buffer.shape[0]; - features = buffer.shape[1]; - } else { - rows = 1; - features = buffer.shape[0]; - } + get_input_array_shapes(buffer, &rows, &features); data_numpy_l = new hnswlib::labeltype[rows * k]; data_numpy_d = new dist_t[rows * k]; + CustomFilterFunctor idFilter(filter); + CustomFilterFunctor* p_idFilter = filter ? &idFilter : nullptr; + for (size_t row = 0; row < rows; row++) { std::priority_queue> result = alg->searchKnn( - (void *) items.data(row), k); + (void *) items.data(row), k, p_idFilter); for (int i = k - 1; i >= 0; i--) { auto &result_tuple = result.top(); data_numpy_d[row * k + i] = result_tuple.first; @@ -820,21 +859,20 @@ class BFIndex { return py::make_tuple( py::array_t( - {rows, k}, // shape - {k * sizeof(hnswlib::labeltype), - sizeof(hnswlib::labeltype)}, // C-style contiguous strides for double - data_numpy_l, // the data pointer + { rows, k }, // shape + { k * sizeof(hnswlib::labeltype), + sizeof(hnswlib::labeltype)}, // C-style contiguous strides for each index + data_numpy_l, // the data pointer free_when_done_l), py::array_t( - {rows, k}, // shape - {k * sizeof(dist_t), sizeof(dist_t)}, // C-style contiguous strides for double - data_numpy_d, // the data pointer + { rows, k }, // shape + { k * sizeof(dist_t), sizeof(dist_t) }, // C-style contiguous strides for each index + data_numpy_d, // the data pointer free_when_done_d)); - } - }; + PYBIND11_PLUGIN(hnswlib) { py::module m("hnswlib"); @@ -843,15 +881,35 @@ PYBIND11_PLUGIN(hnswlib) { /* WARNING: Index::createFromIndex is not thread-safe with Index::addItems */ .def(py::init(&Index::createFromIndex), py::arg("index")) .def(py::init(), py::arg("space"), py::arg("dim")) - .def("init_index", &Index::init_new_index, py::arg("max_elements"), py::arg("M")=16, py::arg("ef_construction")=200, py::arg("random_seed")=100) - .def("knn_query", &Index::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1, py::arg("num_threads")=-1) - .def("add_items", &Index::addItems, py::arg("data"), py::arg("ids") = py::none(), py::arg("num_threads")=-1) + .def("init_index", + &Index::init_new_index, + py::arg("max_elements"), + py::arg("M") = 16, + py::arg("ef_construction") = 200, + py::arg("random_seed") = 100, + py::arg("allow_replace_deleted") = false) + .def("knn_query", + &Index::knnQuery_return_numpy, + py::arg("data"), + py::arg("k") = 1, + py::arg("num_threads") = -1, + py::arg("filter") = py::none()) + .def("add_items", + &Index::addItems, + py::arg("data"), + py::arg("ids") = py::none(), + py::arg("num_threads") = -1, + py::arg("replace_deleted") = false) .def("get_items", &Index::getDataReturnList, py::arg("ids") = py::none()) .def("get_ids_list", &Index::getIdsList) .def("set_ef", &Index::set_ef, py::arg("ef")) .def("set_num_threads", &Index::set_num_threads, py::arg("num_threads")) .def("save_index", &Index::saveIndex, py::arg("path_to_index")) - .def("load_index", &Index::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) + .def("load_index", + &Index::loadIndex, + py::arg("path_to_index"), + py::arg("max_elements") = 0, + py::arg("allow_replace_deleted") = false) .def("mark_deleted", &Index::markDeleted, py::arg("label")) .def("unmark_deleted", &Index::unmarkDeleted, py::arg("label")) .def("resize_index", &Index::resizeIndex, py::arg("new_size")) @@ -865,7 +923,7 @@ PYBIND11_PLUGIN(hnswlib) { return index.index_inited ? index.appr_alg->ef_ : index.default_ef; }, [](Index & index, const size_t ef_) { - index.default_ef=ef_; + index.default_ef = ef_; if (index.appr_alg) index.appr_alg->ef_ = ef_; }) @@ -873,7 +931,7 @@ PYBIND11_PLUGIN(hnswlib) { return index.index_inited ? index.appr_alg->max_elements_ : 0; }) .def_property_readonly("element_count", [](const Index & index) { - return index.index_inited ? index.appr_alg->cur_element_count : 0; + return index.index_inited ? (size_t)index.appr_alg->cur_element_count : 0; }) .def_property_readonly("ef_construction", [](const Index & index) { return index.index_inited ? index.appr_alg->ef_construction_ : 0; @@ -883,16 +941,14 @@ PYBIND11_PLUGIN(hnswlib) { }) .def(py::pickle( - [](const Index &ind) { // __getstate__ + [](const Index &ind) { // __getstate__ return py::make_tuple(ind.getIndexParams()); /* Return dict (wrapped in a tuple) that fully encodes state of the Index object */ }, - [](py::tuple t) { // __setstate__ + [](py::tuple t) { // __setstate__ if (t.size() != 1) throw std::runtime_error("Invalid state!"); - return Index::createFromParams(t[0].cast()); - } - )) + })) .def("__repr__", [](const Index &a) { return ""; @@ -901,11 +957,11 @@ PYBIND11_PLUGIN(hnswlib) { py::class_>(m, "BFIndex") .def(py::init(), py::arg("space"), py::arg("dim")) .def("init_index", &BFIndex::init_new_index, py::arg("max_elements")) - .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k")=1) + .def("knn_query", &BFIndex::knnQuery_return_numpy, py::arg("data"), py::arg("k") = 1, py::arg("filter") = py::none()) .def("add_items", &BFIndex::addItems, py::arg("data"), py::arg("ids") = py::none()) .def("delete_vector", &BFIndex::deleteVector, py::arg("label")) .def("save_index", &BFIndex::saveIndex, py::arg("path_to_index")) - .def("load_index", &BFIndex::loadIndex, py::arg("path_to_index"), py::arg("max_elements")=0) + .def("load_index", &BFIndex::loadIndex, py::arg("path_to_index"), py::arg("max_elements") = 0) .def("__repr__", [](const BFIndex &a) { return ""; }); diff --git a/python_bindings/tests/__init__.py b/python_bindings/tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/python_bindings/tests/bindings_test_recall.py b/python_bindings/tests/bindings_test_recall.py deleted file mode 100644 index 3742fcdd..00000000 --- a/python_bindings/tests/bindings_test_recall.py +++ /dev/null @@ -1,88 +0,0 @@ -import hnswlib -import numpy as np - -dim = 32 -num_elements = 100000 -k = 10 -nun_queries = 10 - -# Generating sample data -data = np.float32(np.random.random((num_elements, dim))) - -# Declaring index -hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip -bf_index = hnswlib.BFIndex(space='l2', dim=dim) - -# Initing both hnsw and brute force indices -# max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded -# during insertion of an element. -# The capacity can be increased by saving/loading the index, see below. -# -# hnsw construction params: -# ef_construction - controls index search speed/build speed tradeoff -# -# M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) -# Higher M leads to higher accuracy/run_time at fixed ef/efConstruction - -hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) -bf_index.init_index(max_elements=num_elements) - -# Controlling the recall for hnsw by setting ef: -# higher ef leads to better accuracy, but slower search -hnsw_index.set_ef(200) - -# Set number of threads used during batch search/construction in hnsw -# By default using all available cores -hnsw_index.set_num_threads(1) - -print("Adding batch of %d elements" % (len(data))) -hnsw_index.add_items(data) -bf_index.add_items(data) - -print("Indices built") - -# Generating query data -query_data = np.float32(np.random.random((nun_queries, dim))) - -# Query the elements and measure recall: -labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) -labels_bf, distances_bf = bf_index.knn_query(query_data, k) - -# Measure recall -correct = 0 -for i in range(nun_queries): - for label in labels_hnsw[i]: - for correct_label in labels_bf[i]: - if label == correct_label: - correct += 1 - break - -print("recall is :", float(correct)/(k*nun_queries)) - -# test serializing the brute force index -index_path = 'bf_index.bin' -print("Saving index to '%s'" % index_path) -bf_index.save_index(index_path) -del bf_index - -# Re-initiating, loading the index -bf_index = hnswlib.BFIndex(space='l2', dim=dim) - -print("\nLoading index from '%s'\n" % index_path) -bf_index.load_index(index_path) - -# Query the brute force index again to verify that we get the same results -labels_bf, distances_bf = bf_index.knn_query(query_data, k) - -# Measure recall -correct = 0 -for i in range(nun_queries): - for label in labels_hnsw[i]: - for correct_label in labels_bf[i]: - if label == correct_label: - correct += 1 - break - -print("recall after reloading is :", float(correct)/(k*nun_queries)) - - diff --git a/setup.py b/setup.py index 161886fd..0126585e 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -__version__ = '0.6.1' +__version__ = '0.7.0' include_dirs = [ diff --git a/download_bigann.py b/tests/cpp/download_bigann.py similarity index 100% rename from download_bigann.py rename to tests/cpp/download_bigann.py diff --git a/main.cpp b/tests/cpp/main.cpp similarity index 97% rename from main.cpp rename to tests/cpp/main.cpp index 6c8acc9b..bf0fc2bf 100644 --- a/main.cpp +++ b/tests/cpp/main.cpp @@ -5,4 +5,4 @@ int main() { sift_test1B(); return 0; -}; \ No newline at end of file +} diff --git a/tests/cpp/multiThreadLoad_test.cpp b/tests/cpp/multiThreadLoad_test.cpp new file mode 100644 index 00000000..4d2b4aa2 --- /dev/null +++ b/tests/cpp/multiThreadLoad_test.cpp @@ -0,0 +1,140 @@ +#include "../../hnswlib/hnswlib.h" +#include +#include + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int max_elements = 1000; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * max_elements); + + std::cout << "Building index" << std::endl; + int num_threads = 40; + int num_labels = 10; + + int num_iterations = 10; + int start_label = 0; + + // run threads that will add elements to the index + // about 7 threads (the number depends on num_threads and num_labels) + // will add/update element with the same label simultaneously + while (true) { + // add elements by batches + std::uniform_int_distribution<> distrib_int(start_label, start_label + num_labels - 1); + std::vector threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + for (int iter = 0; iter < num_iterations; iter++) { + std::vector data(d); + hnswlib::labeltype label = distrib_int(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + ) + ); + } + for (auto &thread : threads) { + thread.join(); + } + if (alg_hnsw->cur_element_count > max_elements - num_labels) { + break; + } + start_label += num_labels; + } + + // insert remaining elements if needed + for (hnswlib::labeltype label = 0; label < max_elements; label++) { + auto search = alg_hnsw->label_lookup_.find(label); + if (search == alg_hnsw->label_lookup_.end()) { + std::cout << "Adding " << label << std::endl; + std::vector data(d); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + } + } + + std::cout << "Index is created" << std::endl; + + bool stop_threads = false; + std::vector threads; + + // create threads that will do markDeleted and unmarkDeleted of random elements + // each thread works with specific range of labels + std::cout << "Starting markDeleted and unmarkDeleted threads" << std::endl; + num_threads = 20; + int chunk_size = max_elements / num_threads; + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&, thread_id] { + std::uniform_int_distribution<> distrib_int(0, chunk_size - 1); + int start_id = thread_id * chunk_size; + std::vector marked_deleted(chunk_size); + while (!stop_threads) { + int id = distrib_int(rng); + hnswlib::labeltype label = start_id + id; + if (marked_deleted[id]) { + alg_hnsw->unmarkDelete(label); + marked_deleted[id] = false; + } else { + alg_hnsw->markDelete(label); + marked_deleted[id] = true; + } + } + } + ) + ); + } + + // create threads that will add and update random elements + std::cout << "Starting add and update elements threads" << std::endl; + num_threads = 20; + std::uniform_int_distribution<> distrib_int_add(max_elements, 2 * max_elements - 1); + for (size_t thread_id = 0; thread_id < num_threads; thread_id++) { + threads.push_back( + std::thread( + [&] { + std::vector data(d); + while (!stop_threads) { + hnswlib::labeltype label = distrib_int_add(rng); + for (int i = 0; i < d; i++) { + data[i] = distrib_real(rng); + } + alg_hnsw->addPoint(data.data(), label); + std::vector data = alg_hnsw->getDataByLabel(label); + float max_val = *max_element(data.begin(), data.end()); + // never happens but prevents compiler from deleting unused code + if (max_val > 10) { + throw std::runtime_error("Unexpected value in data"); + } + } + } + ) + ); + } + + std::cout << "Sleep and continue operations with index" << std::endl; + int sleep_ms = 60 * 1000; + std::this_thread::sleep_for(std::chrono::milliseconds(sleep_ms)); + stop_threads = true; + for (auto &thread : threads) { + thread.join(); + } + + std::cout << "Finish" << std::endl; + return 0; +} diff --git a/tests/cpp/multiThread_replace_test.cpp b/tests/cpp/multiThread_replace_test.cpp new file mode 100644 index 00000000..203cdb0d --- /dev/null +++ b/tests/cpp/multiThread_replace_test.cpp @@ -0,0 +1,121 @@ +#include "../../hnswlib/hnswlib.h" +#include +#include + + +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + + +int main() { + std::cout << "Running multithread load test" << std::endl; + int d = 16; + int num_elements = 1000; + int max_elements = 2 * num_elements; + int num_threads = 50; + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + + hnswlib::L2Space space(d); + + // generate batch1 and batch2 data + float* batch1 = new float[d * max_elements]; + for (int i = 0; i < d * max_elements; i++) { + batch1[i] = distrib_real(rng); + } + float* batch2 = new float[d * num_elements]; + for (int i = 0; i < d * num_elements; i++) { + batch2[i] = distrib_real(rng); + } + + // generate random labels to delete them from index + std::vector rand_labels(max_elements); + for (int i = 0; i < max_elements; i++) { + rand_labels[i] = i; + } + std::shuffle(rand_labels.begin(), rand_labels.end(), rng); + + int iter = 0; + while (iter < 200) { + hnswlib::HierarchicalNSW* alg_hnsw = new hnswlib::HierarchicalNSW(&space, max_elements, 16, 200, 123, true); + + // add batch1 data + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw->addPoint((void*)(batch1 + d * row), row); + }); + + // delete half random elements of batch1 data + for (int i = 0; i < num_elements; i++) { + alg_hnsw->markDelete(rand_labels[i]); + } + + // replace deleted elements with batch2 data + ParallelFor(0, num_elements, num_threads, [&](size_t row, size_t threadId) { + int label = rand_labels[row] + max_elements; + alg_hnsw->addPoint((void*)(batch2 + d * row), label, true); + }); + + iter += 1; + + delete alg_hnsw; + } + + std::cout << "Finish" << std::endl; + + delete[] batch1; + delete[] batch2; + return 0; +} diff --git a/examples/searchKnnCloserFirst_test.cpp b/tests/cpp/searchKnnCloserFirst_test.cpp similarity index 96% rename from examples/searchKnnCloserFirst_test.cpp rename to tests/cpp/searchKnnCloserFirst_test.cpp index cc1392c8..9583fe22 100644 --- a/examples/searchKnnCloserFirst_test.cpp +++ b/tests/cpp/searchKnnCloserFirst_test.cpp @@ -3,15 +3,14 @@ // >>> searchKnnCloserFirst(const void* query_data, size_t k) const; // of class AlgorithmInterface -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include #include #include -namespace -{ +namespace { using idx_t = hnswlib::labeltype; @@ -20,7 +19,7 @@ void test() { idx_t n = 100; idx_t nq = 10; size_t k = 10; - + std::vector data(n * d); std::vector query(nq * d); @@ -34,7 +33,6 @@ void test() { for (idx_t i = 0; i < nq * d; ++i) { query[i] = distrib(rng); } - hnswlib::L2Space space(d); hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); @@ -68,12 +66,12 @@ void test() { gd.pop(); } } - + delete alg_brute; delete alg_hnsw; } -} // namespace +} // namespace int main() { std::cout << "Testing ..." << std::endl; diff --git a/tests/cpp/searchKnnWithFilter_test.cpp b/tests/cpp/searchKnnWithFilter_test.cpp new file mode 100644 index 00000000..0557b7e4 --- /dev/null +++ b/tests/cpp/searchKnnWithFilter_test.cpp @@ -0,0 +1,179 @@ +// This is a test file for testing the filtering feature + +#include "../../hnswlib/hnswlib.h" + +#include + +#include +#include + +namespace { + +using idx_t = hnswlib::labeltype; + +class PickDivisibleIds: public hnswlib::BaseFilterFunctor { +unsigned int divisor = 1; + public: + PickDivisibleIds(unsigned int divisor): divisor(divisor) { + assert(divisor != 0); + } + bool operator()(idx_t label_id) { + return label_id % divisor == 0; + } +}; + +class PickNothing: public hnswlib::BaseFilterFunctor { + public: + bool operator()(idx_t label_id) { + return false; + } +}; + +void test_some_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t div_num, size_t label_id_start) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); + assert(gd.size() == res.size()); + size_t t = gd.size(); + while (!gd.empty()) { + assert(gd.top() == res[--t]); + assert((gd.top().second % div_num) == 0); + gd.pop(); + } + } + + delete alg_brute; + delete alg_hnsw; +} + +void test_none_filtering(hnswlib::BaseFilterFunctor& filter_func, size_t label_id_start) { + int d = 4; + idx_t n = 100; + idx_t nq = 10; + size_t k = 10; + + std::vector data(n * d); + std::vector query(nq * d); + + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib; + + for (idx_t i = 0; i < n * d; ++i) { + data[i] = distrib(rng); + } + for (idx_t i = 0; i < nq * d; ++i) { + query[i] = distrib(rng); + } + + hnswlib::L2Space space(d); + hnswlib::AlgorithmInterface* alg_brute = new hnswlib::BruteforceSearch(&space, 2 * n); + hnswlib::AlgorithmInterface* alg_hnsw = new hnswlib::HierarchicalNSW(&space, 2 * n); + + for (size_t i = 0; i < n; ++i) { + // `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs + alg_brute->addPoint(data.data() + d * i, label_id_start + i); + alg_hnsw->addPoint(data.data() + d * i, label_id_start + i); + } + + // test searchKnnCloserFirst of BruteforceSearch with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_brute->searchKnn(p, k, &filter_func); + auto res = alg_brute->searchKnnCloserFirst(p, k, &filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + // test searchKnnCloserFirst of hnsw with filtering + for (size_t j = 0; j < nq; ++j) { + const void* p = query.data() + j * d; + auto gd = alg_hnsw->searchKnn(p, k, &filter_func); + auto res = alg_hnsw->searchKnnCloserFirst(p, k, &filter_func); + assert(gd.size() == res.size()); + assert(0 == gd.size()); + } + + delete alg_brute; + delete alg_hnsw; +} + +} // namespace + +class CustomFilterFunctor: public hnswlib::BaseFilterFunctor { + std::unordered_set allowed_values; + + public: + explicit CustomFilterFunctor(const std::unordered_set& values) : allowed_values(values) {} + + bool operator()(idx_t id) { + return allowed_values.count(id) != 0; + } +}; + +int main() { + std::cout << "Testing ..." << std::endl; + + // some of the elements are filtered + PickDivisibleIds pickIdsDivisibleByThree(3); + test_some_filtering(pickIdsDivisibleByThree, 3, 17); + PickDivisibleIds pickIdsDivisibleBySeven(7); + test_some_filtering(pickIdsDivisibleBySeven, 7, 17); + + // all of the elements are filtered + PickNothing pickNothing; + test_none_filtering(pickNothing, 17); + + // functor style which can capture context + CustomFilterFunctor pickIdsDivisibleByThirteen({26, 39, 52, 65}); + test_some_filtering(pickIdsDivisibleByThirteen, 13, 21); + + std::cout << "Test ok" << std::endl; + + return 0; +} diff --git a/sift_1b.cpp b/tests/cpp/sift_1b.cpp similarity index 90% rename from sift_1b.cpp rename to tests/cpp/sift_1b.cpp index 2739490c..43777ff6 100644 --- a/sift_1b.cpp +++ b/tests/cpp/sift_1b.cpp @@ -2,7 +2,7 @@ #include #include #include -#include "hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include @@ -12,7 +12,7 @@ using namespace hnswlib; class StopW { std::chrono::steady_clock::time_point time_begin; -public: + public: StopW() { time_begin = std::chrono::steady_clock::now(); } @@ -25,7 +25,6 @@ class StopW { void reset() { time_begin = std::chrono::steady_clock::now(); } - }; @@ -80,8 +79,7 @@ static size_t getPeakRSS() { int fd = -1; if ((fd = open("/proc/self/psinfo", O_RDONLY)) == -1) return (size_t)0L; /* Can't open? */ - if (read(fd, &psinfo, sizeof(psinfo)) != sizeof(psinfo)) - { + if (read(fd, &psinfo, sizeof(psinfo)) != sizeof(psinfo)) { close(fd); return (size_t)0L; /* Can't read? */ } @@ -146,10 +144,16 @@ static size_t getCurrentRSS() { static void -get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t vecsize, size_t qsize, L2SpaceI &l2space, - size_t vecdim, vector>> &answers, size_t k) { - - +get_gt( + unsigned int *massQA, + unsigned char *massQ, + unsigned char *mass, + size_t vecsize, + size_t qsize, + L2SpaceI &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { (vector>>(qsize)).swap(answers); DISTFUNC fstdistfunc_ = l2space.get_dist_func(); cout << qsize << "\n"; @@ -161,43 +165,50 @@ get_gt(unsigned int *massQA, unsigned char *massQ, unsigned char *mass, size_t v } static float -test_approx(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +test_approx( + unsigned char *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { size_t correct = 0; size_t total = 0; - //uncomment to test in parallel mode: + // uncomment to test in parallel mode: //#pragma omp parallel for for (int i = 0; i < qsize; i++) { - std::priority_queue> result = appr_alg.searchKnn(massQ + vecdim * i, k); std::priority_queue> gt(answers[i]); unordered_set g; total += gt.size(); while (gt.size()) { - - g.insert(gt.top().second); gt.pop(); } while (result.size()) { if (g.find(result.top().second) != g.end()) { - correct++; } else { } result.pop(); } - } return 1.0f * correct / total; } static void -test_vs_recall(unsigned char *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { - vector efs;// = { 10,10,10,10,10 }; +test_vs_recall( + unsigned char *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { + vector efs; // = { 10,10,10,10,10 }; for (int i = k; i < 30; i++) { efs.push_back(i); } @@ -229,12 +240,9 @@ inline bool exists_test(const std::string &name) { void sift_test1B() { - - - int subset_size_milllions = 200; - int efConstruction = 40; - int M = 16; - + int subset_size_milllions = 200; + int efConstruction = 40; + int M = 16; size_t vecsize = subset_size_milllions * 1000000; @@ -248,7 +256,6 @@ void sift_test1B() { sprintf(path_gt, "../bigann/gnd/idx_%dM.ivecs", subset_size_milllions); - unsigned char *massb = new unsigned char[vecdim]; cout << "Loading GT:\n"; @@ -264,7 +271,7 @@ void sift_test1B() { } } inputGT.close(); - + cout << "Loading queries:\n"; unsigned char *massQ = new unsigned char[qsize * vecdim]; ifstream inputQ(path_q, ios::binary); @@ -280,7 +287,6 @@ void sift_test1B() { for (int j = 0; j < vecdim; j++) { massQ[i * vecdim + j] = massb[j]; } - } inputQ.close(); @@ -299,7 +305,6 @@ void sift_test1B() { cout << "Building index:\n"; appr_alg = new HierarchicalNSW(&l2space, vecsize, M, efConstruction); - input.read((char *) &in, 4); if (in != 128) { cout << "file error"; @@ -319,10 +324,9 @@ void sift_test1B() { #pragma omp parallel for for (int i = 1; i < vecsize; i++) { unsigned char mass[128]; - int j2=0; + int j2 = 0; #pragma omp critical { - input.read((char *) &in, 4); if (in != 128) { cout << "file error"; @@ -333,7 +337,7 @@ void sift_test1B() { mass[j] = massb[j]; } j1++; - j2=j1; + j2 = j1; if (j1 % report_every == 0) { cout << j1 / (0.01 * vecsize) << " %, " << report_every / (1000.0 * 1e-6 * stopw.getElapsedTimeMicro()) << " kips " << " Mem: " @@ -342,8 +346,6 @@ void sift_test1B() { } } appr_alg->addPoint((void *) (mass), (size_t) j2); - - } input.close(); cout << "Build time:" << 1e-6 * stopw_full.getElapsedTimeMicro() << " seconds\n"; @@ -360,6 +362,4 @@ void sift_test1B() { test_vs_recall(massQ, vecsize, qsize, *appr_alg, vecdim, answers, k); cout << "Actual memory usage: " << getCurrentRSS() / 1000000 << " Mb \n"; return; - - } diff --git a/sift_test.cpp b/tests/cpp/sift_test.cpp similarity index 88% rename from sift_test.cpp rename to tests/cpp/sift_test.cpp index c6718f50..decdf605 100644 --- a/sift_test.cpp +++ b/tests/cpp/sift_test.cpp @@ -2,7 +2,7 @@ #include #include #include -#include "hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include @@ -22,7 +22,7 @@ static void readBinaryPOD(istream& in, T& podRef) { }*/ class StopW { std::chrono::steady_clock::time_point time_begin; -public: + public: StopW() { time_begin = std::chrono::steady_clock::now(); } @@ -35,11 +35,17 @@ class StopW { void reset() { time_begin = std::chrono::steady_clock::now(); } - }; -void get_gt(float *mass, float *massQ, size_t vecsize, size_t qsize, L2Space &l2space, size_t vecdim, - vector>> &answers, size_t k) { +void get_gt( + float *mass, + float *massQ, + size_t vecsize, + size_t qsize, + L2Space &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { BruteforceSearch bs(&l2space, vecsize); for (int i = 0; i < vecsize; i++) { bs.addPoint((void *) (mass + vecdim * i), (size_t) i); @@ -53,9 +59,16 @@ void get_gt(float *mass, float *massQ, size_t vecsize, size_t qsize, L2Space &l2 } void -get_gt(unsigned int *massQA, float *massQ, float *mass, size_t vecsize, size_t qsize, L2Space &l2space, size_t vecdim, - vector>> &answers, size_t k) { - +get_gt( + unsigned int *massQA, + float *massQ, + float *mass, + size_t vecsize, + size_t qsize, + L2Space &l2space, + size_t vecdim, + vector>> &answers, + size_t k) { //answers.swap(vector>>(qsize)); (vector>>(qsize)).swap(answers); DISTFUNC fstdistfunc_ = l2space.get_dist_func(); @@ -69,13 +82,18 @@ get_gt(unsigned int *massQA, float *massQ, float *mass, size_t vecsize, size_t q } } -float test_approx(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +float test_approx( + float *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { size_t correct = 0; size_t total = 0; //#pragma omp parallel for for (int i = 0; i < qsize; i++) { - std::priority_queue> result = appr_alg.searchKnn(massQ + vecdim * i, 10); std::priority_queue> gt(answers[i]); unordered_set g; @@ -93,8 +111,14 @@ float test_approx(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW &appr_alg, size_t vecdim, - vector>> &answers, size_t k) { +void test_vs_recall( + float *massQ, + size_t vecsize, + size_t qsize, + HierarchicalNSW &appr_alg, + size_t vecdim, + vector>> &answers, + size_t k) { //vector efs = { 1,2,3,4,6,8,12,16,24,32,64,128,256,320 };// = ; { 23 }; vector efs; for (int i = 10; i < 30; i++) { @@ -121,7 +145,7 @@ void test_vs_recall(float *massQ, size_t vecsize, size_t qsize, HierarchicalNSW< } //void get_knn_quality(unsigned int *massA,size_t vecsize, size_t maxn, HierarchicalNSW &appr_alg) { // size_t total = 0; -// size_t correct = 0; +// size_t correct = 0; // for (int i = 0; i < vecsize; i++) { // int *data = (int *)(appr_alg.linkList0_ + i * appr_alg.size_links_per_element0_); // //cout << "numconn:" << *data<<"\n"; @@ -186,7 +210,7 @@ void sift_test() { //#define LOAD_I #ifdef LOAD_I - HierarchicalNSW appr_alg(&l2space, "hnswlib_sift",false); + HierarchicalNSW appr_alg(&l2space, "hnswlib_sift", false); //HierarchicalNSW appr_alg(&l2space, "D:/stuff/hnsw_lib/nmslib/similarity_search/release/temp",true); //HierarchicalNSW appr_alg(&l2space, "/mnt/d/stuff/hnsw_lib/nmslib/similarity_search/release/temp", true); @@ -243,7 +267,7 @@ void sift_test() { // // cout << appr_alg.maxlevel_ << "\n"; // //CHECK: -// //for (size_t io = 0; io < vecsize; io++) { +// //for (size_t io = 0; io < vecsize; io++) { // // if (appr_alg.getExternalLabel(io) != io) // // throw new exception("bad!"); // //} @@ -252,22 +276,22 @@ void sift_test() { // for (int i = 0; i < vecsize; i++) { // int *data = (int *)(appr_alg.linkList0_ + i * appr_alg.size_links_per_element0_); // //cout << "numconn:" << *data<<"\n"; -// tableint *datal = (tableint *)(data + 1); +// tableint *datal = (tableint *)(data + 1); // // std::priority_queue< std::pair< float, tableint >> rez; // unordered_set g; // for (int j = 0; j < *data; j++) { // g.insert(datal[j]); // } -// appr_alg.setEf(400); +// appr_alg.setEf(400); // std::priority_queue< std::pair< float, tableint >> closest_elements = appr_alg.searchKnnInternal(appr_alg.getDataByInternalId(i), 17); -// while (closest_elements.size() > 0) { +// while (closest_elements.size() > 0) { // if (closest_elements.top().second != i) { // g.insert(closest_elements.top().second); // } // closest_elements.pop(); // } -// +// // for (tableint l : g) { // float other = fstdistfunc_(appr_alg.getDataByInternalId(l), appr_alg.getDataByInternalId(i), l2space.get_dist_func_param()); // rez.emplace(other, l); @@ -285,18 +309,18 @@ void sift_test() { // } // // } -// +// // //get_knn_quality(massA, vecsize, maxn, appr_alg); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // /*test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k); // test_vs_recall( massQ, vecsize, qsize, appr_alg, vecdim, answers, k);*/ // -// +// // // // // /*for(int i=0;i<1000;i++) // cout << mass[i] << "\n";*/ // //("11", std::ios::binary); -} \ No newline at end of file +} diff --git a/examples/update_gen_data.py b/tests/cpp/update_gen_data.py similarity index 100% rename from examples/update_gen_data.py rename to tests/cpp/update_gen_data.py diff --git a/examples/updates_test.cpp b/tests/cpp/updates_test.cpp similarity index 78% rename from examples/updates_test.cpp rename to tests/cpp/updates_test.cpp index c8775877..52e1fa14 100644 --- a/examples/updates_test.cpp +++ b/tests/cpp/updates_test.cpp @@ -1,27 +1,26 @@ -#include "../hnswlib/hnswlib.h" +#include "../../hnswlib/hnswlib.h" #include -class StopW -{ + + +class StopW { std::chrono::steady_clock::time_point time_begin; -public: - StopW() - { + public: + StopW() { time_begin = std::chrono::steady_clock::now(); } - float getElapsedTimeMicro() - { + float getElapsedTimeMicro() { std::chrono::steady_clock::time_point time_end = std::chrono::steady_clock::now(); return (std::chrono::duration_cast(time_end - time_begin).count()); } - void reset() - { + void reset() { time_begin = std::chrono::steady_clock::now(); } }; + /* * replacement for the openmp '#pragma omp parallel for' directive * only handles a subset of functionality (no reductions etc) @@ -81,22 +80,18 @@ inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn std::rethrow_exception(lastException); } } - - } template -std::vector load_batch(std::string path, int size) -{ +std::vector load_batch(std::string path, int size) { std::cout << "Loading " << path << "..."; // float or int32 (python) assert(sizeof(datatype) == 4); std::ifstream file; - file.open(path); - if (!file.is_open()) - { + file.open(path, std::ios::binary); + if (!file.is_open()) { std::cout << "Cannot open " << path << "\n"; exit(1); } @@ -107,29 +102,21 @@ std::vector load_batch(std::string path, int size) return batch; } + template static float test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, - std::vector> &answers, size_t K) -{ + std::vector> &answers, size_t K) { size_t correct = 0; size_t total = 0; - //uncomment to test in parallel mode: - - - for (int i = 0; i < qsize; i++) - { + for (int i = 0; i < qsize; i++) { std::priority_queue> result = appr_alg.searchKnn((char *)(queries.data() + vecdim * i), K); total += K; - while (result.size()) - { - if (answers[i].find(result.top().second) != answers[i].end()) - { + while (result.size()) { + if (answers[i].find(result.top().second) != answers[i].end()) { correct++; - } - else - { + } else { } result.pop(); } @@ -137,30 +124,34 @@ test_approx(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW< return 1.0f * correct / total; } + static void -test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalNSW &appr_alg, size_t vecdim, - std::vector> &answers, size_t k) -{ +test_vs_recall( + std::vector &queries, + size_t qsize, + hnswlib::HierarchicalNSW &appr_alg, + size_t vecdim, + std::vector> &answers, + size_t k) { + std::vector efs = {1}; - for (int i = k; i < 30; i++) - { + for (int i = k; i < 30; i++) { efs.push_back(i); } - for (int i = 30; i < 400; i+=10) - { + for (int i = 30; i < 400; i+=10) { efs.push_back(i); } - for (int i = 1000; i < 100000; i += 5000) - { + for (int i = 1000; i < 100000; i += 5000) { efs.push_back(i); } std::cout << "ef\trecall\ttime\thops\tdistcomp\n"; - for (size_t ef : efs) - { + + bool test_passed = false; + for (size_t ef : efs) { appr_alg.setEf(ef); - appr_alg.metric_hops=0; - appr_alg.metric_distance_computations=0; + appr_alg.metric_hops = 0; + appr_alg.metric_distance_computations = 0; StopW stopw = StopW(); float recall = test_approx(queries, qsize, appr_alg, vecdim, answers, k); @@ -168,45 +159,41 @@ test_vs_recall(std::vector &queries, size_t qsize, hnswlib::HierarchicalN float distance_comp_per_query = appr_alg.metric_distance_computations / (1.0f * qsize); float hops_per_query = appr_alg.metric_hops / (1.0f * qsize); - std::cout << ef << "\t" << recall << "\t" << time_us_per_query << "us \t"< 0.99) - { - std::cout << "Recall is over 0.99! "< 0.99) { + test_passed = true; + std::cout << "Recall is over 0.99! " << recall << "\t" << time_us_per_query << "us \t" << hops_per_query << "\t" << distance_comp_per_query << "\n"; break; } } + if (!test_passed) { + std::cerr << "Test failed\n"; + exit(1); + } } -int main(int argc, char **argv) -{ +int main(int argc, char **argv) { int M = 16; int efConstruction = 200; int num_threads = std::thread::hardware_concurrency(); - - bool update = false; - if (argc == 2) - { - if (std::string(argv[1]) == "update") - { + if (argc == 2) { + if (std::string(argv[1]) == "update") { update = true; std::cout << "Updates are on\n"; - } - else { - std::cout<<"Usage ./test_updates [update]\n"; + } else { + std::cout << "Usage ./test_updates [update]\n"; exit(1); } - } - else if (argc>2){ - std::cout<<"Usage ./test_updates [update]\n"; + } else if (argc > 2) { + std::cout << "Usage ./test_updates [update]\n"; exit(1); } - std::string path = "../examples/data/"; - + std::string path = "../tests/cpp/data/"; int N; int dummy_data_multiplier; @@ -216,8 +203,7 @@ int main(int argc, char **argv) { std::ifstream configfile; configfile.open(path + "/config.txt"); - if (!configfile.is_open()) - { + if (!configfile.is_open()) { std::cout << "Cannot open config.txt\n"; return 1; } @@ -237,11 +223,9 @@ int main(int argc, char **argv) StopW stopw = StopW(); - if (update) - { + if (update) { std::cout << "Update iteration 0\n"; - ParallelFor(1, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); @@ -252,14 +236,13 @@ int main(int argc, char **argv) }); appr_alg.checkIntegrity(); - for (int b = 1; b < dummy_data_multiplier; b++) - { + for (int b = 1; b < dummy_data_multiplier; b++) { std::cout << "Update iteration " << b << "\n"; char cpath[1024]; sprintf(cpath, "batch_dummy_%02d.bin", b); std::vector dummy_batchb = load_batch(path + cpath, N * d); - - ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { + + ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(dummy_batch.data() + i * d), i); }); appr_alg.checkIntegrity(); @@ -268,31 +251,28 @@ int main(int argc, char **argv) std::cout << "Inserting final elements\n"; std::vector final_batch = load_batch(path + "batch_final.bin", N * d); - + stopw.reset(); ParallelFor(0, N, num_threads, [&](size_t i, size_t threadId) { appr_alg.addPoint((void *)(final_batch.data() + i * d), i); }); - std::cout<<"Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; + std::cout << "Finished. Time taken:" << stopw.getElapsedTimeMicro()*1e-6 << " s\n"; std::cout << "Running tests\n"; std::vector queries_batch = load_batch(path + "queries.bin", N_queries * d); std::vector gt = load_batch(path + "gt.bin", N_queries * K); std::vector> answers(N_queries); - for (int i = 0; i < N_queries; i++) - { - for (int j = 0; j < K; j++) - { + for (int i = 0; i < N_queries; i++) { + for (int j = 0; j < K; j++) { answers[i].insert(gt[i * K + j]); } } - for (int i = 0; i < 3; i++) - { + for (int i = 0; i < 3; i++) { std::cout << "Test iteration " << i << "\n"; test_vs_recall(queries_batch, N_queries, appr_alg, d, answers, K); } return 0; -}; \ No newline at end of file +} diff --git a/python_bindings/tests/bindings_test.py b/tests/python/bindings_test.py similarity index 100% rename from python_bindings/tests/bindings_test.py rename to tests/python/bindings_test.py diff --git a/tests/python/bindings_test_filter.py b/tests/python/bindings_test_filter.py new file mode 100644 index 00000000..480c8dcd --- /dev/null +++ b/tests/python/bindings_test_filter.py @@ -0,0 +1,57 @@ +import os +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + + dim = 16 + num_elements = 10000 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initiating index + # max_elements - the maximum number of elements, should be known beforehand + # (probably will be made optional in the future) + # + # ef_construction - controls index search speed/build speed tradeoff + # M - is tightly connected with internal dimensionality of the data + # strongly affects the memory consumption + + hnsw_index.init_index(max_elements=num_elements, ef_construction=100, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(10) + + hnsw_index.set_num_threads(4) # by default using all available cores + + print("Adding %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + # Query the elements for themselves and measure recall: + labels, distances = hnsw_index.knn_query(data, k=1) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), 1.0, 3) + + print("Querying only even elements") + # Query the even elements for themselves and measure recall: + filter_function = lambda id: id%2 == 0 + # Warning: search with a filter works slow in python in multithreaded mode, therefore we set num_threads=1 + labels, distances = hnsw_index.knn_query(data, k=1, num_threads=1, filter=filter_function) + self.assertAlmostEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5, 3) + # Verify that there are only even elements: + self.assertTrue(np.max(np.mod(labels, 2)) == 0) + + labels, distances = bf_index.knn_query(data, k=1, filter=filter_function) + self.assertEqual(np.mean(labels.reshape(-1) == np.arange(len(data))), .5) diff --git a/python_bindings/tests/bindings_test_getdata.py b/tests/python/bindings_test_getdata.py similarity index 100% rename from python_bindings/tests/bindings_test_getdata.py rename to tests/python/bindings_test_getdata.py diff --git a/python_bindings/tests/bindings_test_labels.py b/tests/python/bindings_test_labels.py similarity index 87% rename from python_bindings/tests/bindings_test_labels.py rename to tests/python/bindings_test_labels.py index 2b091371..524a24d5 100644 --- a/python_bindings/tests/bindings_test_labels.py +++ b/tests/python/bindings_test_labels.py @@ -95,19 +95,20 @@ def testRandomSelf(self): # Delete data1 labels1_deleted, _ = p.knn_query(data1, k=1) - - for l in labels1_deleted: - p.mark_deleted(l[0]) + # delete probable duplicates from nearest neighbors + labels1_deleted_no_dup = set(labels1_deleted.flatten()) + for l in labels1_deleted_no_dup: + p.mark_deleted(l) labels2, _ = p.knn_query(data2, k=1) items = p.get_items(labels2) diff_with_gt_labels = np.mean(np.abs(data2-items)) - self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) # console + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) print("All the data in data1 are removed") # Checking saving/loading index with elements marked as deleted @@ -119,13 +120,13 @@ def testRandomSelf(self): labels1_after, _ = p.knn_query(data1, k=1) for la in labels1_after: - for lb in labels1_deleted: - if la[0] == lb[0]: - self.assertTrue(False) + if la[0] in labels1_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search after index loading") + self.assertTrue(False) # Unmark deleted data - for l in labels1_deleted: - p.unmark_deleted(l[0]) + for l in labels1_deleted_no_dup: + p.unmark_deleted(l) labels_restored, _ = p.knn_query(data1, k=1) self.assertAlmostEqual(np.mean(labels_restored.reshape(-1) == np.arange(len(data1))), 1.0, 3) print("All the data in data1 are restored") diff --git a/python_bindings/tests/bindings_test_metadata.py b/tests/python/bindings_test_metadata.py similarity index 100% rename from python_bindings/tests/bindings_test_metadata.py rename to tests/python/bindings_test_metadata.py diff --git a/python_bindings/tests/bindings_test_pickle.py b/tests/python/bindings_test_pickle.py similarity index 100% rename from python_bindings/tests/bindings_test_pickle.py rename to tests/python/bindings_test_pickle.py diff --git a/tests/python/bindings_test_recall.py b/tests/python/bindings_test_recall.py new file mode 100644 index 00000000..2190ba45 --- /dev/null +++ b/tests/python/bindings_test_recall.py @@ -0,0 +1,100 @@ +import os +import hnswlib +import numpy as np +import unittest + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + dim = 32 + num_elements = 100000 + k = 10 + num_queries = 20 + + recall_threshold = 0.95 + + # Generating sample data + data = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + # Initing both hnsw and brute force indices + # max_elements - the maximum number of elements (capacity). Will throw an exception if exceeded + # during insertion of an element. + # The capacity can be increased by saving/loading the index, see below. + # + # hnsw construction params: + # ef_construction - controls index search speed/build speed tradeoff + # + # M - is tightly connected with internal dimensionality of the data. Strongly affects the memory consumption (~M) + # Higher M leads to higher accuracy/run_time at fixed ef/efConstruction + + hnsw_index.init_index(max_elements=num_elements, ef_construction=200, M=16) + bf_index.init_index(max_elements=num_elements) + + # Controlling the recall for hnsw by setting ef: + # higher ef leads to better accuracy, but slower search + hnsw_index.set_ef(200) + + # Set number of threads used during batch search/construction in hnsw + # By default using all available cores + hnsw_index.set_num_threads(4) + + print("Adding batch of %d elements" % (len(data))) + hnsw_index.add_items(data) + bf_index.add_items(data) + + print("Indices built") + + # Generating query data + query_data = np.float32(np.random.random((num_queries, dim))) + + # Query the elements and measure recall: + labels_hnsw, distances_hnsw = hnsw_index.knn_query(query_data, k) + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct = 0 + for i in range(num_queries): + for label in labels_hnsw[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct += 1 + break + + recall_before = float(correct) / (k*num_queries) + print("recall is :", recall_before) + self.assertGreater(recall_before, recall_threshold) + + # test serializing the brute force index + index_path = 'bf_index.bin' + print("Saving index to '%s'" % index_path) + bf_index.save_index(index_path) + del bf_index + + # Re-initiating, loading the index + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + + print("\nLoading index from '%s'\n" % index_path) + bf_index.load_index(index_path) + + # Query the brute force index again to verify that we get the same results + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct = 0 + for i in range(num_queries): + for label in labels_hnsw[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct += 1 + break + + recall_after = float(correct) / (k*num_queries) + print("recall after reloading is :", recall_after) + + self.assertEqual(recall_before, recall_after) + + os.remove(index_path) diff --git a/tests/python/bindings_test_replace.py b/tests/python/bindings_test_replace.py new file mode 100644 index 00000000..80003a3a --- /dev/null +++ b/tests/python/bindings_test_replace.py @@ -0,0 +1,245 @@ +import os +import pickle +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + """ + Tests if replace of deleted elements works correctly + Tests serialization of the index with replaced elements + """ + dim = 16 + num_elements = 5000 + max_num_elements = 2 * num_elements + + recall_threshold = 0.98 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # batch 4 + first_id += num_elements + last_id += num_elements + labels4 = np.arange(first_id, last_id) + data4 = np.float32(np.random.random((num_elements, dim))) + + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(4) + + # Add batch 1 and 2 + print("Adding batch 1") + hnsw_index.add_items(data1, labels1) + print("Adding batch 2") + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + print("Deleting neighbors of batch 2") + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted.flatten()) + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + for la in labels2_after: + if la[0] in labels2_deleted_no_dup: + print(f"Found deleted label {la[0]} during knn search") + self.assertTrue(False) + print("All the neighbors of data2 are removed") + + # Replace deleted elements + print("Inserting batch 3 by replacing deleted elements") + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # After replacing, all labels should be retrievable + print("Checking that remaining labels are in index") + # Get remaining data from batch 1 and batch 2 after deletion of elements + remaining_labels = (set(labels1) | set(labels2)) - labels2_deleted_no_dup + remaining_labels_list = list(remaining_labels) + comb_data = np.concatenate((data1, data2), axis=0) + remaining_data = comb_data[remaining_labels_list] + + returned_items = hnsw_index.get_items(remaining_labels_list) + self.assertSequenceEqual(remaining_data.tolist(), returned_items) + + returned_items = hnsw_index.get_items(labels3_tr) + self.assertSequenceEqual(data3_tr.tolist(), returned_items) + + # Check index serialization + # Delete batch 3 + print("Deleting batch 3") + for l in labels3_tr: + hnsw_index.mark_deleted(l) + + # Save index + index_path = "index.bin" + print(f"Saving index to {index_path}") + hnsw_index.save_index(index_path) + del hnsw_index + + # Reinit and load the index + hnsw_index = hnswlib.Index(space='l2', dim=dim) # the space can be changed - keeps the data, alters the distance function. + hnsw_index.set_num_threads(4) + print(f"Loading index from {index_path}") + hnsw_index.load_index(index_path, max_elements=max_num_elements, allow_replace_deleted=True) + + # Insert batch 4 + print("Inserting batch 4 by replacing deleted elements") + labels4_tr = labels4[0:labels4.shape[0] - num_duplicates] + data4_tr = data4[0:data4.shape[0] - num_duplicates] + hnsw_index.add_items(data4_tr, labels4_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index.knn_query(data4_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels4_tr) + print(f"Recall for the 4 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + # Delete batch 4 + print("Deleting batch 4") + for l in labels4_tr: + hnsw_index.mark_deleted(l) + + print("Testing pickle serialization") + hnsw_index_pckl = pickle.loads(pickle.dumps(hnsw_index)) + del hnsw_index + # Insert batch 3 + print("Inserting batch 3 by replacing deleted elements") + hnsw_index_pckl.add_items(data3_tr, labels3_tr, replace_deleted=True) + + # Check recall + print("Checking recall") + labels_found, _ = hnsw_index_pckl.knn_query(data3_tr, k=1) + recall = np.mean(labels_found.reshape(-1) == labels3_tr) + print(f"Recall for the 3 batch: {recall}") + self.assertGreater(recall, recall_threshold) + + os.remove(index_path) + + + def test_recall_degradation(self): + """ + Compares recall of the index with replaced elements and without + Measures recall degradation + """ + dim = 16 + num_elements = 10_000 + max_num_elements = 2 * num_elements + query_size = 1_000 + k = 100 + + recall_threshold = 0.98 + max_recall_diff = 0.02 + + # Generating sample data + print("Generating data") + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + # query to test recall + query_data = np.float32(np.random.random((query_size, dim))) + + # Declaring index + hnsw_index_no_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_no_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=False) + hnsw_index_with_replace = hnswlib.Index(space='l2', dim=dim) + hnsw_index_with_replace.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + bf_index = hnswlib.BFIndex(space='l2', dim=dim) + bf_index.init_index(max_elements=max_num_elements) + + hnsw_index_no_replace.set_ef(100) + hnsw_index_no_replace.set_num_threads(50) + hnsw_index_with_replace.set_ef(100) + hnsw_index_with_replace.set_num_threads(50) + + # Add data + print("Adding data") + hnsw_index_with_replace.add_items(data1, labels1) + hnsw_index_with_replace.add_items(data2, labels2) # maximum number of elements is reached + bf_index.add_items(data1, labels1) + bf_index.add_items(data3, labels3) # maximum number of elements is reached + + for l in labels2: + hnsw_index_with_replace.mark_deleted(l) + hnsw_index_with_replace.add_items(data3, labels3, replace_deleted=True) + + hnsw_index_no_replace.add_items(data1, labels1) + hnsw_index_no_replace.add_items(data3, labels3) # maximum number of elements is reached + + # Query the elements and measure recall: + labels_hnsw_with_replace, _ = hnsw_index_with_replace.knn_query(query_data, k) + labels_hnsw_no_replace, _ = hnsw_index_no_replace.knn_query(query_data, k) + labels_bf, distances_bf = bf_index.knn_query(query_data, k) + + # Measure recall + correct_with_replace = 0 + correct_no_replace = 0 + for i in range(query_size): + for label in labels_hnsw_with_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_with_replace += 1 + break + for label in labels_hnsw_no_replace[i]: + for correct_label in labels_bf[i]: + if label == correct_label: + correct_no_replace += 1 + break + + recall_with_replace = float(correct_with_replace) / (k*query_size) + recall_no_replace = float(correct_no_replace) / (k*query_size) + print("recall with replace:", recall_with_replace) + print("recall without replace:", recall_no_replace) + + recall_diff = abs(recall_with_replace - recall_with_replace) + + self.assertGreater(recall_no_replace, recall_threshold) + self.assertLess(recall_diff, max_recall_diff) diff --git a/python_bindings/tests/bindings_test_resize.py b/tests/python/bindings_test_resize.py similarity index 100% rename from python_bindings/tests/bindings_test_resize.py rename to tests/python/bindings_test_resize.py diff --git a/python_bindings/tests/bindings_test_spaces.py b/tests/python/bindings_test_spaces.py similarity index 100% rename from python_bindings/tests/bindings_test_spaces.py rename to tests/python/bindings_test_spaces.py diff --git a/tests/python/bindings_test_stress_mt_replace.py b/tests/python/bindings_test_stress_mt_replace.py new file mode 100644 index 00000000..8cd3e9bc --- /dev/null +++ b/tests/python/bindings_test_stress_mt_replace.py @@ -0,0 +1,68 @@ +import unittest + +import numpy as np + +import hnswlib + + +class RandomSelfTestCase(unittest.TestCase): + def testRandomSelf(self): + dim = 16 + num_elements = 1_000 + max_num_elements = 2 * num_elements + + # Generating sample data + # batch 1 + first_id = 0 + last_id = num_elements + labels1 = np.arange(first_id, last_id) + data1 = np.float32(np.random.random((num_elements, dim))) + # batch 2 + first_id += num_elements + last_id += num_elements + labels2 = np.arange(first_id, last_id) + data2 = np.float32(np.random.random((num_elements, dim))) + # batch 3 + first_id += num_elements + last_id += num_elements + labels3 = np.arange(first_id, last_id) + data3 = np.float32(np.random.random((num_elements, dim))) + + for _ in range(100): + # Declaring index + hnsw_index = hnswlib.Index(space='l2', dim=dim) + hnsw_index.init_index(max_elements=max_num_elements, ef_construction=200, M=16, allow_replace_deleted=True) + + hnsw_index.set_ef(100) + hnsw_index.set_num_threads(50) + + # Add batch 1 and 2 + hnsw_index.add_items(data1, labels1) + hnsw_index.add_items(data2, labels2) # maximum number of elements is reached + + # Delete nearest neighbors of batch 2 + labels2_deleted, _ = hnsw_index.knn_query(data2, k=1) + labels2_deleted_flat = labels2_deleted.flatten() + # delete probable duplicates from nearest neighbors + labels2_deleted_no_dup = set(labels2_deleted_flat) + for l in labels2_deleted_no_dup: + hnsw_index.mark_deleted(l) + labels1_found, _ = hnsw_index.knn_query(data1, k=1) + items = hnsw_index.get_items(labels1_found) + diff_with_gt_labels = np.mean(np.abs(data1 - items)) + self.assertAlmostEqual(diff_with_gt_labels, 0, delta=1e-3) + + labels2_after, _ = hnsw_index.knn_query(data2, k=1) + labels2_after_flat = labels2_after.flatten() + common = np.intersect1d(labels2_after_flat, labels2_deleted_flat) + self.assertTrue(common.size == 0) + + # Replace deleted elements + # Maximum number of elements is reached therefore we cannot add new items + # but we can replace the deleted ones + # Note: there may be less than num_elements elements. + # As we could delete less than num_elements because of duplicates + num_duplicates = len(labels2_deleted) - len(labels2_deleted_no_dup) + labels3_tr = labels3[0:labels3.shape[0] - num_duplicates] + data3_tr = data3[0:data3.shape[0] - num_duplicates] + hnsw_index.add_items(data3_tr, labels3_tr, replace_deleted=True) diff --git a/tests/python/git_tester.py b/tests/python/git_tester.py new file mode 100644 index 00000000..1f9c2ba7 --- /dev/null +++ b/tests/python/git_tester.py @@ -0,0 +1,52 @@ +import os +import shutil + +from sys import platform +from pydriller import Repository + + +speedtest_src_path = os.path.join("tests", "python", "speedtest.py") +speedtest_copy_path = os.path.join("tests", "python", "speedtest2.py") +shutil.copyfile(speedtest_src_path, speedtest_copy_path) # the file has to be outside of git + +commits = list(Repository('.', from_tag="v0.6.2").traverse_commits()) +print("Found commits:") +for idx, commit in enumerate(commits): + name = commit.msg.replace('\n', ' ').replace('\r', ' ') + print(idx, commit.hash, name) + +for commit in commits: + name = commit.msg.replace('\n', ' ').replace('\r', ' ').replace(",", ";") + print("\nProcessing", commit.hash, name) + + if os.path.exists("build"): + shutil.rmtree("build") + os.system(f"git checkout {commit.hash}") + + # Checking we have actually switched the branch: + current_commit=list(Repository('.').traverse_commits())[-1] + if current_commit.hash != commit.hash: + print("git checkout failed!!!!") + print("git checkout failed!!!!") + print("git checkout failed!!!!") + print("git checkout failed!!!!") + continue + + print("\n\n--------------------\n\n") + ret = os.system("python -m pip install .") + print("Install result:", ret) + + if ret != 0: + print("build failed!!!!") + print("build failed!!!!") + print("build failed!!!!") + print("build failed!!!!") + continue + + # os.system(f'python {speedtest_copy_path} -n "{hash[:4]}_{name}" -d 32 -t 1') + os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 1') + os.system(f'python {speedtest_copy_path} -n "{commit.hash[:4]}_{name}" -d 16 -t 64') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 64 -t 1') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 1') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 4 -t 24') + # os.system(f'python {speedtest_copy_path} -n "{name}" -d 128 -t 24') diff --git a/tests/python/speedtest.py b/tests/python/speedtest.py new file mode 100644 index 00000000..8d16cfc3 --- /dev/null +++ b/tests/python/speedtest.py @@ -0,0 +1,65 @@ +import hnswlib +import numpy as np +import os.path +import time +import argparse + +# Use nargs to specify how many arguments an option should take. +ap = argparse.ArgumentParser() +ap.add_argument('-d') +ap.add_argument('-n') +ap.add_argument('-t') +args = ap.parse_args() +dim = int(args.d) +name = args.n +threads=int(args.t) +num_elements = 400000 + +# Generating sample data +np.random.seed(1) +data = np.float32(np.random.random((num_elements, dim))) + + +# index_path=f'speed_index{dim}.bin' +# Declaring index +p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip + +# if not os.path.isfile(index_path) : + +p.init_index(max_elements=num_elements, ef_construction=60, M=16) + +# Controlling the recall by setting ef: +# higher ef leads to better accuracy, but slower search +p.set_ef(10) + +# Set number of threads used during batch search/construction +# By default using all available cores +p.set_num_threads(64) +t0=time.time() +p.add_items(data) +construction_time=time.time()-t0 +# Serializing and deleting the index: + +# print("Saving index to '%s'" % index_path) +# p.save_index(index_path) +p.set_num_threads(threads) +times=[] +time.sleep(1) +p.set_ef(15) +for _ in range(1): + # p.load_index(index_path) + for _ in range(3): + t0=time.time() + qdata=data[:5000*threads] + labels, distances = p.knn_query(qdata, k=1) + tt=time.time()-t0 + times.append(tt) + recall=np.sum(labels.reshape(-1)==np.arange(len(qdata)))/len(qdata) + print(f"{tt} seconds, recall= {recall}") + +str_out=f"{np.mean(times)}, {np.median(times)}, {np.std(times)}, {construction_time}, {recall}, {name}" +print(str_out) +with open (f"log2_{dim}_t{threads}.txt","a") as f: + f.write(str_out+"\n") + f.flush() +