diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 8ff6e3e..48b00e6 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -47,7 +46,13 @@ namespace dp { try { threads_.emplace_back([&, id = current_id, init](const std::stop_token &stop_tok) { - init(id); + // invoke the init function on the thread + try { + std::invoke(init, id); + } catch (...) { + // suppress exceptions + } + do { // wait until signaled tasks_[id].signal.acquire(); @@ -55,12 +60,15 @@ namespace dp { do { // invoke the task while (auto task = tasks_[id].tasks.pop_front()) { - try { - unassigned_tasks_.fetch_sub(1, std::memory_order_release); - std::invoke(std::move(task.value())); - completed_tasks_.fetch_sub(1, std::memory_order_release); - } catch (...) { - } + // decrement the unassigned tasks as the task is now going + // to be executed + unassigned_tasks_.fetch_sub(1, std::memory_order_release); + // invoke the task + std::invoke(std::move(task.value())); + // the above task can push more work onto the pool, so we + // only decrement the in flights once the task has been + // executed because now it's now longer "in flight" + in_flight_tasks_.fetch_sub(1, std::memory_order_release); } // try to steal a task @@ -70,7 +78,7 @@ namespace dp { // steal a task unassigned_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); - completed_tasks_.fetch_sub(1, std::memory_order_release); + in_flight_tasks_.fetch_sub(1, std::memory_order_release); // stop stealing once we have invoked a stolen task break; } @@ -82,8 +90,9 @@ namespace dp { priority_queue_.rotate_to_front(id); // check if all tasks are completed and release the barrier (binary // semaphore) - if (completed_tasks_.load(std::memory_order_acquire) == 0) { - threads_done_.release(); + if (in_flight_tasks_.load(std::memory_order_acquire) == 0) { + threads_complete_signal_.store(true, std::memory_order_release); + threads_complete_signal_.notify_one(); } } while (!stop_tok.stop_requested()); @@ -214,6 +223,11 @@ namespace dp { })); } + /** + * @brief Returns the number of threads in the pool. + * + * @return std::size_t The number of threads in the pool. + */ [[nodiscard]] auto size() const { return threads_.size(); } /** @@ -221,9 +235,9 @@ namespace dp { * @details This function will block until all tasks have been completed. */ void wait_for_tasks() { - if (completed_tasks_.load(std::memory_order_acquire) > 0) { + if (in_flight_tasks_.load(std::memory_order_acquire) > 0) { // wait for all tasks to finish - threads_done_.acquire(); + threads_complete_signal_.wait(false); } } @@ -235,9 +249,19 @@ namespace dp { // would only be a problem if there are zero threads return; } + // get the index auto i = *(i_opt); - unassigned_tasks_.fetch_add(1, std::memory_order_relaxed); - completed_tasks_.fetch_add(1, std::memory_order_relaxed); + + // increment the unassigned tasks and in flight tasks + unassigned_tasks_.fetch_add(1, std::memory_order_release); + const auto prev_in_flight = in_flight_tasks_.fetch_add(1, std::memory_order_release); + + // reset the in flight signal if the list was previously empty + if (prev_in_flight == 0) { + threads_complete_signal_.store(false, std::memory_order_release); + } + + // assign work tasks_[i].tasks.push_back(std::forward(f)); tasks_[i].signal.release(); } @@ -250,8 +274,9 @@ namespace dp { std::vector threads_; std::deque tasks_; dp::thread_safe_queue priority_queue_; - std::atomic_int_fast64_t unassigned_tasks_{}, completed_tasks_{}; - std::binary_semaphore threads_done_{0}; + // guarantee these get zero-initialized + std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0}; + std::atomic_bool threads_complete_signal_{false}; }; /** diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index aeb6b0b..dc16faa 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -469,6 +470,86 @@ TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") { CHECK_EQ(counter.load(), total_tasks); } +TEST_CASE("Ensure wait_for_tasks() properly waits for tasks to fully complete") { + class counter_wrapper { + public: + std::atomic_int counter = 0; + + void increment_counter() { counter.fetch_add(1, std::memory_order_release); } + }; + + dp::thread_pool local_pool{}; + constexpr auto task_count = 10; + std::array counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + for (size_t i = 0; i < task_count; i++) { + counter_wrapper cnt_wrp{}; + + for (size_t var1 = 0; var1 < 17; var1++) { + for (int var2 = 0; var2 < 12; var2++) { + local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); }); + } + } + local_pool.wait_for_tasks(); + // std::cout << cnt_wrp.counter << std::endl; + counts[i] = cnt_wrp.counter.load(std::memory_order_acquire); + } + + auto all_correct_count = + std::ranges::all_of(counts, [](int count) { return count == 17 * 12; }); + const auto sum = std::accumulate(counts.begin(), counts.end(), 0); + CHECK_EQ(sum, 17 * 12 * task_count); + CHECK(all_correct_count); +} + +TEST_CASE("Ensure wait_for_tasks() can be called multiple times on the same pool") { + class counter_wrapper { + public: + std::atomic_int counter = 0; + + void increment_counter() { counter.fetch_add(1, std::memory_order_release); } + }; + + dp::thread_pool local_pool{}; + constexpr auto task_count = 10; + std::array counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + for (size_t i = 0; i < task_count; i++) { + counter_wrapper cnt_wrp{}; + + for (size_t var1 = 0; var1 < 16; var1++) { + for (int var2 = 0; var2 < 13; var2++) { + local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); }); + } + } + local_pool.wait_for_tasks(); + // std::cout << cnt_wrp.counter << std::endl; + counts[i] = cnt_wrp.counter.load(std::memory_order_acquire); + } + + auto all_correct_count = + std::ranges::all_of(counts, [](int count) { return count == 16 * 13; }); + auto sum = std::accumulate(counts.begin(), counts.end(), 0); + CHECK_EQ(sum, 16 * 13 * task_count); + CHECK(all_correct_count); + + for (size_t i = 0; i < task_count; i++) { + counter_wrapper cnt_wrp{}; + + for (size_t var1 = 0; var1 < 17; var1++) { + for (int var2 = 0; var2 < 12; var2++) { + local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); }); + } + } + local_pool.wait_for_tasks(); + // std::cout << cnt_wrp.counter << std::endl; + counts[i] = cnt_wrp.counter.load(std::memory_order_acquire); + } + + all_correct_count = std::ranges::all_of(counts, [](int count) { return count == 17 * 12; }); + sum = std::accumulate(counts.begin(), counts.end(), 0); + CHECK_EQ(sum, 17 * 12 * task_count); + CHECK(all_correct_count); +} + TEST_CASE("Initialization function is called") { std::atomic_int counter = 0; {