From e83573cf9f5d237aa3188067f34e141a9dcb523f Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Wed, 5 Feb 2025 03:06:45 -0800 Subject: [PATCH 1/3] feat: add remaining Float16 specializations for simd calculation --- include/svs/core/distance/cosine.h | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/include/svs/core/distance/cosine.h b/include/svs/core/distance/cosine.h index c8e41e7..53c9ec6 100644 --- a/include/svs/core/distance/cosine.h +++ b/include/svs/core/distance/cosine.h @@ -306,9 +306,26 @@ template struct CosineSimilarityImpl { template struct CosineSimilarityImpl { SVS_NOINLINE static float compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic length) { - auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>(), a, b, length); + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); return sum / (std::sqrt(norm) * a_norm); } }; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + #endif } // namespace svs::distance From c5c9bc97c47b16dcb2f59d8ad496ba6c613334ec Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Fri, 7 Feb 2025 06:58:49 -0800 Subject: [PATCH 2/3] copy/paste bound_with for CS --- include/svs/index/inverted/common.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/include/svs/index/inverted/common.h b/include/svs/index/inverted/common.h index 88c6625..f769ac5 100644 --- a/include/svs/index/inverted/common.h +++ b/include/svs/index/inverted/common.h @@ -45,4 +45,11 @@ template inline T bound_with(T nearest, T epsilon, svs::DistanceIP) return nearest / (1 + epsilon); } +template +inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) { + // TODO: This is just copy/paste from DistanceIP - is it correct? + assert(nearest > 0.0f); + return nearest / (1 + epsilon); +} + } // namespace svs::index::inverted From 7bece3ef45393a98a7001b6cc241cf32605e6ce8 Mon Sep 17 00:00:00 2001 From: Andreas Huber Date: Wed, 19 Feb 2025 02:15:34 -0800 Subject: [PATCH 3/3] chore: cleanup TODO comments --- include/svs/index/inverted/common.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/include/svs/index/inverted/common.h b/include/svs/index/inverted/common.h index f769ac5..6bc3395 100644 --- a/include/svs/index/inverted/common.h +++ b/include/svs/index/inverted/common.h @@ -40,14 +40,12 @@ template inline T bound_with(T nearest, T epsilon, svs::DistanceL2) } template inline T bound_with(T nearest, T epsilon, svs::DistanceIP) { - // TODO: What do we do if the best match is simply bad? assert(nearest > 0.0f); return nearest / (1 + epsilon); } template inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) { - // TODO: This is just copy/paste from DistanceIP - is it correct? assert(nearest > 0.0f); return nearest / (1 + epsilon); }