Skip to content

Commit

Permalink
threadpool: reduce pause/resume/wakeup overhead in common cases
Browse files Browse the repository at this point in the history
We now start threadpool in paused state only if we have two.
The resume is now implicit (ie new work) which allows for reduced locking and context-switch overhead.
  • Loading branch information
max-krasnyansky committed Aug 11, 2024
1 parent 20db9f4 commit b18719b
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 25 deletions.
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p
tpp.prio = params.priority;
tpp.poll = params.poll;
tpp.strict_cpu = params.strict_cpu;
tpp.paused = false;

return tpp;
}
Expand Down
25 changes: 14 additions & 11 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,6 @@ int main(int argc, char ** argv) {
struct ggml_threadpool_params tpp =
ggml_threadpool_params_from_cpu_params(params.cpuparams);

struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}

llama_attach_threadpool(ctx, threadpool);
if (ctx_guidance) {
llama_attach_threadpool(ctx_guidance, threadpool);
}

struct ggml_compute_threadpool * threadpool_batch = NULL;
if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) {
threadpool_batch = ggml_create_threadpool(&tpp_batch);
Expand All @@ -253,6 +242,20 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
llama_attach_batch_threadpool(ctx_guidance, threadpool_batch);
}

// Start the non-batch threadpool in the paused state
tpp.paused = true;
}

struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}

llama_attach_threadpool(ctx, threadpool);
if (ctx_guidance) {
llama_attach_threadpool(ctx_guidance, threadpool);
}

const int n_ctx_train = llama_n_ctx_train(model);
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,7 @@ extern "C" {
int32_t prio;
bool poll;
bool strict_cpu;
bool paused;
};

struct ggml_compute_threadpool; // forward declaration, see ggml.c
Expand Down
42 changes: 31 additions & 11 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -18964,14 +18964,27 @@ void ggml_release_threadpool(struct ggml_compute_threadpool* threadpool) {
GGML_ALIGNED_FREE(threadpool);
}

#ifndef GGML_USE_OPENMP
// pause/resume must be called under mutex
static void __ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) {
GGML_PRINT_DEBUG("Pausing threadpool\n");
threadpool->pause = true;
ggml_cond_broadcast(&threadpool->cond);
}

static void __ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) {
GGML_PRINT_DEBUG("Resuming threadpool\n");
threadpool->pause = false;
ggml_cond_broadcast(&threadpool->cond);
}
#endif

void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) {
#ifndef GGML_USE_OPENMP
GGML_ASSERT(!threadpool->disposable);
GGML_PRINT_DEBUG("Pausing threadpool\n");
ggml_mutex_lock(&threadpool->mutex);
if (!threadpool->pause) {
threadpool->pause = true;
ggml_cond_broadcast(&threadpool->cond);
__ggml_pause_threadpool(threadpool);
}
ggml_mutex_unlock(&threadpool->mutex);
#else
Expand All @@ -18982,12 +18995,9 @@ void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) {
void ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) {
#ifndef GGML_USE_OPENMP
GGML_ASSERT(!threadpool->disposable);
GGML_PRINT_DEBUG("Resuming threadpool\n");

ggml_mutex_lock(&threadpool->mutex);
if (threadpool->pause) {
threadpool->pause = false;
ggml_cond_broadcast(&threadpool->cond);
__ggml_resume_threadpool(threadpool);
}
ggml_mutex_unlock(&threadpool->mutex);
#else
Expand Down Expand Up @@ -19329,7 +19339,7 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl(
threadpool->n_barrier_passed = 0;
threadpool->current_chunk = 0;
threadpool->stop = false;
threadpool->pause = disposable ? false : true;
threadpool->pause = disposable ? false : tpp->paused;
threadpool->new_work = false;
threadpool->workers = NULL;
threadpool->n_threads_max = tpp->n_threads;
Expand Down Expand Up @@ -19419,9 +19429,10 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
struct ggml_threadpool_params ttp = {
.mask_specified = false,
.n_threads = n_threads,
.prio = 1,
.prio = 0,
.poll = false,
.strict_cpu = false
.strict_cpu = false,
.paused = false
};

threadpool = ggml_create_threadpool_impl(&ttp, true, cgraph, cplan);
Expand Down Expand Up @@ -19475,10 +19486,19 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
if (!threadpool->poll) {
ggml_mutex_lock(&threadpool->mutex);
threadpool->new_work = true;
ggml_cond_broadcast(&threadpool->cond);
if (threadpool->pause) {
__ggml_resume_threadpool(threadpool);
} else {
ggml_cond_broadcast(&threadpool->cond);
}
ggml_mutex_unlock(&threadpool->mutex);
} else {
threadpool->new_work = true;
if (threadpool->pause) {
ggml_mutex_lock(&threadpool->mutex);
__ggml_resume_threadpool(threadpool);
ggml_mutex_unlock(&threadpool->mutex);
}
}
}
// this is a work thread too
Expand Down
3 changes: 0 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14458,17 +14458,14 @@ static std::pair<int32_t, ggml_compute_threadpool_t> llama_swap_threadpools(
// Switch between the 2 threadpools as needed
if (n_tokens > 1) {
ggml_pause_threadpool(lctx.threadpool);
ggml_resume_threadpool(lctx.threadpool_batch);
threadpool = lctx.threadpool_batch;
n_threads = cparams.n_threads_batch;
} else {
ggml_pause_threadpool(lctx.threadpool_batch);
ggml_resume_threadpool(lctx.threadpool);
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
} else if (lctx.threadpool) {
ggml_resume_threadpool(lctx.threadpool);
threadpool = lctx.threadpool;
n_threads = cparams.n_threads;
}
Expand Down

0 comments on commit b18719b

Please sign in to comment.