forked from Tyler-Hardin/thread_pool
-
Notifications
You must be signed in to change notification settings - Fork 0
/
priority_thread_pool.cpp
149 lines (129 loc) · 3.55 KB
/
priority_thread_pool.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include "priority_thread_pool.hpp"
#include <map>
#include <cassert>
#include <unistd.h>
using namespace std;
// Size of stacks used by task executor contexts.
constexpr std::size_t STACK_SIZE = 1024 * 8;
static map<std::thread::id, std::shared_ptr<priority_task>> cur_tasks;
static mutex cur_tasks_mutex;
priority_task::~priority_task() {
assert(done);
free(work_stack);
// Fail fast.
work_stack = nullptr;
work_context.uc_stack.ss_sp = nullptr;
}
bool priority_task::operator<(const priority_task& t) const {
return priority < t.priority;
}
/**
* Actually runs the task, in a forked context.
*/
void priority_task::_run(void) {
shared_ptr<priority_task> t;
{
lock_guard<mutex> lk(cur_tasks_mutex);
auto it = cur_tasks.find(this_thread::get_id());
assert(it != cur_tasks.end());
t = it->second;
}
t->work();
t->done = true;
}
/**
* Starts or resumes the forked context and returns whether it is finished.
*/
bool priority_task::run() {
paused = false;
// This is where we'll resume when yield is called.
getcontext(&pause_context);
if(!started) {
// Create the context which will execute the function.
getcontext(&work_context);
work_stack = malloc(STACK_SIZE);
work_context.uc_stack.ss_size = STACK_SIZE;
work_context.uc_stack.ss_sp = work_stack;
work_context.uc_stack.ss_flags = 0;
work_context.uc_link = &pause_context;
makecontext(&work_context, &_run, 0);
started = true;
setcontext(&work_context);
}
else {
// done will be true after work_context returns.
if(done) {
return true;
}
// pause will be true if we're called by setcontext(&pause_context) in task::pause().
else if(!paused) {
setcontext(&work_context);
}
// Effectively, this is the case wherein we're not done (the work
// function hasn't returned) and we're paused. So we return false,
// signifying to the priority_thread_pool::handle_task that the
// task needs to be added back to be resumed later.
else {
return false;
}
}
}
/**
* Pauses the work context and resumes the context in ::run().
*/
void priority_task::pause() {
paused = true;
// We will resume here when task::run is called a second time.
getcontext(&work_context);
// pause will be false if we're being resumed with a second call to
// task::run. (I.e. setcontext(&work_context).)
if(paused) {
// Jump back into task::run() to return to scheduler.
setcontext(&pause_context);
}
// else return back to running work context.
}
priority_thread_pool::priority_thread_pool(unsigned int n) : base_thread_pool(n) {
init_mutex.unlock();
}
priority_thread_pool::~priority_thread_pool() {
wait();
}
/**
* Yields the task the current thread is running.
*/
void priority_thread_pool::yield() {
cur_tasks_mutex.lock();
auto it = cur_tasks.find(std::this_thread::get_id());
assert(it != cur_tasks.end());
auto task = it->second;
cur_tasks_mutex.unlock();
task->pause();
}
optional<shared_ptr<priority_task>> priority_thread_pool::get_task() {
optional<shared_ptr<priority_task>> ret;
lock_guard<mutex> lk(task_mutex);
if(!tasks.empty()) {
ret = tasks.top();
tasks.pop();
}
return ret;
}
void priority_thread_pool::handle_task(shared_ptr<priority_task> t) {
auto id = this_thread::get_id();
{
lock_guard<mutex> lk(cur_tasks_mutex);
assert(cur_tasks.emplace(id, t).second);
}
bool finished = t->run();
{
lock_guard<mutex> lk(cur_tasks_mutex);
cur_tasks.erase(id);
}
// Finished is true when the task is finished executing. If it's not,
// add it back to the heap and resume it later.
if(!finished) {
lock_guard<mutex> lk(task_mutex);
tasks.emplace(t);
}
}