Skip to content

Commit

Permalink
Support Async cache making for DiskANN
Browse files Browse the repository at this point in the history
Signed-off-by: Li Liu <[email protected]>
  • Loading branch information
liliu-z committed Dec 5, 2023
1 parent 645c810 commit 4341cf9
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 131 deletions.
12 changes: 10 additions & 2 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,18 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
LOG_KNOWHERE_ERROR_ << "Failed to generate bfs cache for DiskANN.";
return Status::diskann_inner_error;
}

if (node_list.size() > 0) {
if (TryDiskANNCall([&]() { pq_flash_index_->load_cache_list(node_list); }) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to load cache for DiskANN.";
return Status::diskann_inner_error;
}
}
} else {
LOG_KNOWHERE_INFO_ << "Use sample_queries to generate cache list";
if (TryDiskANNCall([&]() {
pq_flash_index_->generate_cache_list_from_sample_queries(warmup_query_file, 15, 6,
num_nodes_to_cache, node_list);
pq_flash_index_->async_generate_cache_list_from_sample_queries(warmup_query_file, 15, 6,
num_nodes_to_cache);
}) != Status::success) {
LOG_KNOWHERE_ERROR_ << "Failed to generate cache from sample queries for DiskANN.";
return Status::diskann_inner_error;
Expand Down Expand Up @@ -507,6 +514,7 @@ DiskANNIndexNode<T>::Deserialize(const BinarySet& binset, const Config& cfg) {
}

is_prepared_.store(true);
LOG_KNOWHERE_INFO_ << "End of diskann loading.";
return Status::success;
}

Expand Down
56 changes: 42 additions & 14 deletions thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

#pragma once
#include <cassert>
#include <condition_variable>
#include <future>
#include <mutex>
#include <optional>
#include <sstream>
#include <stack>
Expand Down Expand Up @@ -32,6 +34,20 @@
#define FULL_PRECISION_REORDER_MULTIPLIER 3

namespace diskann {
class ThreadSafeStateController {
public:
enum class Status {
NONE,
DOING,
STOPPING,
DONE,
KILLED,
};

Status status;
std::condition_variable cond;
std::mutex status_mtx;
};
template<typename T>
struct QueryScratch {
T *coord_scratch = nullptr; // MUST BE AT LEAST [sizeof(T) * data_dim]
Expand Down Expand Up @@ -75,20 +91,20 @@ namespace diskann {
uint32_t num_threads, const char *index_prefix);
#else
// load compressed data, and obtains the handle to the disk-resident index
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
DISKANN_DLLEXPORT int load(uint32_t num_threads, const char *index_prefix);
#endif

DISKANN_DLLEXPORT void load_cache_list(std::vector<uint32_t> &node_list);

// asynchronously collect the access frequency of each node in the graph
#ifdef EXEC_ENV_OLS
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
MemoryMappedFiles &files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache,
std::vector<uint32_t> &node_list);
DISKANN_DLLEXPORT void async_generate_cache_list_from_sample_queries(
MemoryMappedFiles files, std::string sample_bin, _u64 l_search,
_u64 beamwidth, _u64 num_nodes_to_cache);
#else
DISKANN_DLLEXPORT void generate_cache_list_from_sample_queries(
DISKANN_DLLEXPORT void async_generate_cache_list_from_sample_queries(
std::string sample_bin, _u64 l_search, _u64 beamwidth,
_u64 num_nodes_to_cache, std::vector<uint32_t> &node_list);
_u64 num_nodes_to_cache);
#endif

DISKANN_DLLEXPORT void cache_bfs_levels(_u64 num_nodes_to_cache,
Expand All @@ -100,8 +116,7 @@ namespace diskann {
const bool use_reorder_data = false, QueryStats *stats = nullptr,
const knowhere::feder::diskann::FederResultUniq &feder = nullptr,
knowhere::BitsetView bitset_view = nullptr,
const float filter_ratio = -1.0f,
const bool for_tuning = false);
const float filter_ratio = -1.0f, const bool for_tuning = false);

DISKANN_DLLEXPORT _u32 range_search(
const T *query1, const double range, const _u64 min_l_search,
Expand All @@ -110,8 +125,9 @@ namespace diskann {
const float l_k_ratio, knowhere::BitsetView bitset_view = nullptr,
QueryStats *stats = nullptr);

DISKANN_DLLEXPORT void get_vector_by_ids(
const int64_t *ids, const int64_t n, T *const output_data);
DISKANN_DLLEXPORT void get_vector_by_ids(const int64_t *ids,
const int64_t n,
T *const output_data);

std::shared_ptr<AlignedFileReader> reader;

Expand All @@ -129,6 +145,9 @@ namespace diskann {

DISKANN_DLLEXPORT _u64 cal_size();

// for async cache making task
DISKANN_DLLEXPORT void destroy_cache_async_task();

protected:
DISKANN_DLLEXPORT void use_medoids_data_as_centroids();
DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads);
Expand Down Expand Up @@ -205,8 +224,14 @@ namespace diskann {
_u64 aligned_dim = 0;
_u64 disk_bytes_per_point = 0;

std::string disk_index_file;
std::vector<std::pair<_u32, _u32>> node_visit_counter;
std::string disk_index_file;
std::shared_mutex node_visit_counter_mtx;
std::unique_ptr<std::vector<std::pair<_u32, std::atomic<_u32>>>>
node_visit_counter_ptr;
std::atomic<_u32> search_counter = 0;

std::shared_ptr<ThreadSafeStateController> state_controller =
std::make_shared<ThreadSafeStateController>();

// PQ data
// n_chunks = # of chunks ndims is split into
Expand Down Expand Up @@ -255,6 +280,9 @@ namespace diskann {
// closest centroid as the starting point of search
float *centroid_data = nullptr;

// cache
std::shared_mutex cache_mtx;

// nhood_cache
unsigned *nhood_cache_buf = nullptr;
tsl::robin_map<_u32, std::pair<_u32, _u32 *>>
Expand All @@ -268,7 +296,7 @@ namespace diskann {
ConcurrentQueue<ThreadData<T>> thread_data;
_u64 max_nthreads;
bool load_flag = false;
bool count_visited_nodes = false;
std::atomic<bool> count_visited_nodes = false;
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;

Expand Down
Loading

0 comments on commit 4341cf9

Please sign in to comment.