Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix WaitForTasks #67

Closed
35 changes: 24 additions & 11 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <atomic>
#include <barrier>
#include <concepts>
#include <deque>
#include <functional>
Expand Down Expand Up @@ -56,9 +55,15 @@ namespace dp {
// invoke the task
while (auto task = tasks_[id].tasks.pop_front()) {
try {
// decrement the unassigned tasks as the task is now going
// to be executed
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
// invoke the task
std::invoke(std::move(task.value()));
completed_tasks_.fetch_sub(1, std::memory_order_release);
// the above task can push more work onto the pool, so we
// only decrement the in flights once the task has been
// executed because now it's now longer "in flight"
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
} catch (...) {
}
}
Expand All @@ -70,7 +75,7 @@ namespace dp {
// steal a task
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
completed_tasks_.fetch_sub(1, std::memory_order_release);
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
// stop stealing once we have invoked a stolen task
break;
}
Expand All @@ -82,8 +87,9 @@ namespace dp {
priority_queue_.rotate_to_front(id);
// check if all tasks are completed and release the barrier (binary
// semaphore)
if (completed_tasks_.load(std::memory_order_acquire) == 0) {
threads_done_.release();
if (in_flight_tasks_.load(std::memory_order_acquire) == 0) {
threads_complete_signal_ = true;
threads_complete_signal_.notify_one();
}

} while (!stop_tok.stop_requested());
Expand Down Expand Up @@ -215,9 +221,9 @@ namespace dp {
* @details This function will block until all tasks have been completed.
*/
void wait_for_tasks() {
if (completed_tasks_.load(std::memory_order_acquire) > 0) {
if (in_flight_tasks_.load(std::memory_order_acquire) > 0) {
// wait for all tasks to finish
threads_done_.acquire();
threads_complete_signal_.wait(false);
}
}

Expand All @@ -230,8 +236,14 @@ namespace dp {
return;
}
auto i = *(i_opt);
unassigned_tasks_.fetch_add(1, std::memory_order_relaxed);
completed_tasks_.fetch_add(1, std::memory_order_relaxed);
unassigned_tasks_.fetch_add(1, std::memory_order_release);
const auto prev_in_flight = in_flight_tasks_.fetch_add(1, std::memory_order_release);

// reset the in flight signal if the list was previously empty
if (prev_in_flight == 0) {
threads_complete_signal_.store(false, std::memory_order_release);
}

tasks_[i].tasks.push_back(std::forward<Function>(f));
tasks_[i].signal.release();
}
Expand All @@ -244,8 +256,9 @@ namespace dp {
std::vector<ThreadType> threads_;
std::deque<task_item> tasks_;
dp::thread_safe_queue<std::size_t> priority_queue_;
std::atomic_int_fast64_t unassigned_tasks_{}, completed_tasks_{};
std::binary_semaphore threads_done_{0};
// guarantee these get zero-initialized
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
std::atomic_bool threads_complete_signal_{false};
};

/**
Expand Down
80 changes: 80 additions & 0 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,86 @@ TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") {
CHECK_EQ(counter.load(), total_tasks);
}

TEST_CASE("Ensure wait_for_tasks() properly waits for tasks to fully complete") {
class counter_wrapper {
public:
std::atomic_int counter = 0;

void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
};

dp::thread_pool local_pool{};
constexpr auto task_count = 10;
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 17; var1++) {
for (int var2 = 0; var2 < 12; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

auto all_correct_count =
std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
const auto sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 17 * 12 * task_count);
CHECK(all_correct_count);
}

TEST_CASE("Ensure wait_for_tasks() can be called multiple times on the same pool") {
class counter_wrapper {
public:
std::atomic_int counter = 0;

void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
};

dp::thread_pool local_pool{};
constexpr auto task_count = 10;
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 17; var1++) {
for (int var2 = 0; var2 < 12; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

auto all_correct_count =
std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
auto sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 17 * 12 * task_count);
CHECK(all_correct_count);

for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 17; var1++) {
for (int var2 = 0; var2 < 12; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

all_correct_count = std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 17 * 12 * task_count);
CHECK(all_correct_count);
}

TEST_CASE("Initialization function is called") {
std::atomic_int counter = 0;
{
Expand Down