From 46a8a59080a6fd3bee42b795c455582938f8e65b Mon Sep 17 00:00:00 2001 From: Andreas Huber <9201869+ahuber21@users.noreply.github.com> Date: Thu, 20 Feb 2025 09:10:34 +0100 Subject: [PATCH] feat: add remaining Float16 specializations for simd calculation (#78) This PR adds the remaining `Float16` specializations for `CosineSimilarityImpl`, resulting in SIMD ops being used instead of the generic reference implementation in some of the test cases. --- include/svs/core/distance/cosine.h | 19 ++++++++++++++++++- include/svs/index/inverted/common.h | 7 ++++++- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/include/svs/core/distance/cosine.h b/include/svs/core/distance/cosine.h index c8e41e7..53c9ec6 100644 --- a/include/svs/core/distance/cosine.h +++ b/include/svs/core/distance/cosine.h @@ -306,9 +306,26 @@ template struct CosineSimilarityImpl { template struct CosineSimilarityImpl { SVS_NOINLINE static float compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic length) { - auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>(), a, b, length); + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); return sum / (std::sqrt(norm) * a_norm); } }; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + #endif } // namespace svs::distance diff --git a/include/svs/index/inverted/common.h b/include/svs/index/inverted/common.h index 88c6625..6bc3395 100644 --- a/include/svs/index/inverted/common.h +++ b/include/svs/index/inverted/common.h @@ -40,7 +40,12 @@ template inline T bound_with(T nearest, T epsilon, svs::DistanceL2) } template inline T bound_with(T nearest, T epsilon, svs::DistanceIP) { - // TODO: What do we do if the best match is simply bad? + assert(nearest > 0.0f); + return nearest / (1 + epsilon); +} + +template +inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) { assert(nearest > 0.0f); return nearest / (1 + epsilon); }