diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 23bdb1b..fc67998 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -477,6 +477,55 @@ TEST_CASE("Ensure wait_for_tasks() properly waits for tasks to fully complete") CHECK(all_correct_count); } +TEST_CASE("Ensure wait_for_tasks() can be called multiple times on the same pool") { + class counter_wrapper { + public: + std::atomic_int counter = 0; + + void increment_counter() { counter.fetch_add(1, std::memory_order_release); } + }; + + dp::thread_pool local_pool{}; + constexpr auto task_count = 10; + std::array counts{{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}; + for (size_t i = 0; i < task_count; i++) { + counter_wrapper cnt_wrp{}; + + for (size_t var1 = 0; var1 < 17; var1++) { + for (int var2 = 0; var2 < 12; var2++) { + local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); }); + } + } + local_pool.wait_for_tasks(); + // std::cout << cnt_wrp.counter << std::endl; + counts[i] = cnt_wrp.counter.load(std::memory_order_acquire); + } + + auto all_correct_count = + std::ranges::all_of(counts, [](int count) { return count == 17 * 12; }); + auto sum = std::accumulate(counts.begin(), counts.end(), 0); + CHECK_EQ(sum, 17 * 12 * task_count); + CHECK(all_correct_count); + + for (size_t i = 0; i < task_count; i++) { + counter_wrapper cnt_wrp{}; + + for (size_t var1 = 0; var1 < 17; var1++) { + for (int var2 = 0; var2 < 12; var2++) { + local_pool.enqueue_detach([&cnt_wrp]() { cnt_wrp.increment_counter(); }); + } + } + local_pool.wait_for_tasks(); + // std::cout << cnt_wrp.counter << std::endl; + counts[i] = cnt_wrp.counter.load(std::memory_order_acquire); + } + + all_correct_count = std::ranges::all_of(counts, [](int count) { return count == 17 * 12; }); + sum = std::accumulate(counts.begin(), counts.end(), 0); + CHECK_EQ(sum, 17 * 12 * task_count); + CHECK(all_correct_count); +} + TEST_CASE("Initialization function is called") { std::atomic_int counter = 0; {