Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with wait_for_tasks() #68

Merged
merged 15 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <atomic>
#include <barrier>
#include <concepts>
#include <deque>
#include <functional>
Expand Down Expand Up @@ -47,20 +46,29 @@ namespace dp {
try {
threads_.emplace_back([&, id = current_id,
init](const std::stop_token &stop_tok) {
init(id);
// invoke the init function on the thread
try {
std::invoke(init, id);
} catch (...) {
// suppress exceptions
}

do {
// wait until signaled
tasks_[id].signal.acquire();

do {
// invoke the task
while (auto task = tasks_[id].tasks.pop_front()) {
try {
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
completed_tasks_.fetch_sub(1, std::memory_order_release);
} catch (...) {
}
// decrement the unassigned tasks as the task is now going
// to be executed
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
// invoke the task
std::invoke(std::move(task.value()));
// the above task can push more work onto the pool, so we
// only decrement the in flights once the task has been
// executed because now it's now longer "in flight"
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
}

// try to steal a task
Expand All @@ -70,7 +78,7 @@ namespace dp {
// steal a task
unassigned_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
completed_tasks_.fetch_sub(1, std::memory_order_release);
in_flight_tasks_.fetch_sub(1, std::memory_order_release);
// stop stealing once we have invoked a stolen task
break;
}
Expand All @@ -82,8 +90,9 @@ namespace dp {
priority_queue_.rotate_to_front(id);
// check if all tasks are completed and release the barrier (binary
// semaphore)
if (completed_tasks_.load(std::memory_order_acquire) == 0) {
threads_done_.release();
if (in_flight_tasks_.load(std::memory_order_acquire) == 0) {
threads_complete_signal_.store(true, std::memory_order_release);
threads_complete_signal_.notify_one();
}

} while (!stop_tok.stop_requested());
Expand Down Expand Up @@ -214,16 +223,21 @@ namespace dp {
}));
}

/**
* @brief Returns the number of threads in the pool.
*
* @return std::size_t The number of threads in the pool.
*/
[[nodiscard]] auto size() const { return threads_.size(); }

/**
* @brief Wait for all tasks to finish.
* @details This function will block until all tasks have been completed.
*/
void wait_for_tasks() {
if (completed_tasks_.load(std::memory_order_acquire) > 0) {
if (in_flight_tasks_.load(std::memory_order_acquire) > 0) {
// wait for all tasks to finish
threads_done_.acquire();
threads_complete_signal_.wait(false);
}
}

Expand All @@ -235,9 +249,19 @@ namespace dp {
// would only be a problem if there are zero threads
return;
}
// get the index
auto i = *(i_opt);
unassigned_tasks_.fetch_add(1, std::memory_order_relaxed);
completed_tasks_.fetch_add(1, std::memory_order_relaxed);

// increment the unassigned tasks and in flight tasks
unassigned_tasks_.fetch_add(1, std::memory_order_release);
const auto prev_in_flight = in_flight_tasks_.fetch_add(1, std::memory_order_release);

// reset the in flight signal if the list was previously empty
if (prev_in_flight == 0) {
threads_complete_signal_.store(false, std::memory_order_release);
}

// assign work
tasks_[i].tasks.push_back(std::forward<Function>(f));
tasks_[i].signal.release();
}
Expand All @@ -250,8 +274,9 @@ namespace dp {
std::vector<ThreadType> threads_;
std::deque<task_item> tasks_;
dp::thread_safe_queue<std::size_t> priority_queue_;
std::atomic_int_fast64_t unassigned_tasks_{}, completed_tasks_{};
std::binary_semaphore threads_done_{0};
// guarantee these get zero-initialized
std::atomic_int_fast64_t unassigned_tasks_{0}, in_flight_tasks_{0};
std::atomic_bool threads_complete_signal_{false};
};

/**
Expand Down
81 changes: 81 additions & 0 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <thread_pool/version.h>

#include <algorithm>
#include <array>
#include <iostream>
#include <numeric>
#include <random>
Expand Down Expand Up @@ -469,6 +470,86 @@ TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") {
CHECK_EQ(counter.load(), total_tasks);
}

TEST_CASE("Ensure wait_for_tasks() properly waits for tasks to fully complete") {
class counter_wrapper {
public:
std::atomic_int counter = 0;

void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
};

dp::thread_pool local_pool{};
constexpr auto task_count = 10;
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 17; var1++) {
for (int var2 = 0; var2 < 12; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

auto all_correct_count =
std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
const auto sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 17 * 12 * task_count);
CHECK(all_correct_count);
}

TEST_CASE("Ensure wait_for_tasks() can be called multiple times on the same pool") {
class counter_wrapper {
public:
std::atomic_int counter = 0;

void increment_counter() { counter.fetch_add(1, std::memory_order_release); }
};

dp::thread_pool local_pool{};
constexpr auto task_count = 10;
std::array<int, task_count> counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}};
for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 16; var1++) {
for (int var2 = 0; var2 < 13; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

auto all_correct_count =
std::ranges::all_of(counts, [](int count) { return count == 16 * 13; });
auto sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 16 * 13 * task_count);
CHECK(all_correct_count);

for (size_t i = 0; i < task_count; i++) {
counter_wrapper cnt_wrp{};

for (size_t var1 = 0; var1 < 17; var1++) {
for (int var2 = 0; var2 < 12; var2++) {
local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); });
}
}
local_pool.wait_for_tasks();
// std::cout << cnt_wrp.counter << std::endl;
counts[i] = cnt_wrp.counter.load(std::memory_order_acquire);
}

all_correct_count = std::ranges::all_of(counts, [](int count) { return count == 17 * 12; });
sum = std::accumulate(counts.begin(), counts.end(), 0);
CHECK_EQ(sum, 17 * 12 * task_count);
CHECK(all_correct_count);
}

TEST_CASE("Initialization function is called") {
std::atomic_int counter = 0;
{
Expand Down
Loading