diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index dbbe331..10566a7 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -25,7 +25,7 @@ jobs: - name: Install format dependencies run: | - choco install llvm --version 16.0.1 -y + choco install llvm --version 18.1.3 -y choco install ninja -y pip3 install cmake_format==0.6.11 pyyaml diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 6a24a9f..8a8fe22 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -98,10 +98,7 @@ namespace dp { } ~thread_pool() { - if (completed_tasks_.load(std::memory_order_acquire) > 0) { - // wait for all tasks to finish - threads_done_.acquire(); - } + wait_for_tasks(); // stop all threads for (std::size_t i = 0; i < threads_.size(); ++i) { @@ -207,6 +204,13 @@ namespace dp { [[nodiscard]] auto size() const { return threads_.size(); } + void wait_for_tasks() { + if (completed_tasks_.load(std::memory_order_acquire) > 0) { + // wait for all tasks to finish + threads_done_.acquire(); + } + } + private: template void enqueue_task(Function &&f) { diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 0c3f92d..9fe8f4f 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -423,3 +423,25 @@ TEST_CASE("Test premature exit") { CHECK_NE(spawn_task_id, task_3_id); CHECK_NE(task_1_id, task_2_id); } + +TEST_CASE("Ensure wait_for_tasks() properly blocks current execution.") { + std::atomic counter = 0; + int total_tasks{}; + constexpr auto thread_count = 4; + + SUBCASE("with tasks") { total_tasks = 30; } + SUBCASE("with no tasks") { total_tasks = 0; } + SUBCASE("with task count less than thread count") { total_tasks = thread_count / 2; } + + dp::thread_pool pool(thread_count); + for (auto i = 0; i < total_tasks; i++) { + auto task = [i, &counter]() { + std::this_thread::sleep_for(std::chrono::milliseconds((i + 1) * 10)); + ++counter; + }; + pool.enqueue_detach(task); + } + pool.wait_for_tasks(); + + CHECK_EQ(counter.load(), total_tasks); +}