From 6eae8239b9766f0485d6f15a44882abb51712341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20H=C3=A9vr?= Date: Sat, 13 Apr 2024 20:12:34 +0200 Subject: [PATCH 1/3] Add premature exit test --- test/source/thread_pool.cpp | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 38a31a6..0a99ab5 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -346,3 +346,38 @@ TEST_CASE("Recursive parallel sort") { CHECK(std::ranges::is_sorted(data)); } + + +TEST_CASE("Test premature exit") { + // two threads in pool, thread1, thread2 + // first, push task_1 + // task_1 pushes task_2 and sleeps, so both threads are busy and no tasks are in queue + // thread1 - task1, thread2 - task2 + // task_1 finishes, no tasks in queue, but task_2 is still running --> thread1 must not exit + // task_2 pushes another task (end_task) and sleeps for 5s before finishing the task_2 + // So the first thread, thread1 should execute the end_task + // but if the thread1 prematurely exits, than the end_task will be executed by the thread2 + + std::thread::id id_task_1, id_end; + { + dp::thread_pool<> testPool(2); + + auto end = [&]() { id_end = std::this_thread::get_id(); }; + + auto task_2 = [&]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + testPool.enqueue_detach(end); + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + }; + + auto task_1 = [&]() { + id_task_1 = std::this_thread::get_id(); + testPool.enqueue_detach(task_2); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + }; + + testPool.enqueue_detach(task_1); + } + + CHECK_EQ(id_task_1, id_end); +} \ No newline at end of file From 65e05d52d5cc618a4bb86a9a933db46a2642f38e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20H=C3=A9vr?= Date: Sat, 13 Apr 2024 20:13:13 +0200 Subject: [PATCH 2/3] Avoid premature thread exiting by monitoring if some tasks are still being processed --- include/thread_pool/thread_pool.h | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 08a4735..f6225b2 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -52,6 +52,7 @@ namespace dp { try { pending_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); + total_tasks_.fetch_sub(1, std::memory_order_release); } catch (...) { } } @@ -63,6 +64,7 @@ namespace dp { // steal a task pending_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); + total_tasks_.fetch_sub(1, std::memory_order_release); // stop stealing once we have invoked a stolen task break; } @@ -72,6 +74,10 @@ namespace dp { priority_queue_.rotate_to_front(id); + if (total_tasks_.load(std::memory_order_acquire) == 0) { + threads_done_.release(); + } + } while (!stop_tok.stop_requested()); }); // increment the thread id @@ -90,6 +96,11 @@ namespace dp { } ~thread_pool() { + if (total_tasks_.load(std::memory_order_acquire) > 0) { + // wait for all tasks to finish + threads_done_.acquire(); + } + // stop all threads for (std::size_t i = 0; i < threads_.size(); ++i) { threads_[i].request_stop(); @@ -204,6 +215,7 @@ namespace dp { } auto i = *(i_opt); pending_tasks_.fetch_add(1, std::memory_order_relaxed); + total_tasks_.fetch_add(1, std::memory_order_relaxed); tasks_[i].tasks.push_back(std::forward(f)); tasks_[i].signal.release(); } @@ -216,7 +228,8 @@ namespace dp { std::vector threads_; std::deque tasks_; dp::thread_safe_queue priority_queue_; - std::atomic_int_fast64_t pending_tasks_{}; + std::atomic_int_fast64_t pending_tasks_{}, total_tasks_{}; + std::binary_semaphore threads_done_{0}; }; /** From 8814340588830c6145e3436af744972f96764ede Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20H=C3=A9vr?= Date: Thu, 18 Apr 2024 12:19:15 +0200 Subject: [PATCH 3/3] Fix the stack-use-after-scope --- test/source/thread_pool.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 0a99ab5..4484331 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -362,15 +362,15 @@ TEST_CASE("Test premature exit") { { dp::thread_pool<> testPool(2); - auto end = [&]() { id_end = std::this_thread::get_id(); }; + auto end = [&id_end]() { id_end = std::this_thread::get_id(); }; - auto task_2 = [&]() { + auto task_2 = [&testPool, end]() { std::this_thread::sleep_for(std::chrono::milliseconds(1000)); testPool.enqueue_detach(end); std::this_thread::sleep_for(std::chrono::milliseconds(5000)); }; - auto task_1 = [&]() { + auto task_1 = [&testPool, &id_task_1, task_2]() { id_task_1 = std::this_thread::get_id(); testPool.enqueue_detach(task_2); std::this_thread::sleep_for(std::chrono::milliseconds(500));