Skip to content

Commit

Permalink
feat: add remaining Float16 specializations for simd calculation (#78)
Browse files Browse the repository at this point in the history
This PR adds the remaining `Float16` specializations for
`CosineSimilarityImpl`, resulting in SIMD ops being used instead of the
generic reference implementation in some of the test cases.
  • Loading branch information
ahuber21 authored Feb 20, 2025
1 parent a88c8e6 commit 46a8a59
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
19 changes: 18 additions & 1 deletion include/svs/core/distance/cosine.h
Original file line number Diff line number Diff line change
Expand Up @@ -306,9 +306,26 @@ template <size_t N> struct CosineSimilarityImpl<N, float, int8_t> {
template <size_t N> struct CosineSimilarityImpl<N, float, Float16> {
SVS_NOINLINE static float
compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>(), a, b, length);
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
return sum / (std::sqrt(norm) * a_norm);
}
};

template <size_t N> struct CosineSimilarityImpl<N, Float16, float> {
SVS_NOINLINE static float
compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic<N> length) {
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
return sum / (std::sqrt(norm) * a_norm);
}
};

template <size_t N> struct CosineSimilarityImpl<N, Float16, Float16> {
SVS_NOINLINE static float
compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
return sum / (std::sqrt(norm) * a_norm);
}
};

#endif
} // namespace svs::distance
7 changes: 6 additions & 1 deletion include/svs/index/inverted/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ template <typename T> inline T bound_with(T nearest, T epsilon, svs::DistanceL2)
}

template <typename T> inline T bound_with(T nearest, T epsilon, svs::DistanceIP) {
// TODO: What do we do if the best match is simply bad?
assert(nearest > 0.0f);
return nearest / (1 + epsilon);
}

template <typename T>
inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) {
assert(nearest > 0.0f);
return nearest / (1 + epsilon);
}
Expand Down

0 comments on commit 46a8a59

Please sign in to comment.