Skip to content

Commit

Permalink
Set norm to 1.0 for all-0 vectors
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <[email protected]>
  • Loading branch information
cydrain committed Aug 30, 2024
1 parent 613f06f commit 83a7cea
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 10 deletions.
5 changes: 5 additions & 0 deletions thirdparty/DiskANN/include/diskann/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,11 @@ namespace diskann {

for (auto& norm : norms) {
norm = std::sqrt(norm);
// handle all-0 vectors
if (norm == 0.0) {
LOG(INFO) << "all-0 vector, set norm to 1.0";
norm = 1.0;
}
}

in_reader.seekg(2 * sizeof(_u32), std::ios::beg);
Expand Down
1 change: 1 addition & 0 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2990,6 +2990,7 @@ namespace diskann {
for (unsigned i = 0; i < _nd; i++) {
char *cur_node_offset = _opt_graph + i * _node_size;
float cur_norm = norm_l2sqr(_data + i * _aligned_dim, _aligned_dim);
cur_norm = (cur_norm == 0.0 ? 1.0 : cur_norm);
std::memcpy(cur_node_offset, &cur_norm, sizeof(float));
std::memcpy(cur_node_offset + sizeof(float), _data + i * _aligned_dim,
_data_len - sizeof(float));
Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/IVFlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ void ivf_residual_add_from_flat_codes(
// ok
index->rq.decode(tmp_code.data(), tmp.data(), 1);
float norm = fvec_norm_L2sqr(tmp.data(), rq.d);
norm = (norm == 0.0 ? 1.0 : norm);
wr.write(rq.encode_norm(norm), rq.norm_bits);

// add code to the inverted list
Expand Down
2 changes: 2 additions & 0 deletions thirdparty/faiss/faiss/IndexAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ struct AQDistanceComputerLUT : FlatCodesDistanceComputer {
bias = 0;
} else {
bias = fvec_norm_L2sqr(x, d);
bias = (bias == 0.0 ? 1.0 : bias);
}
}

Expand Down Expand Up @@ -174,6 +175,7 @@ void search_with_LUT(
if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to
// add ||x||^2
bias = fvec_norm_L2sqr(xq + q * d, d);
bias = (bias == 0.0 ? 1.0 : bias);
}
for (size_t i = 0; i < ntotal; i++) {
float dis = aq.compute_1_distance_LUT<is_IP, st>(
Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/IndexFlat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ struct FlatL2WithNormsDis : FlatCodesDistanceComputer {
void set_query(const float* x) override {
q = x;
query_l2norm = fvec_norm_L2sqr(q, d);
query_l2norm = (query_l2norm == 0.0 ? 1.0 : query_l2norm);
}

// compute four distances
Expand Down
1 change: 1 addition & 0 deletions thirdparty/faiss/faiss/IndexIVFAdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ struct AQInvertedListScannerLUT : AQInvertedListScanner {
AQInvertedListScanner::set_query(query_vector);
if (!is_IP && !ia.by_residual) {
distance_bias = fvec_norm_L2sqr(query_vector, ia.d);
distance_bias = (distance_bias == 0.0 ? 1.0 : distance_bias);
}
}

Expand Down
7 changes: 4 additions & 3 deletions thirdparty/faiss/faiss/IndexIVFPQ.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,10 @@ void initialize_IVFPQ_precomputed_table(
// squared norms of the PQ centroids
std::vector<float> r_norms(pq.M * pq.ksub, NAN);
for (int m = 0; m < pq.M; m++)
for (int j = 0; j < pq.ksub; j++)
r_norms[m * pq.ksub + j] =
fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
for (int j = 0; j < pq.ksub; j++) {
float norm = fvec_norm_L2sqr(pq.get_centroids(m, j), pq.dsub);
r_norms[m * pq.ksub + j] = (norm == 0.0 ? 1.0 : norm);
}

if (use_precomputed_table == 1) {
precomputed_table.resize(nlist * pq.M * pq.ksub);
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/faiss/faiss/impl/AdditiveQuantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
#pragma omp for
for (int64_t i = 0; i < ntotal; i++) {
decode_64bit(i, tmp.data());
norms[i] = fvec_norm_L2sqr(tmp.data(), d);
float norm = fvec_norm_L2sqr(tmp.data(), d);
norms[i] = (norm == 0.0 ? 1.0 : norm);
}
}
}
Expand Down
17 changes: 11 additions & 6 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ void fvec_norms_L2(
size_t nx) {
#pragma omp parallel for if (nx > 10000)
for (int64_t i = 0; i < nx; i++) {
nr[i] = sqrtf(fvec_norm_L2sqr(x + i * d, d));
auto norm = fvec_norm_L2sqr(x + i * d, d);
norm = (norm == 0.0 ? 1.0 : norm);
nr[i] = sqrtf(norm);
}
}

Expand All @@ -75,8 +77,10 @@ void fvec_norms_L2sqr(
size_t d,
size_t nx) {
#pragma omp parallel for if (nx > 10000)
for (int64_t i = 0; i < nx; i++)
nr[i] = fvec_norm_L2sqr(x + i * d, d);
for (int64_t i = 0; i < nx; i++) {
float norm = fvec_norm_L2sqr(x + i * d, d);
nr[i] = (norm == 0.0 ? 1.0 : norm);
}
}

// The following is a workaround to a problem
Expand Down Expand Up @@ -460,11 +464,11 @@ void exhaustive_cosine_seq_impl(

// the lambda that applies a filtered element.
auto apply = [&resi, y, y_norms, d](const float ip, const idx_t j) {
const float norm =
float norm =
(y_norms != nullptr) ?
y_norms[j] :
sqrtf(fvec_norm_L2sqr(y + j * d, d));

norm = (norm == 0.0 ? 1.0 : norm);
resi.add_result(ip / norm, j);
};

Expand Down Expand Up @@ -1448,10 +1452,11 @@ void knn_cosine_by_idx(
break;
}
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
const float norm =
float norm =
(y_norms != nullptr) ?
y_norms[idsi[j]] :
sqrtf(fvec_norm_L2sqr(y + d * idsi[j], d));
norm = (norm == 0.0 ? 1.0 : norm);
ip /= norm;

if (ip > simi[0]) {
Expand Down

0 comments on commit 83a7cea

Please sign in to comment.