Skip to content

Commit

Permalink
Add support of multi vector in jni
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 29, 2023
1 parent 7b47bae commit d6b2116
Show file tree
Hide file tree
Showing 19 changed files with 933 additions and 27 deletions.
23 changes: 20 additions & 3 deletions jni/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,21 @@ if (${CONFIG_FAISS} STREQUAL ON OR ${CONFIG_ALL} STREQUAL ON OR ${CONFIG_TEST} S
set(FAISS_ENABLE_PYTHON OFF)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/external/faiss EXCLUDE_FROM_ALL)

add_library(${TARGET_LIB_FAISS} SHARED ${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp ${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp)
add_library(
${TARGET_LIB_FAISS} SHARED
${CMAKE_CURRENT_SOURCE_DIR}/src/org_opensearch_knn_jni_FaissService.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/faiss_wrapper.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/utils/BitSet.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollector.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/knn_extension/faiss/MultiVectorResultCollectorFactory.cpp)
target_link_libraries(${TARGET_LIB_FAISS} faiss ${TARGET_LIB_COMMON} OpenMP::OpenMP_CXX)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include $ENV{JAVA_HOME}/include $ENV{JAVA_HOME}/include/${JVM_OS_TYPE} ${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
target_include_directories(${TARGET_LIB_FAISS} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/include
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss
${CMAKE_CURRENT_SOURCE_DIR}/include/knn_extension/faiss/utils
$ENV{JAVA_HOME}/include
$ENV{JAVA_HOME}/include/${JVM_OS_TYPE}
${CMAKE_CURRENT_SOURCE_DIR}/external/faiss)
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES SUFFIX ${LIB_EXT})
set_target_properties(${TARGET_LIB_FAISS} PROPERTIES POSITION_INDEPENDENT_CODE ON)

Expand Down Expand Up @@ -198,7 +210,12 @@ if ("${WIN32}" STREQUAL "")
jni_test
tests/faiss_wrapper_test.cpp
tests/nmslib_wrapper_test.cpp
tests/test_util.cpp)
tests/test_util.cpp
tests/knn_extension/faiss/utils/BitSetTest.cpp
tests/knn_extension/faiss/utils/HeapTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorTest.cpp
tests/knn_extension/faiss/MultiVectorResultCollectorFactoryTest.cpp
)

target_link_libraries(
jni_test
Expand Down
5 changes: 2 additions & 3 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#define OPENSEARCH_KNN_FAISS_WRAPPER_H

#include "jni_util.h"

#include <jni.h>

namespace knn_jni {
Expand All @@ -38,13 +37,13 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ);
jfloatArray queryVectorJ, jint kJ, jintArray parentIdsJ);

// Execute a query against the index located in memory at indexPointerJ along with Filters
//
// Return an array of KNNQueryResults
jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ);
jfloatArray queryVectorJ, jint kJ, jintArray filterIdsJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
41 changes: 41 additions & 0 deletions jni/include/knn_extension/faiss/MultiVectorResultCollector.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollector.h>
#include <faiss/MetricType.h>
#include "knn_extension/faiss/utils/BitSet.h"
#include <unordered_map>

namespace os_faiss {

using idx_t = faiss::idx_t;
/**
* Implementation of ResultCollector to support multi vector
*
* By using parent_bit_set, it convert a doc id to its parent doc id and store the parend doc id
* while collecting search result. Using group_id_to_index, it de-duplicates result from same parent
* doc. Once all results are collected, post_process method is called where it converts parent doc id
* to its original id using group_id_to_id.
*/
struct MultiVectorResultCollector:faiss::ResultCollector {
std::unordered_map<idx_t, idx_t> group_id_to_id;
std::unordered_map<idx_t, size_t> group_id_to_index;
BitSet* parent_bit_set;
// mapping data from Faiss ID to Lucene ID
const std::vector<int64_t>* id_map;
MultiVectorResultCollector(BitSet* parent_bit_set, const std::vector<int64_t>* id_map);
void collect(
int k,
int& nres,
float* bh_val,
int64_t* bh_ids,
float val,
int64_t ids) override;
void post_process(int64_t nres, int64_t* bh_ids) override;
};

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/impl/ResultCollectorFactory.h>
#include "knn_extension/faiss/utils/BitSet.h"

namespace os_faiss {
/**
* Create MultiVectorResultCollector for single query request
*
* Creating new collector is required because MultiVectorResultCollector has instance variables
* which should be isolated for each query.
*/
struct MultiVectorResultCollectorFactory:faiss::ResultCollectorFactory {
BitSet* parent_bit_set;

MultiVectorResultCollectorFactory(BitSet* parent_bit_set);
faiss::ResultCollector* new_collector() override;
void delete_collector(faiss::ResultCollector* resultCollector) override;
};
}
47 changes: 47 additions & 0 deletions jni/include/knn_extension/faiss/utils/BitSet.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <faiss/MetricType.h>
#include <faiss/impl/platform_macros.h>
#include <limits>

using idx_t = faiss::idx_t;
/**
* This class is used to store parent and child doc id mapping
*
* For example, let's say there are two documents with 3 nested field each. Then, lucene store each nested field as
* individual document with its own doc id. The document ids are assigned as following.
*
* 0, 1, 2, 3(parent doc for 0, 1, 2), 4, 5, 6, 7(parent doc for 4, 5, 6)
*
* Therefore, we can represent the value in BitSet like 10001000 where parent doc id position is set as 1
* and child doc id position is set as 0. Finally, by using nextSetBit method, we can find parent ID of a
* given document ID.
*/
class BitSet {
protected:
const int NO_MORE_DOCS = std::numeric_limits<int>::max();
public:
virtual idx_t nextSetBit(idx_t index) = 0;
virtual ~BitSet() = default;
};


/**
* BitSet implementation by using an array of unit64
*/
class FixedBitSet : public BitSet {
public:
size_t n;
// using uint64_t to leverage function __builtin_ctzll which is defined in faiss/impl/platform_macros.h
uint64_t* bitmap;

public:
FixedBitSet(const int* intArray, const int length);
idx_t nextSetBit(idx_t index) override;
~FixedBitSet();
};
Loading

0 comments on commit d6b2116

Please sign in to comment.