Skip to content

Commit

Permalink
feat: add wait_for_tasks() (#62)
Browse files Browse the repository at this point in the history
Added `wait_for_tasks()` feature, building on top of the work done in #61. This simply refactors some of that code into a public method that users can call to block the current thread and wait for all tasks to complete.
  • Loading branch information
DeveloperPaul123 authored Apr 25, 2024
1 parent 97a1329 commit 9e94e28
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/style.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

- name: Install format dependencies
run: |
choco install llvm --version 16.0.1 -y
choco install llvm --version 18.1.3 -y
choco install ninja -y
pip3 install cmake_format==0.6.11 pyyaml
Expand Down
12 changes: 8 additions & 4 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ namespace dp {
}

~thread_pool() {
if (completed_tasks_.load(std::memory_order_acquire) > 0) {
// wait for all tasks to finish
threads_done_.acquire();
}
wait_for_tasks();

// stop all threads
for (std::size_t i = 0; i < threads_.size(); ++i) {
Expand Down Expand Up @@ -207,6 +204,13 @@ namespace dp {

[[nodiscard]] auto size() const { return threads_.size(); }

void wait_for_tasks() {
if (completed_tasks_.load(std::memory_order_acquire) > 0) {
// wait for all tasks to finish
threads_done_.acquire();
}
}

private:
template <typename Function>
void enqueue_task(Function &&f) {
Expand Down
22 changes: 22 additions & 0 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,25 @@ TEST_CASE("Test premature exit") {
CHECK_NE(spawn_task_id, task_3_id);
CHECK_NE(task_1_id, task_2_id);
}

TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") {
std::atomic counter = 0;
int total_tasks{};
constexpr auto thread_count = 4;

SUBCASE("with tasks") { total_tasks = 30; }
SUBCASE("with no tasks") { total_tasks = 0; }
SUBCASE("with task count less than thread count") { total_tasks = thread_count / 2; }

dp::thread_pool pool(thread_count);
for (auto i = 0; i < total_tasks; i++) {
auto task = [i, &counter]() {
std::this_thread::sleep_for(std::chrono::milliseconds((i + 1) * 10));
++counter;
};
pool.enqueue_detach(task);
}
pool.wait_for_tasks();

CHECK_EQ(counter.load(), total_tasks);
}

0 comments on commit 9e94e28

Please sign in to comment.