diff --git a/include/thread_pool/thread_pool.h b/include/thread_pool/thread_pool.h index 08a4735..f6225b2 100644 --- a/include/thread_pool/thread_pool.h +++ b/include/thread_pool/thread_pool.h @@ -52,6 +52,7 @@ namespace dp { try { pending_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); + total_tasks_.fetch_sub(1, std::memory_order_release); } catch (...) { } } @@ -63,6 +64,7 @@ namespace dp { // steal a task pending_tasks_.fetch_sub(1, std::memory_order_release); std::invoke(std::move(task.value())); + total_tasks_.fetch_sub(1, std::memory_order_release); // stop stealing once we have invoked a stolen task break; } @@ -72,6 +74,10 @@ namespace dp { priority_queue_.rotate_to_front(id); + if (total_tasks_.load(std::memory_order_acquire) == 0) { + threads_done_.release(); + } + } while (!stop_tok.stop_requested()); }); // increment the thread id @@ -90,6 +96,11 @@ namespace dp { } ~thread_pool() { + if (total_tasks_.load(std::memory_order_acquire) > 0) { + // wait for all tasks to finish + threads_done_.acquire(); + } + // stop all threads for (std::size_t i = 0; i < threads_.size(); ++i) { threads_[i].request_stop(); @@ -204,6 +215,7 @@ namespace dp { } auto i = *(i_opt); pending_tasks_.fetch_add(1, std::memory_order_relaxed); + total_tasks_.fetch_add(1, std::memory_order_relaxed); tasks_[i].tasks.push_back(std::forward(f)); tasks_[i].signal.release(); } @@ -216,7 +228,8 @@ namespace dp { std::vector threads_; std::deque tasks_; dp::thread_safe_queue priority_queue_; - std::atomic_int_fast64_t pending_tasks_{}; + std::atomic_int_fast64_t pending_tasks_{}, total_tasks_{}; + std::binary_semaphore threads_done_{0}; }; /** diff --git a/test/source/thread_pool.cpp b/test/source/thread_pool.cpp index 38a31a6..4484331 100644 --- a/test/source/thread_pool.cpp +++ b/test/source/thread_pool.cpp @@ -346,3 +346,38 @@ TEST_CASE("Recursive parallel sort") { CHECK(std::ranges::is_sorted(data)); } + + +TEST_CASE("Test premature exit") { + // two threads in pool, thread1, thread2 + // first, push task_1 + // task_1 pushes task_2 and sleeps, so both threads are busy and no tasks are in queue + // thread1 - task1, thread2 - task2 + // task_1 finishes, no tasks in queue, but task_2 is still running --> thread1 must not exit + // task_2 pushes another task (end_task) and sleeps for 5s before finishing the task_2 + // So the first thread, thread1 should execute the end_task + // but if the thread1 prematurely exits, than the end_task will be executed by the thread2 + + std::thread::id id_task_1, id_end; + { + dp::thread_pool<> testPool(2); + + auto end = [&id_end]() { id_end = std::this_thread::get_id(); }; + + auto task_2 = [&testPool, end]() { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + testPool.enqueue_detach(end); + std::this_thread::sleep_for(std::chrono::milliseconds(5000)); + }; + + auto task_1 = [&testPool, &id_task_1, task_2]() { + id_task_1 = std::this_thread::get_id(); + testPool.enqueue_detach(task_2); + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + }; + + testPool.enqueue_detach(task_1); + } + + CHECK_EQ(id_task_1, id_end); +} \ No newline at end of file