Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix CAGRA graph optimization bug #565

Open
wants to merge 22 commits into
base: branch-25.02
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 121 additions & 36 deletions cpp/src/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include <omp.h>
#include <sys/time.h>

#include <cassert>
#include <climits>
#include <iostream>
#include <memory>
Expand Down Expand Up @@ -994,14 +993,18 @@ void mst_optimization(raft::resources const& res,
total_incoming_edges += incoming_num_edges_ptr[i];
}

bool check_num_mst_edges = true;
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
if (outgoing_num_edges_ptr[i] < outgoing_max_edges_ptr[i]) continue;
if (outgoing_num_edges_ptr[i] + incoming_num_edges_ptr[i] == mst_graph_degree) continue;
assert(outgoing_num_edges_ptr[i] + incoming_num_edges_ptr[i] < mst_graph_degree);
if (outgoing_num_edges_ptr[i] + incoming_num_edges_ptr[i] > mst_graph_degree) {
check_num_mst_edges = false;
}
outgoing_max_edges_ptr[i] += 1;
incoming_max_edges_ptr[i] = mst_graph_degree - outgoing_max_edges_ptr[i];
}
RAFT_EXPECTS(check_num_mst_edges, "Some nodes have too many MST graph edges.");
}

// 6. Show stats
Expand All @@ -1018,8 +1021,9 @@ void mst_optimization(raft::resources const& res,
}
RAFT_LOG_DEBUG("%s", msg.c_str());
}
assert(num_clusters > 0);
assert(total_outgoing_edges == total_incoming_edges);
RAFT_EXPECTS(num_clusters > 0, "No clusters could not be created in MST optimization.");
RAFT_EXPECTS(total_outgoing_edges == total_incoming_edges,
"The numbers of incoming and outcoming edges are mismatch.");
if (num_clusters == 1) { break; }
num_clusters_pre = num_clusters;
}
Expand Down Expand Up @@ -1084,16 +1088,19 @@ void optimize(
auto output_graph_ptr = new_graph.data_handle();

// MST optimization
auto mst_graph = raft::make_host_matrix<IdxT, int64_t, raft::row_major>(0, 0);
auto mst_graph_num_edges = raft::make_host_vector<uint32_t, int64_t>(graph_size);
auto mst_graph_num_edges_ptr = mst_graph_num_edges.data_handle();
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
mst_graph_num_edges_ptr[i] = 0;
}
if (guarantee_connectivity) {
mst_graph =
raft::make_host_matrix<IdxT, int64_t, raft::row_major>(graph_size, output_graph_degree);
RAFT_LOG_INFO("MST optimization is used to guarantee graph connectivity.");
constexpr bool use_gpu = true;
mst_optimization(res, knn_graph, new_graph, mst_graph_num_edges.view(), use_gpu);
mst_optimization(res, knn_graph, mst_graph.view(), mst_graph_num_edges.view(), use_gpu);

for (uint64_t i = 0; i < graph_size; i++) {
if (i < 8 || i >= graph_size - 8) {
Expand All @@ -1102,7 +1109,6 @@ void optimize(
}
}

auto pruned_graph = raft::make_host_matrix<uint32_t, int64_t>(graph_size, output_graph_degree);
{
//
// Prune kNN graph
Expand Down Expand Up @@ -1191,13 +1197,14 @@ void optimize(
const auto num_full = host_stats.data_handle()[1];

// Create pruned kNN graph
bool invalid_neighbor_list = false;
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
// Find the `output_graph_degree` smallest detourable count nodes by checking the detourable
// count of the neighbors while increasing the target detourable count from zero.
uint64_t pk = 0;
uint32_t num_detour = 0;
while (pk < output_graph_degree) {
for (uint32_t l = 0; l < input_graph_degree && pk < output_graph_degree; l++) {
uint32_t next_num_detour = std::numeric_limits<uint32_t>::max();
for (uint64_t k = 0; k < input_graph_degree; k++) {
const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)];
Expand All @@ -1208,23 +1215,45 @@ void optimize(

// Store the neighbor index if its detourable count is equal to `num_detour`.
if (num_detour_k != num_detour) { continue; }
output_graph_ptr[pk + (output_graph_degree * i)] =
input_graph_ptr[k + (input_graph_degree * i)];
pk += 1;

// Check duplication and append
const auto candidate_node = input_graph_ptr[k + (input_graph_degree * i)];
bool dup = false;
for (uint32_t dk = 0; dk < pk; dk++) {
if (candidate_node == output_graph_ptr[i * output_graph_degree + dk]) {
dup = true;
break;
}
}
if (!dup && candidate_node < graph_size) {
output_graph_ptr[i * output_graph_degree + pk] = candidate_node;
pk += 1;
}
if (pk >= output_graph_degree) break;
}
if (pk >= output_graph_degree) break;

assert(next_num_detour != std::numeric_limits<uint32_t>::max());
if (next_num_detour == std::numeric_limits<uint32_t>::max()) {
// There are no valid edges enough in the initial kNN graph. Break the loop here and catch
// the error at the next validation (pk != output_graph_degree).
break;
}
num_detour = next_num_detour;
}
RAFT_EXPECTS(
pk == output_graph_degree,
"Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
i);
if (pk != output_graph_degree) {
RAFT_LOG_DEBUG(
"Couldn't find the output_graph_degree (%lu) smallest detourable count nodes for "
"node %lu in the rank-based node reranking process",
output_graph_degree,
i);
invalid_neighbor_list = true;
}
}
RAFT_EXPECTS(
!invalid_neighbor_list,
"Could not generate an intermediate CAGRA graph because the initial kNN graph contains too "
"many invalid or duplicated neighbor nodes. This error can occur, for example, if too many "
"overflows occur during the norm computation between the dataset vectors.");

const double time_prune_end = cur_time();
RAFT_LOG_DEBUG(
Expand Down Expand Up @@ -1311,37 +1340,51 @@ void optimize(
//
const double time_replace_start = cur_time();

bool check_num_protected_edges = true;
#pragma omp parallel for
for (uint64_t i = 0; i < graph_size; i++) {
auto my_fwd_graph = pruned_graph.data_handle() + (output_graph_degree * i);
auto my_rev_graph = rev_graph.data_handle() + (output_graph_degree * i);
auto my_out_graph = output_graph_ptr + (output_graph_degree * i);
uint32_t kf = 0;
uint32_t k = mst_graph_num_edges_ptr[i];

const auto num_protected_edges = std::max<uint64_t>(k, output_graph_degree / 2);
assert(num_protected_edges <= output_graph_degree);
if (num_protected_edges == output_graph_degree) continue;
// If guarantee_connectivity == true, use a temporal list to merge the neighbor lists of the
// graphs.
std::vector<IdxT> temp_output_neighbor_list;
if (guarantee_connectivity) {
temp_output_neighbor_list.resize(output_graph_degree);
my_out_graph = temp_output_neighbor_list.data();
const auto mst_graph_num_edges = mst_graph_num_edges_ptr[i];

// Set MST graph edges
for (uint32_t j = 0; j < mst_graph_num_edges; j++) {
my_out_graph[j] = mst_graph(i, j);
}

// Append edges from the pruned graph to output graph
while (k < output_graph_degree && kf < output_graph_degree) {
if (my_fwd_graph[kf] < graph_size) {
auto flag_match = false;
for (uint32_t kk = 0; kk < k; kk++) {
if (my_out_graph[kk] == my_fwd_graph[kf]) {
flag_match = true;
// Set pruned graph edges
for (uint32_t pruned_j = 0, output_j = mst_graph_num_edges;
(pruned_j < output_graph_degree) && (output_j < output_graph_degree);
pruned_j++) {
const auto v = output_graph_ptr[output_graph_degree * i + pruned_j];

// duplication check
bool dup = false;
for (uint32_t m = 0; m < output_j; m++) {
if (v == my_out_graph[m]) {
dup = true;
break;
}
}
if (!flag_match) {
my_out_graph[k] = my_fwd_graph[kf];
k += 1;

if (!dup) {
my_out_graph[output_j] = v;
output_j++;
}
}
kf += 1;
}
assert(k == output_graph_degree);
assert(kf <= output_graph_degree);

const auto num_protected_edges =
std::max<uint64_t>(mst_graph_num_edges_ptr[i], output_graph_degree / 2);
if (num_protected_edges > output_graph_degree) { check_num_protected_edges = false; }
if (num_protected_edges == output_graph_degree) continue;

// Replace some edges of the output graph with edges of the reverse graph.
auto kr = std::min<uint32_t>(rev_graph_count.data_handle()[i], output_graph_degree);
Expand All @@ -1358,7 +1401,19 @@ void optimize(
my_out_graph[num_protected_edges] = my_rev_graph[kr];
}
}

// If guarantee_connectivity == true, move the output neighbor list from the temporal list to
// the output list. If false, the copy is not needed because my_out_graph is a pointer to the
// output buffer.
if (guarantee_connectivity) {
for (uint32_t j = 0; j < output_graph_degree; j++) {
output_graph_ptr[(output_graph_degree * i) + j] = my_out_graph[j];
}
}
}
RAFT_EXPECTS(check_num_protected_edges,
"Failed to merge the MST, pruned, and reverse edge graphs. Some nodes have too "
"many MST optimization edges.");

const double time_replace_end = cur_time();
RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start);
Expand Down Expand Up @@ -1419,6 +1474,36 @@ void optimize(
(double)sum_hist / graph_size);
}
}

// Check duplication and out-of-range indices
{
uint64_t num_dup = 0;
uint64_t num_oor = 0;
#pragma omp parallel for reduction(+ : num_dup) reduction(+ : num_oor)
for (uint64_t i = 0; i < graph_size; i++) {
auto my_out_graph = output_graph_ptr + (output_graph_degree * i);
for (uint32_t j = 0; j < output_graph_degree; j++) {
const auto neighbor_a = my_out_graph[j];

// Check oor
if (neighbor_a > graph_size) {
num_oor++;
continue;
}

// Check duplication
for (uint32_t k = j + 1; k < output_graph_degree; k++) {
const auto neighbor_b = my_out_graph[k];
if (neighbor_a == neighbor_b) { num_dup++; }
}
}
}
RAFT_EXPECTS(
num_dup == 0, "%lu duplicated node(s) are found in the generated CAGRA graph", num_dup);
RAFT_EXPECTS(num_oor == 0,
"%lu out-of-range index node(s) are found in the generated CAGRA graph",
num_oor);
}
}

} // namespace cuvs::neighbors::cagra::detail::graph
10 changes: 8 additions & 2 deletions cpp/test/neighbors/ann_cagra/bug_extreme_inputs_oob.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ class cagra_extreme_inputs_oob_test : public ::testing::Test {
ix_ps.graph_degree = 64;
ix_ps.intermediate_graph_degree = 128;

[[maybe_unused]] auto ix = cagra::build(res, ix_ps, raft::make_const_mdspan(dataset->view()));
raft::resource::sync_stream(res);
try {
[[maybe_unused]] auto ix = cagra::build(res, ix_ps, raft::make_const_mdspan(dataset->view()));
raft::resource::sync_stream(res);
} catch (const std::exception&) {
SUCCEED();
return;
}
FAIL();
}

void SetUp() override
Expand Down
Loading