Skip to content

Commit

Permalink
Fix TypedIndex rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
markkohdev committed Aug 19, 2024
1 parent 0532250 commit 598e9e6
Showing 1 changed file with 4 additions and 29 deletions.
33 changes: 4 additions & 29 deletions cpp/src/TypedIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,29 +53,6 @@ template <> const std::string storageDataTypeName<int8_t>() { return "Float8"; }
template <> const std::string storageDataTypeName<float>() { return "Float32"; }
template <> const std::string storageDataTypeName<E4M3>() { return "E4M3"; }

template <typename dist_t, typename data_t>
dist_t ensureNotNegative(dist_t distance, hnswlib::labeltype label) {
if constexpr (std::is_same_v<data_t, E4M3>) {
// Allow for a very slight negative distance if using E4M3
if (distance < 0 && distance >= -0.14) {
return 0;
}
}

if (distance < 0) {
if (distance >= -0.00001) {
return 0;
}

throw std::runtime_error(
"Potential candidate (with label '" + std::to_string(label) +
"') had negative distance " + std::to_string(distance) +
". This may indicate a corrupted index file.");
}

return distance;
}

/**
* A C++ wrapper class for a typed HNSW index.
*
Expand Down Expand Up @@ -402,7 +379,7 @@ class TypedIndex : public Index {
floatToDataType<data_t, scalefactor>(&inputArray[startIndex],
&convertedArray[startIndex],
actualDimensions);
size_t id = ids.size() ? ids.at(row) : (currentLabel + row);
size_t id = ids.size() ? ids.at(row) : (currentLabel.fetch_add(1));
try {
algorithmImpl->addPoint(convertedArray.data() + startIndex, id);
} catch (IndexFullError &e) {
Expand Down Expand Up @@ -438,7 +415,7 @@ class TypedIndex : public Index {
normalizeVector<dist_t, data_t, scalefactor>(
&inputArray[startIndex], &normalizedArray[startIndex],
actualDimensions);
size_t id = ids.size() ? ids.at(row) : (currentLabel + row);
size_t id = ids.size() ? ids.at(row) : (currentLabel.fetch_add(1));

try {
algorithmImpl->addPoint(normalizedArray.data() + startIndex, id);
Expand Down Expand Up @@ -629,8 +606,7 @@ class TypedIndex : public Index {
dist_t distance = result_tuple.first;
hnswlib::labeltype label = result_tuple.second;

distancePointer[row * k + i] =
ensureNotNegative<dist_t, data_t>(distance, label);
distancePointer[row * k + i] = distance;
labelPointer[row * k + i] = label;
result.pop();
}
Expand Down Expand Up @@ -704,8 +680,7 @@ class TypedIndex : public Index {
for (int i = k - 1; i >= 0; i--) {
auto &result_tuple = result.top();

distancePointer[i] = ensureNotNegative<dist_t, data_t>(
result_tuple.first, result_tuple.second);
distancePointer[i] = result_tuple.first;
labelPointer[i] = result_tuple.second;
result.pop();
}
Expand Down

0 comments on commit 598e9e6

Please sign in to comment.