Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent Premature Thread Exiting #61

Merged
merged 3 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace dp {
try {
pending_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
total_tasks_.fetch_sub(1, std::memory_order_release);
} catch (...) {
}
}
Expand All @@ -63,6 +64,7 @@ namespace dp {
// steal a task
pending_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
total_tasks_.fetch_sub(1, std::memory_order_release);
// stop stealing once we have invoked a stolen task
break;
}
Expand All @@ -72,6 +74,10 @@ namespace dp {

priority_queue_.rotate_to_front(id);

if (total_tasks_.load(std::memory_order_acquire) == 0) {
threads_done_.release();
}

} while (!stop_tok.stop_requested());
});
// increment the thread id
Expand All @@ -90,6 +96,11 @@ namespace dp {
}

~thread_pool() {
if (total_tasks_.load(std::memory_order_acquire) > 0) {
// wait for all tasks to finish
threads_done_.acquire();
}

// stop all threads
for (std::size_t i = 0; i < threads_.size(); ++i) {
threads_[i].request_stop();
Expand Down Expand Up @@ -204,6 +215,7 @@ namespace dp {
}
auto i = *(i_opt);
pending_tasks_.fetch_add(1, std::memory_order_relaxed);
total_tasks_.fetch_add(1, std::memory_order_relaxed);
tasks_[i].tasks.push_back(std::forward<Function>(f));
tasks_[i].signal.release();
}
Expand All @@ -216,7 +228,8 @@ namespace dp {
std::vector<ThreadType> threads_;
std::deque<task_item> tasks_;
dp::thread_safe_queue<std::size_t> priority_queue_;
std::atomic_int_fast64_t pending_tasks_{};
std::atomic_int_fast64_t pending_tasks_{}, total_tasks_{};
std::binary_semaphore threads_done_{0};
};

/**
Expand Down
35 changes: 35 additions & 0 deletions test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,38 @@ TEST_CASE("Recursive parallel sort") {

CHECK(std::ranges::is_sorted(data));
}


TEST_CASE("Test premature exit") {
// two threads in pool, thread1, thread2
// first, push task_1
// task_1 pushes task_2 and sleeps, so both threads are busy and no tasks are in queue
// thread1 - task1, thread2 - task2
// task_1 finishes, no tasks in queue, but task_2 is still running --> thread1 must not exit
// task_2 pushes another task (end_task) and sleeps for 5s before finishing the task_2
// So the first thread, thread1 should execute the end_task
// but if the thread1 prematurely exits, than the end_task will be executed by the thread2

std::thread::id id_task_1, id_end;
{
dp::thread_pool<> testPool(2);

auto end = [&id_end]() { id_end = std::this_thread::get_id(); };

auto task_2 = [&testPool, end]() {
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
testPool.enqueue_detach(end);
std::this_thread::sleep_for(std::chrono::milliseconds(5000));
};

auto task_1 = [&testPool, &id_task_1, task_2]() {
id_task_1 = std::this_thread::get_id();
testPool.enqueue_detach(task_2);
std::this_thread::sleep_for(std::chrono::milliseconds(500));
};

testPool.enqueue_detach(task_1);
}

CHECK_EQ(id_task_1, id_end);
}
Loading