diff --git a/cpp/bench/ann/src/common/thread_pool.hpp b/cpp/bench/ann/src/common/thread_pool.hpp new file mode 100644 index 0000000000..efea938d5b --- /dev/null +++ b/cpp/bench/ann/src/common/thread_pool.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +class FixedThreadPool { + public: + FixedThreadPool(int num_threads) + { + if (num_threads < 1) { + throw std::runtime_error("num_threads must >= 1"); + } else if (num_threads == 1) { + return; + } + + tasks_ = new Task_[num_threads]; + + threads_.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + threads_.emplace_back([&, i] { + auto& task = tasks_[i]; + while (true) { + std::unique_lock lock(task.mtx); + task.cv.wait(lock, + [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); + if (finished_.load(std::memory_order_relaxed)) { break; } + + task.task(); + task.has_task = false; + } + }); + } + } + + ~FixedThreadPool() + { + if (threads_.empty()) { return; } + + finished_.store(true, std::memory_order_relaxed); + for (unsigned i = 0; i < threads_.size(); ++i) { + auto& task = tasks_[i]; + std::lock_guard(task.mtx); + + task.cv.notify_one(); + threads_[i].join(); + } + + delete[] tasks_; + } + + template + void submit(Func f, IdxT len) + { + if (threads_.empty()) { + for (IdxT i = 0; i < len; ++i) { + f(i); + } + return; + } + + const int num_threads = threads_.size(); + // one extra part for competition among threads + const IdxT items_per_thread = len / (num_threads + 1); + std::atomic cnt(items_per_thread * num_threads); + + auto wrapped_f = [&](IdxT start, IdxT end) { + for (IdxT i = start; i < end; ++i) { + f(i); + } + + while (true) { + IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); + if (i >= len) { break; } + f(i); + } + }; + + std::vector> futures; + futures.reserve(num_threads); + for (int i = 0; i < num_threads; ++i) { + IdxT start = i * items_per_thread; + auto& task = tasks_[i]; + { + std::lock_guard lock(task.mtx); + (void)lock; // stop nvcc warning + task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); + futures.push_back(task.task.get_future()); + task.has_task = true; + } + task.cv.notify_one(); + } + + for (auto& fut : futures) { + fut.wait(); + } + return; + } + + private: + struct alignas(64) Task_ { + std::mutex mtx; + std::condition_variable cv; + bool has_task = false; + std::packaged_task task; + }; + + Task_* tasks_; + std::vector threads_; + std::atomic finished_{false}; +}; diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp b/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp index 0552e8fa36..f11df605ee 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp +++ b/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp @@ -76,6 +76,7 @@ void parse_search_param(const nlohmann::json& conf, { param.nprobe = conf.at("nprobe"); if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); } + if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } } template class Algo> diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h index 3a78ca1724..a703fa9950 100644 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h @@ -16,6 +16,8 @@ #pragma once #include "../common/ann_types.hpp" +#include "../common/thread_pool.hpp" + #include #include @@ -54,6 +56,7 @@ class FaissCpu : public ANN { struct SearchParam : public AnnSearchParam { int nprobe; float refine_ratio = 1.0; + int num_threads = omp_get_num_procs(); }; struct BuildParam { @@ -116,6 +119,9 @@ class FaissCpu : public ANN { faiss::MetricType metric_type_; int nlist_; double training_sample_fraction_; + + int num_threads_; + std::unique_ptr thread_pool_; }; template @@ -160,6 +166,11 @@ void FaissCpu::set_search_param(const AnnSearchParam& param) this->index_refine_ = std::make_unique(this->index_.get()); this->index_refine_.get()->k_factor = search_param.refine_ratio; } + + if (!thread_pool_ || num_threads_ != search_param.num_threads) { + num_threads_ = search_param.num_threads; + thread_pool_ = std::make_unique(num_threads_); + } } template @@ -172,7 +183,13 @@ void FaissCpu::search(const T* queries, { static_assert(sizeof(size_t) == sizeof(faiss::idx_t), "sizes of size_t and faiss::idx_t are different"); - index_->search(batch_size, queries, k, distances, reinterpret_cast(neighbors)); + + thread_pool_->submit( + [&](int i) { + // Use thread pool for batch size = 1. FAISS multi-threads internally for batch size > 1. + index_->search(batch_size, queries, k, distances, reinterpret_cast(neighbors)); + }, + 1); } template @@ -275,7 +292,14 @@ class FaissCpuFlat : public FaissCpu { } // class FaissCpu is more like a IVF class, so need special treating here - void set_search_param(const typename ANN::AnnSearchParam&) override{}; + void set_search_param(const typename ANN::AnnSearchParam& param) override + { + auto search_param = dynamic_cast::SearchParam&>(param); + if (!this->thread_pool_ || this->num_threads_ != search_param.num_threads) { + this->num_threads_ = search_param.num_threads; + this->thread_pool_ = std::make_unique(this->num_threads_); + } + }; void save(const std::string& file) const override { diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h index 4d7b993aa1..df44605493 100644 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h @@ -31,9 +31,8 @@ #include #include -#include - #include "../common/ann_types.hpp" +#include "../common/thread_pool.hpp" #include namespace raft::bench::ann { @@ -53,112 +52,6 @@ struct hnsw_dist_t { using type = int; }; -class FixedThreadPool { - public: - FixedThreadPool(int num_threads) - { - if (num_threads < 1) { - throw std::runtime_error("num_threads must >= 1"); - } else if (num_threads == 1) { - return; - } - - tasks_ = new Task_[num_threads]; - - threads_.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - threads_.emplace_back([&, i] { - auto& task = tasks_[i]; - while (true) { - std::unique_lock lock(task.mtx); - task.cv.wait(lock, - [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); - if (finished_.load(std::memory_order_relaxed)) { break; } - - task.task(); - task.has_task = false; - } - }); - } - } - - ~FixedThreadPool() - { - if (threads_.empty()) { return; } - - finished_.store(true, std::memory_order_relaxed); - for (unsigned i = 0; i < threads_.size(); ++i) { - auto& task = tasks_[i]; - std::lock_guard(task.mtx); - - task.cv.notify_one(); - threads_[i].join(); - } - - delete[] tasks_; - } - - template - void submit(Func f, IdxT len) - { - if (threads_.empty()) { - for (IdxT i = 0; i < len; ++i) { - f(i); - } - return; - } - - const int num_threads = threads_.size(); - // one extra part for competition among threads - const IdxT items_per_thread = len / (num_threads + 1); - std::atomic cnt(items_per_thread * num_threads); - - auto wrapped_f = [&](IdxT start, IdxT end) { - for (IdxT i = start; i < end; ++i) { - f(i); - } - - while (true) { - IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); - if (i >= len) { break; } - f(i); - } - }; - - std::vector> futures; - futures.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - IdxT start = i * items_per_thread; - auto& task = tasks_[i]; - { - std::lock_guard lock(task.mtx); - (void)lock; // stop nvcc warning - task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); - futures.push_back(task.task.get_future()); - task.has_task = true; - } - task.cv.notify_one(); - } - - for (auto& fut : futures) { - fut.wait(); - } - return; - } - - private: - struct alignas(64) Task_ { - std::mutex mtx; - std::condition_variable cv; - bool has_task = false; - std::packaged_task task; - }; - - Task_* tasks_; - std::vector threads_; - std::atomic finished_{false}; -}; - template class HnswLib : public ANN { public: @@ -281,6 +174,7 @@ void HnswLib::search( { thread_pool_->submit( [&](int i) { + // hnsw can only handle a single vector at a time. get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k); }, batch_size); diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 727a6ed830..19c5151186 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -52,7 +52,7 @@ class RaftCagra : public ANN { using BuildParam = raft::neighbors::cagra::index_params; - RaftCagra(Metric metric, int dim, const BuildParam& param) + RaftCagra(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) : ANN(metric, dim), index_params_(param), dimension_(dim), diff --git a/docs/source/ann_benchmarks_param_tuning.md b/docs/source/ann_benchmarks_param_tuning.md index b70d1d788f..0faaeba59c 100644 --- a/docs/source/ann_benchmarks_param_tuning.md +++ b/docs/source/ann_benchmarks_param_tuning.md @@ -93,11 +93,16 @@ IVF-pq is an inverted-file index, which partitions the vectors into a series of | `numProbes` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | | `refine_ratio` | `search_params` | N| Positive Number >=0 | 0 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. | -### `faiss_flat` +### `faiss_cpu_flat` Use FAISS flat index on the CPU, which performs an exact search using brute-force and doesn't have any further build or search parameters. -### `faiss_ivf_flat` + +| Parameter | Type | Required | Data Type | Default | Description | +|-----------|----------------|----------|---------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. | + +### `faiss_cpu_ivf_flat` Use FAISS IVF-Flat index on CPU @@ -106,8 +111,9 @@ Use FAISS IVF-Flat index on CPU | `nlists` | `build_param` | Y | Positive Integer >0 | | Number of clusters to partition the vectors into. Larger values will put less points into each cluster but this will impact index build time as more clusters need to be trained. | | `ratio` | `build_param` | N | Positive Integer >0 | 2 | `1/ratio` is the number of training points which should be used to train the clusters. | | `nprobe` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | +| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. | -### `faiss_ivf_pq` +### `faiss_cpu_ivf_pq` Use FAISS IVF-PQ index on CPU @@ -120,6 +126,7 @@ Use FAISS IVF-PQ index on CPU | `bitsPerCode` | `build_param` | N | Positive Integer [4-8] | 8 | Number of bits to use for each code. | | `numProbes` | `search_params` | Y | Positive Integer >0 | | The closest number of clusters to search for each query vector. Larger values will improve recall but will search more points in the index. | | `refine_ratio` | `search_params` | N| Positive Number >=0 | 0 | `refine_ratio * k` nearest neighbors are queried from the index initially and an additional refinement step improves recall by selecting only the best `k` neighbors. | +| `numThreads` | `search_params` | N | Positive Integer >0 | 1 | Number of threads to use for queries. | ## HNSW