Skip to content

Commit

Permalink
add read/write lock in cache map
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Dec 3, 2023
1 parent d63c403 commit c935820
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 12 deletions.
13 changes: 8 additions & 5 deletions thirdparty/DiskANN/include/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <sstream>
#include <stack>
#include <string>
#include <mutex>
#include <shared_mutex>
#include "knowhere/common/lru_cache.h"
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"
Expand Down Expand Up @@ -232,12 +234,13 @@ namespace diskann {
// nhood_cache
unsigned * nhood_cache_buf = nullptr;
tsl::robin_map<_u32, std::pair<_u32, _u32 *>> nhood_cache; // <id, <neihbors_num, neihbors>>

mutable std::shared_mutex nhood_mtx;
// coord_cache
T * coord_cache_buf = nullptr;
tsl::robin_map<_u32, T *> coord_cache;
Semaphore semaph;
std::atomic<bool> async_generate_cache = false;
T * coord_cache_buf = nullptr;
tsl::robin_map<_u32, T *> coord_cache;
Semaphore semaph;
std::atomic<bool> async_generate_cache = false;
mutable std::shared_mutex coord_mtx;

// thread-specific scratch
ConcurrentQueue<ThreadData<T>> thread_data;
Expand Down
34 changes: 27 additions & 7 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,24 @@ namespace {
1.1982775974217197,
0.5);
};

template<typename Value>
inline void
insert_cache_safety(tsl::robin_map<_u32, Value>& cache, std::pair<_u32, Value>&& entry, std::shared_mutex& mtx) {
std::unique_lock<std::shared_mutex> lock(mtx);
cache.insert(std::move(entry));
}

template<typename Value>
inline std::optional<Value>
find_in_cache_safety(const tsl::robin_map<_u32, Value>& cache, _u32 id, std::shared_mutex& mtx) {
std::shared_lock<std::shared_mutex> lock(mtx);
if (cache.find(id) == cache.end()) {
return std::nullopt;
} else {
return cache[id];
}
}
} // namespace

namespace diskann {
Expand Down Expand Up @@ -274,7 +292,7 @@ namespace diskann {
T * node_coords = OFFSET_TO_NODE_COORDS(node_buf);
T * cached_coords = coord_cache_buf + node_idx * aligned_dim;
memcpy(cached_coords, node_coords, disk_bytes_per_point);
coord_cache.insert(std::make_pair(nhood.first, cached_coords));
insert_cache_safety(coord_cache, std::make_pair(nhood.first, cached_coords), coord_mtx);

// insert node nhood into nhood_cache
unsigned *node_nhood = OFFSET_TO_NODE_NHOOD(node_buf);
Expand All @@ -285,7 +303,7 @@ namespace diskann {
cnhood.first = nnbrs;
cnhood.second = nhood_cache_buf + node_idx * (max_degree + 1);
memcpy(cnhood.second, nbrs, nnbrs * sizeof(unsigned));
nhood_cache.insert(std::make_pair(nhood.first, cnhood));
insert_cache_safety(nhood_cache, std::make_pair(nhood.first, cnhood), nhood_mtx);
aligned_free(nhood.second);
node_idx++;
}
Expand Down Expand Up @@ -934,8 +952,9 @@ namespace diskann {
const auto [dist, id] = opt.value();

// check if in cache
if (coord_cache.find(id) != coord_cache.end()) {
float dist = dist_cmp(query, coord_cache.at(id), (size_t) aligned_dim);
auto cache_v = find_in_cache_safety(coord_cache, id, coord_mtx);
if (cache_v.has_value()) {
float dist = dist_cmp(query, cache_v.value(), (size_t) aligned_dim);
max_heap.Push(dist, id);
continue;
}
Expand Down Expand Up @@ -1164,10 +1183,10 @@ namespace diskann {
num_seen < beam_width) {
if (retset[marker].flag) {
num_seen++;
auto iter = nhood_cache.find(retset[marker].id);
if (iter != nhood_cache.end()) {
auto cache_v = find_in_cache_safety(nhood_cache, retset[marker].id, nhood_mtx);
if (cache_v.has_value()) {
cached_nhoods.push_back(
std::make_pair(retset[marker].id, iter->second));
std::make_pair(retset[marker].id, cache_v.value()));
if (stats != nullptr) {
stats->n_cache_hits++;
}
Expand Down Expand Up @@ -1225,6 +1244,7 @@ namespace diskann {

// process cached nhoods
for (auto &cached_nhood : cached_nhoods) {
// don't need to use safety access, because id must be in coord_cache if it in nhood_cache,
auto global_cache_iter = coord_cache.find(cached_nhood.first);
T * node_fp_coords_copy = global_cache_iter->second;
float cur_expanded_dist;
Expand Down

0 comments on commit c935820

Please sign in to comment.