diff --git a/include/svs/core/distance/cosine.h b/include/svs/core/distance/cosine.h index c8e41e7..53c9ec6 100644 --- a/include/svs/core/distance/cosine.h +++ b/include/svs/core/distance/cosine.h @@ -306,9 +306,26 @@ template struct CosineSimilarityImpl { template struct CosineSimilarityImpl { SVS_NOINLINE static float compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic length) { - auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>(), a, b, length); + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); return sum / (std::sqrt(norm) * a_norm); } }; + +template struct CosineSimilarityImpl { + SVS_NOINLINE static float + compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic length) { + auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length); + return sum / (std::sqrt(norm) * a_norm); + } +}; + #endif } // namespace svs::distance diff --git a/include/svs/index/inverted/common.h b/include/svs/index/inverted/common.h index 88c6625..6bc3395 100644 --- a/include/svs/index/inverted/common.h +++ b/include/svs/index/inverted/common.h @@ -40,7 +40,12 @@ template inline T bound_with(T nearest, T epsilon, svs::DistanceL2) } template inline T bound_with(T nearest, T epsilon, svs::DistanceIP) { - // TODO: What do we do if the best match is simply bad? + assert(nearest > 0.0f); + return nearest / (1 + epsilon); +} + +template +inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) { assert(nearest > 0.0f); return nearest / (1 + epsilon); }