From 49882ec3687ead487265bcd23bfbd9bc25556669 Mon Sep 17 00:00:00 2001 From: Cai Yudong Date: Tue, 24 Sep 2024 11:29:07 +0800 Subject: [PATCH] Set scoped omp to fix IVF index build degrade (#863) Signed-off-by: Cai Yudong --- benchmark/hdf5/benchmark_binary.cpp | 2 + benchmark/hdf5/benchmark_binary_range.cpp | 2 + benchmark/hdf5/benchmark_float.cpp | 2 + benchmark/hdf5/benchmark_float_bitset.cpp | 2 + benchmark/hdf5/benchmark_float_qps.cpp | 5 +- benchmark/hdf5/benchmark_float_range.cpp | 2 + .../hdf5/benchmark_float_range_bitset.cpp | 2 + benchmark/hdf5/benchmark_knowhere.h | 3 ++ benchmark/hdf5/ref_logs/Makefile | 1 + include/knowhere/comp/thread_pool.h | 29 +++++++---- src/common/comp/brute_force.cc | 8 +-- src/common/thread/thread.cc | 15 ++++-- src/index/flat/flat.cc | 4 +- src/index/hnsw/faiss_hnsw.cc | 52 +++++++++---------- src/index/ivf/ivf.cc | 17 +++--- tests/ut/test_utils.cc | 28 ++++++---- 16 files changed, 108 insertions(+), 66 deletions(-) diff --git a/benchmark/hdf5/benchmark_binary.cpp b/benchmark/hdf5/benchmark_binary.cpp index 0b926aa75..8ee4ce787 100644 --- a/benchmark/hdf5/benchmark_binary.cpp +++ b/benchmark/hdf5/benchmark_binary.cpp @@ -105,6 +105,8 @@ class Benchmark_binary : public Benchmark_knowhere, public ::testing::Test { cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); } diff --git a/benchmark/hdf5/benchmark_binary_range.cpp b/benchmark/hdf5/benchmark_binary_range.cpp index aa9b0348c..f2c97bce5 100644 --- a/benchmark/hdf5/benchmark_binary_range.cpp +++ b/benchmark/hdf5/benchmark_binary_range.cpp @@ -115,6 +115,8 @@ class Benchmark_binary_range : public Benchmark_knowhere, public ::testing::Test cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; cfg_[knowhere::meta::RADIUS] = *gt_radius_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); } diff --git a/benchmark/hdf5/benchmark_float.cpp b/benchmark/hdf5/benchmark_float.cpp index 34fee4a3b..52e2690c2 100644 --- a/benchmark/hdf5/benchmark_float.cpp +++ b/benchmark/hdf5/benchmark_float.cpp @@ -141,6 +141,8 @@ class Benchmark_float : public Benchmark_knowhere, public ::testing::Test { cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); } diff --git a/benchmark/hdf5/benchmark_float_bitset.cpp b/benchmark/hdf5/benchmark_float_bitset.cpp index 58ece6fd4..0ab64f2ab 100644 --- a/benchmark/hdf5/benchmark_float_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_bitset.cpp @@ -133,6 +133,8 @@ class Benchmark_float_bitset : public Benchmark_knowhere, public ::testing::Test cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); create_golden_index(cfg_); diff --git a/benchmark/hdf5/benchmark_float_qps.cpp b/benchmark/hdf5/benchmark_float_qps.cpp index e0fb39697..fa3300db4 100644 --- a/benchmark/hdf5/benchmark_float_qps.cpp +++ b/benchmark/hdf5/benchmark_float_qps.cpp @@ -277,7 +277,10 @@ class Benchmark_float_qps : public Benchmark_knowhere, public ::testing::Test { load_hdf5_data(); cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; - knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AUTO); + knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); + printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); #ifdef KNOWHERE_WITH_GPU knowhere::KnowhereConfig::InitGPUResource(GPU_DEVICE_ID, 2); cfg_[knowhere::meta::DEVICE_ID] = GPU_DEVICE_ID; diff --git a/benchmark/hdf5/benchmark_float_range.cpp b/benchmark/hdf5/benchmark_float_range.cpp index c716a8af7..1ec62a333 100644 --- a/benchmark/hdf5/benchmark_float_range.cpp +++ b/benchmark/hdf5/benchmark_float_range.cpp @@ -153,6 +153,8 @@ class Benchmark_float_range : public Benchmark_knowhere, public ::testing::Test cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; cfg_[knowhere::meta::RADIUS] = *gt_radius_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); } diff --git a/benchmark/hdf5/benchmark_float_range_bitset.cpp b/benchmark/hdf5/benchmark_float_range_bitset.cpp index 693303792..6a6216ec9 100644 --- a/benchmark/hdf5/benchmark_float_range_bitset.cpp +++ b/benchmark/hdf5/benchmark_float_range_bitset.cpp @@ -140,6 +140,8 @@ class Benchmark_float_range_bitset : public Benchmark_knowhere, public ::testing cfg_[knowhere::meta::METRIC_TYPE] = metric_type_; cfg_[knowhere::meta::RADIUS] = *gt_radius_; knowhere::KnowhereConfig::SetSimdType(knowhere::KnowhereConfig::SimdType::AVX2); + knowhere::KnowhereConfig::SetBuildThreadPoolSize(default_build_thread_num); + knowhere::KnowhereConfig::SetSearchThreadPoolSize(default_search_thread_num); printf("faiss::distance_compute_blas_threshold: %ld\n", knowhere::KnowhereConfig::GetBlasThreshold()); create_golden_index(cfg_); diff --git a/benchmark/hdf5/benchmark_knowhere.h b/benchmark/hdf5/benchmark_knowhere.h index 1ebaef45a..adb17fa72 100644 --- a/benchmark/hdf5/benchmark_knowhere.h +++ b/benchmark/hdf5/benchmark_knowhere.h @@ -27,6 +27,9 @@ #include "knowhere/index/index_factory.h" #include "knowhere/version.h" +static const size_t default_build_thread_num = 8; +static const size_t default_search_thread_num = 8; + namespace fs = std::filesystem; std::string kDir = fs::current_path().string() + "/diskann_test"; std::string kRawDataPath = kDir + "/raw_data"; diff --git a/benchmark/hdf5/ref_logs/Makefile b/benchmark/hdf5/ref_logs/Makefile index 3258317f2..1b94f4d78 100644 --- a/benchmark/hdf5/ref_logs/Makefile +++ b/benchmark/hdf5/ref_logs/Makefile @@ -24,6 +24,7 @@ test_binary_range_hnsw: # Test Knowhere float index test_float: test_float_idmap test_float_ivf_flat test_float_ivf_sq8 test_float_ivf_pq test_float_hnsw test_float_diskann test_float_gpu: test_float_ivf_flat test_float_ivf_pq +test_float_ivf: test_float_ivf_flat test_float_ivf_pq test_float_idmap: ./benchmark_float --gtest_filter="Benchmark_float.TEST_IDMAP" | tee test_float_idmap.log diff --git a/include/knowhere/comp/thread_pool.h b/include/knowhere/comp/thread_pool.h index 65aab700e..6b2e0f9e5 100644 --- a/include/knowhere/comp/thread_pool.h +++ b/include/knowhere/comp/thread_pool.h @@ -245,26 +245,22 @@ class ThreadPool { return search_pool_; } - class ScopedOmpSetter { + class ScopedBuildOmpSetter { int omp_before; #ifdef OPENBLAS_OS_LINUX int blas_thread_before; #endif public: - explicit ScopedOmpSetter(int num_threads = 1) { - if (num_threads <= 0) { - return; - } - + explicit ScopedBuildOmpSetter(int num_threads = 0) { omp_before = (build_pool_ ? build_pool_->size() : omp_get_max_threads()); #ifdef OPENBLAS_OS_LINUX + // to avoid thread spawn when IVF_PQ build blas_thread_before = openblas_get_num_threads(); - openblas_set_num_threads(num_threads); + openblas_set_num_threads(1); #endif - - omp_set_num_threads(num_threads); + omp_set_num_threads(num_threads <= 0 ? omp_before : num_threads); } - ~ScopedOmpSetter() { + ~ScopedBuildOmpSetter() { omp_set_num_threads(omp_before); #ifdef OPENBLAS_OS_LINUX openblas_set_num_threads(blas_thread_before); @@ -272,6 +268,19 @@ class ThreadPool { } }; + class ScopedSearchOmpSetter { + int omp_before; + + public: + explicit ScopedSearchOmpSetter(int num_threads = 1) { + omp_before = (search_pool_ ? search_pool_->size() : omp_get_max_threads()); + omp_set_num_threads(num_threads <= 0 ? omp_before : num_threads); + } + ~ScopedSearchOmpSetter() { + omp_set_num_threads(omp_before); + } + }; + private: folly::CPUThreadPoolExecutor pool_; diff --git a/src/common/comp/brute_force.cc b/src/common/comp/brute_force.cc index 9cdd27b5b..647792620 100644 --- a/src/common/comp/brute_force.cc +++ b/src/common/comp/brute_force.cc @@ -117,7 +117,7 @@ BruteForce::Search(const DataSetPtr base_dataset, const DataSetPtr query_dataset futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(pool->push([&, index = i, labels_ptr = labels.get(), distances_ptr = distances.get()] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); auto cur_labels = labels_ptr + topk * index; auto cur_distances = distances_ptr + topk * index; @@ -244,7 +244,7 @@ BruteForce::SearchWithBuf(const DataSetPtr base_dataset, const DataSetPtr query_ futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(pool->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); auto cur_labels = labels + topk * index; auto cur_distances = distances + topk * index; @@ -420,7 +420,7 @@ BruteForce::RangeSearch(const DataSetPtr base_dataset, const DataSetPtr query_da return Status::success; } // else not sparse: - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); faiss::RangeSearchResult res(1); BitsetViewIDSelector bw_idselector(bitset); @@ -667,7 +667,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da for (int i = 0; i < nq; ++i) { futs.emplace_back(pool->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); BitsetViewIDSelector bw_idselector(bitset); faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector; diff --git a/src/common/thread/thread.cc b/src/common/thread/thread.cc index 0e41e039a..fa1709d96 100644 --- a/src/common/thread/thread.cc +++ b/src/common/thread/thread.cc @@ -29,7 +29,7 @@ ExecOverSearchThreadPool(std::vector>& tasks) { futures.reserve(tasks.size()); for (auto&& t : tasks) { futures.emplace_back(pool->push([&t]() { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); t(); })); } @@ -44,7 +44,7 @@ ExecOverBuildThreadPool(std::vector>& tasks) { futures.reserve(tasks.size()); for (auto&& t : tasks) { futures.emplace_back(pool->push([&t]() { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedBuildOmpSetter setter(1); t(); })); } @@ -72,9 +72,14 @@ GetBuildThreadPoolSize() { return ThreadPool::GetGlobalBuildThreadPool()->size(); } -std::unique_ptr -CreateScopeOmpSetter(int num_threads) { - return std::make_unique(num_threads); +std::unique_ptr +CreateScopeBuildOmpSetter(int num_threads) { + return std::make_unique(num_threads); +} + +std::unique_ptr +CreateScopeSearchOmpSetter(int num_threads) { + return std::make_unique(num_threads); } } // namespace knowhere diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index ac81815da..f1e3ecde6 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -91,7 +91,7 @@ class FlatIndexNode : public IndexNode { futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(search_pool_->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); auto cur_ids = ids + k * index; auto cur_dis = distances + k * index; @@ -167,7 +167,7 @@ class FlatIndexNode : public IndexNode { futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(search_pool_->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); faiss::RangeSearchResult res(1); BitsetViewIDSelector bw_idselector(bitset); diff --git a/src/index/hnsw/faiss_hnsw.cc b/src/index/hnsw/faiss_hnsw.cc index 2a1cfb80b..f31cf260a 100644 --- a/src/index/hnsw/faiss_hnsw.cc +++ b/src/index/hnsw/faiss_hnsw.cc @@ -80,19 +80,18 @@ class BaseFaissIndexNode : public IndexNode { // use build_pool_ to make sure the OMP threads spawned by index_->train etc // can inherit the low nice value of threads in build_pool_. - auto tryObj = build_pool - ->push([&] { - std::unique_ptr setter; - if (base_cfg.num_build_thread.has_value()) { - setter = - std::make_unique(base_cfg.num_build_thread.value()); - } else { - setter = std::make_unique(); - } - - return TrainInternal(dataset, *cfg); - }) - .getTry(); + auto tryObj = + build_pool + ->push([&] { + std::unique_ptr setter; + if (base_cfg.num_build_thread.has_value()) { + setter = std::make_unique(base_cfg.num_build_thread.value()); + } else { + setter = std::make_unique(); + } + return TrainInternal(dataset, *cfg); + }) + .getTry(); if (!tryObj.hasValue()) { LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what(); @@ -108,19 +107,18 @@ class BaseFaissIndexNode : public IndexNode { // use build_pool_ to make sure the OMP threads spawned by index_->train etc // can inherit the low nice value of threads in build_pool_. - auto tryObj = build_pool - ->push([&] { - std::unique_ptr setter; - if (base_cfg.num_build_thread.has_value()) { - setter = - std::make_unique(base_cfg.num_build_thread.value()); - } else { - setter = std::make_unique(); - } - - return AddInternal(dataset, *cfg); - }) - .getTry(); + auto tryObj = + build_pool + ->push([&] { + std::unique_ptr setter; + if (base_cfg.num_build_thread.has_value()) { + setter = std::make_unique(base_cfg.num_build_thread.value()); + } else { + setter = std::make_unique(); + } + return AddInternal(dataset, *cfg); + }) + .getTry(); if (!tryObj.hasValue()) { LOG_KNOWHERE_WARNING_ << "faiss internal error: " << tryObj.exception().what(); @@ -869,7 +867,7 @@ class BaseFaissRegularIndexHNSWNode : public BaseFaissRegularIndexNode { for (int64_t i = 0; i < rows; ++i) { futs.emplace_back(search_pool->push([&, idx = i] { // 1 thread per element - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); // set up a query // const float* cur_query = (const float*)data + idx * dim; diff --git a/src/index/ivf/ivf.cc b/src/index/ivf/ivf.cc index b008d75b2..132511d6c 100644 --- a/src/index/ivf/ivf.cc +++ b/src/index/ivf/ivf.cc @@ -414,11 +414,11 @@ template Status IvfIndexNode::TrainInternal(const DataSetPtr dataset, std::shared_ptr cfg) { const BaseConfig& base_cfg = static_cast(*cfg); - std::unique_ptr setter; + std::unique_ptr setter; if (base_cfg.num_build_thread.has_value()) { - setter = std::make_unique(base_cfg.num_build_thread.value()); + setter = std::make_unique(base_cfg.num_build_thread.value()); } else { - setter = std::make_unique(); + setter = std::make_unique(); } bool is_cosine = IsMetricType(base_cfg.metric_type.value(), knowhere::metric::COSINE); @@ -627,11 +627,12 @@ IvfIndexNode::Add(const DataSetPtr dataset, std::shared_ptr // can inherit the low nice value of threads in build_pool_. auto tryObj = build_pool_ ->push([&] { - std::unique_ptr setter; + std::unique_ptr setter; if (base_cfg.num_build_thread.has_value()) { - setter = std::make_unique(base_cfg.num_build_thread.value()); + setter = + std::make_unique(base_cfg.num_build_thread.value()); } else { - setter = std::make_unique(); + setter = std::make_unique(); } if constexpr (std::is_same::value) { index_->add(rows, (const uint8_t*)data); @@ -677,7 +678,7 @@ IvfIndexNode::Search(const DataSetPtr dataset, std::unique_ futs.reserve(rows); for (int i = 0; i < rows; ++i) { futs.emplace_back(search_pool_->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); auto offset = k * index; std::unique_ptr copied_query = nullptr; @@ -802,7 +803,7 @@ IvfIndexNode::RangeSearch(const DataSetPtr dataset, std::un futs.reserve(nq); for (int i = 0; i < nq; ++i) { futs.emplace_back(search_pool_->push([&, index = i] { - ThreadPool::ScopedOmpSetter setter(1); + ThreadPool::ScopedSearchOmpSetter setter(1); faiss::RangeSearchResult res(1); std::unique_ptr copied_query = nullptr; diff --git a/tests/ut/test_utils.cc b/tests/ut/test_utils.cc index 0259e4121..939a31d2a 100644 --- a/tests/ut/test_utils.cc +++ b/tests/ut/test_utils.cc @@ -204,18 +204,28 @@ TEST_CASE("Test ThreadPool") { } } - SECTION("ScopedOmpSetter") { - int prev_num_threads = omp_get_max_threads(); + SECTION("ScopedBuildOmpSetter") { + int prev_num_threads = knowhere::ThreadPool::GetGlobalBuildThreadPoolSize(); { int target_num_threads = (prev_num_threads / 2) > 0 ? (prev_num_threads / 2) : 1; - knowhere::ThreadPool::ScopedOmpSetter setter(target_num_threads); - auto thread_num = omp_get_max_threads(); - REQUIRE(thread_num == target_num_threads); -#ifdef OPENBLAS_OS_LINUX - auto openblas_thread_num = openblas_get_num_threads(); - REQUIRE(openblas_thread_num == target_num_threads); -#endif + knowhere::ThreadPool::ScopedBuildOmpSetter setter(target_num_threads); + auto thread_num_1 = omp_get_max_threads(); + REQUIRE(thread_num_1 == target_num_threads); + } + auto thread_num_2 = omp_get_max_threads(); + REQUIRE(thread_num_2 == prev_num_threads); + } + + SECTION("ScopedSearchOmpSetter") { + int prev_num_threads = knowhere::ThreadPool::GetGlobalSearchThreadPoolSize(); + { + int target_num_threads = (prev_num_threads / 2) > 0 ? (prev_num_threads / 2) : 1; + knowhere::ThreadPool::ScopedSearchOmpSetter setter(target_num_threads); + auto thread_num_1 = omp_get_max_threads(); + REQUIRE(thread_num_1 == target_num_threads); } + auto thread_num_2 = omp_get_max_threads(); + REQUIRE(thread_num_2 == prev_num_threads); } }