diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c6e1390e..84bd4df87 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,27 +21,15 @@ exclude: '^thirdparty' fail_fast: True repos: # prepare-clang-tidy and clang-tidy run only in local machine, not on CI - - repo: local - hooks: - - id: prepare-clang-tidy - name: Prepare clang-tidy - entry: bash -c 'scripts/prepare_clang_tidy.sh' - language: system - types: [shell] - always_run: true - stages: [pre-push] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: check-added-large-files - id: check-merge-conflict - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/pocc/pre-commit-hooks - rev: master + rev: v1.3.5 hooks: - id: clang-format args: [-style=file] - - id: clang-tidy - args: ["--config-file=.clang-tidy", "-p=build/compile_commands.json"] - stages: [pre-push] diff --git a/cmake/libs/libfaiss.cmake b/cmake/libs/libfaiss.cmake index a78f2685f..9187f3d52 100644 --- a/cmake/libs/libfaiss.cmake +++ b/cmake/libs/libfaiss.cmake @@ -31,7 +31,7 @@ if(__X86_64) endif() if(__AARCH64) - set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc) + set(UTILS_SRC src/simd/hook.cc src/simd/distances_ref.cc src/simd/distances_neon.cc) add_library(knowhere_utils STATIC ${UTILS_SRC}) target_link_libraries(knowhere_utils PUBLIC glog::glog) endif() diff --git a/src/simd/distances_neon.cc b/src/simd/distances_neon.cc new file mode 100644 index 000000000..cb304927f --- /dev/null +++ b/src/simd/distances_neon.cc @@ -0,0 +1,639 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. +#if defined(__ARM_NEON) +#pragma GCC optimize("O3,fast-math,inline") +#include "distances_neon.h" + +#include +#include +namespace faiss { +float +fvec_inner_product_neon(const float* x, const float* y, size_t d) { + float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); + float32x4x4_t c; + c.val[0] = vmulq_f32(a.val[0], b.val[0]); + c.val[1] = vmulq_f32(a.val[1], b.val[1]); + c.val[2] = vmulq_f32(a.val[2], b.val[2]); + c.val[3] = vmulq_f32(a.val[3], b.val[3]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + c.val[2] = vaddq_f32(c.val[2], c.val[3]); + c.val[0] = vaddq_f32(c.val[0], c.val[2]); + + sum_ = vaddq_f32(sum_, c.val[0]); + + d -= 16; + } + + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); + float32x4x2_t c; + c.val[0] = vmulq_f32(a.val[0], b.val[0]); + c.val[1] = vmulq_f32(a.val[1], b.val[1]); + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + sum_ = vaddq_f32(sum_, c.val[0]); + d -= 8; + } + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); + float32x4_t c; + c = vmulq_f32(a, b); + sum_ = vaddq_f32(sum_, c); + d -= 4; + } + + float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); + d -= 1; + } + + if (d >= 2) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 1); + res_y = vld1q_lane_f32(y + dim - d, res_y, 1); + d -= 1; + } + + if (d >= 1) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 0); + res_y = vld1q_lane_f32(y + dim - d, res_y, 0); + d -= 1; + } + + sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_y)); + + return vaddvq_f32(sum_); +} + +float +fvec_L2sqr_neon(const float* x, const float* y, size_t d) { + float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); + float32x4x4_t c; + + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + c.val[2] = vsubq_f32(a.val[2], b.val[2]); + c.val[3] = vsubq_f32(a.val[3], b.val[3]); + + c.val[0] = vmulq_f32(c.val[0], c.val[0]); + c.val[1] = vmulq_f32(c.val[1], c.val[1]); + c.val[2] = vmulq_f32(c.val[2], c.val[2]); + c.val[3] = vmulq_f32(c.val[3], c.val[3]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + c.val[2] = vaddq_f32(c.val[2], c.val[3]); + c.val[0] = vaddq_f32(c.val[0], c.val[2]); + + sum_ = vaddq_f32(sum_, c.val[0]); + + d -= 16; + } + + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); + float32x4x2_t c; + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + + c.val[0] = vmulq_f32(c.val[0], c.val[0]); + c.val[1] = vmulq_f32(c.val[1], c.val[1]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + sum_ = vaddq_f32(sum_, c.val[0]); + d -= 8; + } + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); + float32x4_t c; + c = vsubq_f32(a, b); + c = vmulq_f32(c, c); + + sum_ = vaddq_f32(sum_, c); + d -= 4; + } + + float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); + d -= 1; + } + + if (d >= 2) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 1); + res_y = vld1q_lane_f32(y + dim - d, res_y, 1); + d -= 1; + } + + if (d >= 1) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 0); + res_y = vld1q_lane_f32(y + dim - d, res_y, 0); + d -= 1; + } + + sum_ = vaddq_f32(sum_, vmulq_f32(vsubq_f32(res_x, res_y), vsubq_f32(res_x, res_y))); + + return vaddvq_f32(sum_); +} + +float +fvec_L1_neon(const float* x, const float* y, size_t d) { + float32x4_t sum_ = {0.f}; + + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); + float32x4x4_t c; + + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + c.val[2] = vsubq_f32(a.val[2], b.val[2]); + c.val[3] = vsubq_f32(a.val[3], b.val[3]); + + c.val[0] = vabsq_f32(c.val[0]); + c.val[1] = vabsq_f32(c.val[1]); + c.val[2] = vabsq_f32(c.val[2]); + c.val[3] = vabsq_f32(c.val[3]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + c.val[2] = vaddq_f32(c.val[2], c.val[3]); + c.val[0] = vaddq_f32(c.val[0], c.val[2]); + + sum_ = vaddq_f32(sum_, c.val[0]); + + d -= 16; + } + + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); + float32x4x2_t c; + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + + c.val[0] = vabsq_f32(c.val[0]); + c.val[1] = vabsq_f32(c.val[1]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + sum_ = vaddq_f32(sum_, c.val[0]); + d -= 8; + } + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); + float32x4_t c; + c = vsubq_f32(a, b); + c = vabsq_f32(c); + + sum_ = vaddq_f32(sum_, c); + d -= 4; + } + + float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); + d -= 1; + } + + if (d >= 2) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 1); + res_y = vld1q_lane_f32(y + dim - d, res_y, 1); + d -= 1; + } + + if (d >= 1) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 0); + res_y = vld1q_lane_f32(y + dim - d, res_y, 0); + d -= 1; + } + + sum_ = vaddq_f32(sum_, vabsq_f32(vsubq_f32(res_x, res_y))); + + return vaddvq_f32(sum_); +} + +float +fvec_Linf_neon(const float* x, const float* y, size_t d) { + float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t b = vld1q_f32_x4(y + dim - d); + float32x4x4_t c; + + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + c.val[2] = vsubq_f32(a.val[2], b.val[2]); + c.val[3] = vsubq_f32(a.val[3], b.val[3]); + + c.val[0] = vabsq_f32(c.val[0]); + c.val[1] = vabsq_f32(c.val[1]); + c.val[2] = vabsq_f32(c.val[2]); + c.val[3] = vabsq_f32(c.val[3]); + + c.val[0] = vmaxq_f32(c.val[0], c.val[1]); + c.val[2] = vmaxq_f32(c.val[2], c.val[3]); + c.val[0] = vmaxq_f32(c.val[0], c.val[2]); + + sum_ = vmaxq_f32(sum_, c.val[0]); + + d -= 16; + } + + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t b = vld1q_f32_x2(y + dim - d); + float32x4x2_t c; + c.val[0] = vsubq_f32(a.val[0], b.val[0]); + c.val[1] = vsubq_f32(a.val[1], b.val[1]); + + c.val[0] = vabsq_f32(c.val[0]); + c.val[1] = vabsq_f32(c.val[1]); + + c.val[0] = vmaxq_f32(c.val[0], c.val[1]); + sum_ = vmaxq_f32(sum_, c.val[0]); + d -= 8; + } + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t b = vld1q_f32(y + dim - d); + float32x4_t c; + c = vsubq_f32(a, b); + c = vabsq_f32(c); + + sum_ = vmaxq_f32(sum_, c); + d -= 4; + } + + float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t res_y = {0.0f, 0.0f, 0.0f, 0.0f}; + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + res_y = vld1q_lane_f32(y + dim - d, res_y, 2); + d -= 1; + } + + if (d >= 2) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 1); + res_y = vld1q_lane_f32(y + dim - d, res_y, 1); + d -= 1; + } + + if (d >= 1) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 0); + res_y = vld1q_lane_f32(y + dim - d, res_y, 0); + d -= 1; + } + + sum_ = vmaxq_f32(sum_, vabsq_f32(vsubq_f32(res_x, res_y))); + + return vmaxvq_f32(sum_); +} + +float +fvec_norm_L2sqr_neon(const float* x, size_t d) { + float32x4_t sum_ = {0.0f, 0.0f, 0.0f, 0.0f}; + auto dim = d; + while (d >= 16) { + float32x4x4_t a = vld1q_f32_x4(x + dim - d); + float32x4x4_t c; + c.val[0] = vmulq_f32(a.val[0], a.val[0]); + c.val[1] = vmulq_f32(a.val[1], a.val[1]); + c.val[2] = vmulq_f32(a.val[2], a.val[2]); + c.val[3] = vmulq_f32(a.val[3], a.val[3]); + + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + c.val[2] = vaddq_f32(c.val[2], c.val[3]); + c.val[0] = vaddq_f32(c.val[0], c.val[2]); + + sum_ = vaddq_f32(sum_, c.val[0]); + + d -= 16; + } + + if (d >= 8) { + float32x4x2_t a = vld1q_f32_x2(x + dim - d); + float32x4x2_t c; + c.val[0] = vmulq_f32(a.val[0], a.val[0]); + c.val[1] = vmulq_f32(a.val[1], a.val[1]); + c.val[0] = vaddq_f32(c.val[0], c.val[1]); + sum_ = vaddq_f32(sum_, c.val[0]); + d -= 8; + } + if (d >= 4) { + float32x4_t a = vld1q_f32(x + dim - d); + float32x4_t c; + c = vmulq_f32(a, a); + sum_ = vaddq_f32(sum_, c); + d -= 4; + } + + float32x4_t res_x = {0.0f, 0.0f, 0.0f, 0.0f}; + if (d >= 3) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 2); + d -= 1; + } + + if (d >= 2) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 1); + d -= 1; + } + + if (d >= 1) { + res_x = vld1q_lane_f32(x + dim - d, res_x, 0); + d -= 1; + } + + sum_ = vaddq_f32(sum_, vmulq_f32(res_x, res_x)); + + return vaddvq_f32(sum_); +} + +void +fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny) { + for (size_t i = 0; i < ny; i++) { + dis[i] = fvec_L2sqr_neon(x, y, d); + y += d; + } +} + +void +fvec_inner_products_ny_neon(float* ip, const float* x, const float* y, size_t d, size_t ny) { + for (size_t i = 0; i < ny; i++) { + ip[i] = fvec_inner_product_neon(x, y, d); + y += d; + } +} + +void +fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c) { + size_t len = n; + while (n >= 16) { + auto a_ = vld1q_f32_x4(a + len - n); + auto b_ = vld1q_f32_x4(b + len - n); + b_.val[0] = vmulq_n_f32(b_.val[0], bf); + b_.val[1] = vmulq_n_f32(b_.val[1], bf); + b_.val[2] = vmulq_n_f32(b_.val[2], bf); + b_.val[3] = vmulq_n_f32(b_.val[3], bf); + float32x4x4_t c_; + c_.val[0] = vaddq_f32(b_.val[0], a_.val[0]); + c_.val[1] = vaddq_f32(b_.val[1], a_.val[1]); + c_.val[2] = vaddq_f32(b_.val[2], a_.val[2]); + c_.val[3] = vaddq_f32(b_.val[3], a_.val[3]); + vst1q_f32_x4(c + len - n, c_); + n -= 16; + } + + if (n >= 8) { + auto a_ = vld1q_f32_x2(a + len - n); + auto b_ = vld1q_f32_x2(b + len - n); + b_.val[0] = vmulq_n_f32(b_.val[0], bf); + b_.val[1] = vmulq_n_f32(b_.val[1], bf); + float32x4x2_t c_; + c_.val[0] = vaddq_f32(b_.val[0], a_.val[0]); + c_.val[1] = vaddq_f32(b_.val[1], a_.val[1]); + vst1q_f32_x2(c + len - n, c_); + n -= 8; + } + + if (n >= 4) { + auto a_ = vld1q_f32(a + len - n); + auto b_ = vld1q_f32(b + len - n); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_f32(c + len - n, c_); + n -= 4; + } + + if (n == 3) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n + 2, b_, 2); + b_ = vld1q_lane_f32(b + len - n + 1, b_, 1); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n + 2, c_, 2); + vst1q_lane_f32(c + len - n + 1, c_, 1); + vst1q_lane_f32(c + len - n, c_, 0); + } + if (n == 2) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n + 1, b_, 1); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n + 1, c_, 1); + vst1q_lane_f32(c + len - n, c_, 0); + } + if (n == 1) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n, c_, 0); + } +} + +int +fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c) { + size_t len = n; + uint32x4_t ids = {0, 0, 0, 0}; + float32x4_t val = { + INFINITY, + INFINITY, + INFINITY, + INFINITY, + }; + while (n >= 16) { + auto a_ = vld1q_f32_x4(a + len - n); + auto b_ = vld1q_f32_x4(b + len - n); + b_.val[0] = vmulq_n_f32(b_.val[0], bf); + b_.val[1] = vmulq_n_f32(b_.val[1], bf); + b_.val[2] = vmulq_n_f32(b_.val[2], bf); + b_.val[3] = vmulq_n_f32(b_.val[3], bf); + float32x4x4_t c_; + c_.val[0] = vaddq_f32(b_.val[0], a_.val[0]); + c_.val[1] = vaddq_f32(b_.val[1], a_.val[1]); + c_.val[2] = vaddq_f32(b_.val[2], a_.val[2]); + c_.val[3] = vaddq_f32(b_.val[3], a_.val[3]); + + vst1q_f32_x4(c + len - n, c_); + + uint32_t loc = len - n; + auto cmp = vcleq_f32(c_.val[0], val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + + val = vminq_f32(c_.val[0], val); + + cmp = vcleq_f32(c_.val[1], val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{4, 5, 6, 7}, vld1q_dup_u32(&loc)), ids); + + val = vminq_f32(val, c_.val[1]); + + cmp = vcleq_f32(c_.val[2], val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{8, 9, 10, 11}, vld1q_dup_u32(&loc)), ids); + + val = vminq_f32(val, c_.val[2]); + + cmp = vcleq_f32(c_.val[3], val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{12, 13, 14, 15}, vld1q_dup_u32(&loc)), ids); + + val = vminq_f32(val, c_.val[3]); + + n -= 16; + } + + if (n >= 8) { + auto a_ = vld1q_f32_x2(a + len - n); + auto b_ = vld1q_f32_x2(b + len - n); + b_.val[0] = vmulq_n_f32(b_.val[0], bf); + b_.val[1] = vmulq_n_f32(b_.val[1], bf); + float32x4x2_t c_; + c_.val[0] = vaddq_f32(b_.val[0], a_.val[0]); + c_.val[1] = vaddq_f32(b_.val[1], a_.val[1]); + vst1q_f32_x2(c + len - n, c_); + + uint32_t loc = len - n; + + auto cmp = vcleq_f32(c_.val[0], val); + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + val = vminq_f32(val, c_.val[0]); + cmp = vcleq_f32(c_.val[1], val); + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{4, 5, 6, 7}, vld1q_dup_u32(&loc)), ids); + val = vminq_f32(val, c_.val[1]); + n -= 8; + } + + if (n >= 4) { + auto a_ = vld1q_f32(a + len - n); + auto b_ = vld1q_f32(b + len - n); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_f32(c + len - n, c_); + + uint32_t loc = len - n; + + auto cmp = vcleq_f32(c_, val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + + val = vminq_f32(val, c_); + n -= 4; + } + + if (n == 3) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n + 2, a_, 2); + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n + 2, b_, 2); + b_ = vld1q_lane_f32(b + len - n + 1, b_, 1); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n + 2, c_, 2); + vst1q_lane_f32(c + len - n + 1, c_, 1); + vst1q_lane_f32(c + len - n, c_, 0); + uint32_t loc = len - n; + c_ = vsetq_lane_f32(INFINITY, c_, 3); + auto cmp = vcleq_f32(c_, val); + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + } + if (n == 2) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n + 1, a_, 1); + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n + 1, b_, 1); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n + 1, c_, 1); + vst1q_lane_f32(c + len - n, c_, 0); + uint32_t loc = len - n; + c_ = vsetq_lane_f32(INFINITY, c_, 2); + c_ = vsetq_lane_f32(INFINITY, c_, 3); + auto cmp = vcleq_f32(c_, val); + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + } + if (n == 1) { + float32x4_t a_ = {0.0f, 0.0f, 0.0f, 0.0f}; + float32x4_t b_ = {0.0f, 0.0f, 0.0f, 0.0f}; + + a_ = vld1q_lane_f32(a + len - n, a_, 0); + b_ = vld1q_lane_f32(b + len - n, b_, 0); + b_ = vmulq_n_f32(b_, bf); + float32x4_t c_ = vaddq_f32(b_, a_); + vst1q_lane_f32(c + len - n, c_, 0); + uint32_t loc = len - n; + c_ = vsetq_lane_f32(INFINITY, c_, 1); + c_ = vsetq_lane_f32(INFINITY, c_, 2); + c_ = vsetq_lane_f32(INFINITY, c_, 3); + auto cmp = vcleq_f32(c_, val); + + ids = vbslq_u32(cmp, vaddq_u32(uint32x4_t{0, 1, 2, 3}, vld1q_dup_u32(&loc)), ids); + } + + uint32_t ids_[4]; + vst1q_u32(ids_, ids); + float32_t min_ = INFINITY; + uint32_t ans_ = 0; + + for (int i = 0; i < 4; ++i) { + if (c[ids_[i]] < min_) { + ans_ = ids_[i]; + min_ = c[ids_[i]]; + } + } + return ans_; +} + +} // namespace faiss +#endif diff --git a/src/simd/distances_neon.h b/src/simd/distances_neon.h new file mode 100644 index 000000000..fdfc79ad9 --- /dev/null +++ b/src/simd/distances_neon.h @@ -0,0 +1,55 @@ +// Copyright (C) 2019-2023 Zilliz. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software distributed under the License +// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +// or implied. See the License for the specific language governing permissions and limitations under the License. + +#ifndef DISTANCES_NEON_H +#define DISTANCES_NEON_H + +#include + +namespace faiss { + +/// Squared L2 distance between two vectors +float +fvec_L2sqr_neon(const float* x, const float* y, size_t d); + +/// inner product +float +fvec_inner_product_neon(const float* x, const float* y, size_t d); + +/// L1 distance +float +fvec_L1_neon(const float* x, const float* y, size_t d); + +/// infinity distance +float +fvec_Linf_neon(const float* x, const float* y, size_t d); + +/// squared norm of a vector +float +fvec_norm_L2sqr_neon(const float* x, size_t d); + +/// compute ny square L2 distance between x and a set of contiguous y vectors +void +fvec_L2sqr_ny_neon(float* dis, const float* x, const float* y, size_t d, size_t ny); + +/// compute the inner product between nx vectors x and one y +void +fvec_inner_products_ny_neon(float* ip, const float* x, const float* y, size_t d, size_t ny); + +void +fvec_madd_neon(size_t n, const float* a, float bf, const float* b, float* c); + +int +fvec_madd_and_argmin_neon(size_t n, const float* a, float bf, const float* b, float* c); + +} // namespace faiss + +#endif /* DISTANCES_NEON_H */ diff --git a/src/simd/hook.cc b/src/simd/hook.cc index 15d7d5aa4..7239d9909 100644 --- a/src/simd/hook.cc +++ b/src/simd/hook.cc @@ -18,6 +18,10 @@ #include "faiss/FaissHook.h" +#if defined(__ARM_NEON) +#include "distances_neon.h" +#endif + #if defined(__x86_64__) #include "distances_avx.h" #include "distances_avx512.h" @@ -124,6 +128,22 @@ fvec_hook(std::string& simd_type) { simd_type = "GENERIC"; } #endif + +#if defined(__ARM_NEON) + fvec_inner_product = fvec_inner_product_neon; + fvec_L2sqr = fvec_L2sqr_neon; + fvec_L1 = fvec_L1_neon; + fvec_Linf = fvec_Linf_neon; + + fvec_norm_L2sqr = fvec_norm_L2sqr_neon; + fvec_L2sqr_ny = fvec_L2sqr_ny_neon; + fvec_inner_products_ny = fvec_inner_products_ny_neon; + fvec_madd = fvec_madd_neon; + fvec_madd_and_argmin = fvec_madd_and_argmin_neon; + + simd_type = "NEON"; + +#endif } static int init_hook_ = []() {