From b18719b3027d0bd72d45419e9599c6460ef18908 Mon Sep 17 00:00:00 2001 From: Max Krasnyansky Date: Sat, 10 Aug 2024 16:12:06 -0700 Subject: [PATCH] threadpool: reduce pause/resume/wakeup overhead in common cases We now start threadpool in paused state only if we have two. The resume is now implicit (ie new work) which allows for reduced locking and context-switch overhead. --- common/common.cpp | 1 + examples/main/main.cpp | 25 ++++++++++++++----------- ggml/include/ggml.h | 1 + ggml/src/ggml.c | 42 +++++++++++++++++++++++++++++++----------- src/llama.cpp | 3 --- 5 files changed, 47 insertions(+), 25 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 5abddaefa6381..df6e1624ef7b5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2507,6 +2507,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p tpp.prio = params.priority; tpp.poll = params.poll; tpp.strict_cpu = params.strict_cpu; + tpp.paused = false; return tpp; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 56e8730593115..e7d42e9cec731 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -230,17 +230,6 @@ int main(int argc, char ** argv) { struct ggml_threadpool_params tpp = ggml_threadpool_params_from_cpu_params(params.cpuparams); - struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); - if (!threadpool) { - LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); - exit(1); - } - - llama_attach_threadpool(ctx, threadpool); - if (ctx_guidance) { - llama_attach_threadpool(ctx_guidance, threadpool); - } - struct ggml_compute_threadpool * threadpool_batch = NULL; if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) { threadpool_batch = ggml_create_threadpool(&tpp_batch); @@ -253,6 +242,20 @@ int main(int argc, char ** argv) { if (ctx_guidance) { llama_attach_batch_threadpool(ctx_guidance, threadpool_batch); } + + // Start the non-batch threadpool in the paused state + tpp.paused = true; + } + + struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + + llama_attach_threadpool(ctx, threadpool); + if (ctx_guidance) { + llama_attach_threadpool(ctx_guidance, threadpool); } const int n_ctx_train = llama_n_ctx_train(model); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index af74231565786..923182d9d9710 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -631,6 +631,7 @@ extern "C" { int32_t prio; bool poll; bool strict_cpu; + bool paused; }; struct ggml_compute_threadpool; // forward declaration, see ggml.c diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81c7a33878253..308e569856c70 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -18964,14 +18964,27 @@ void ggml_release_threadpool(struct ggml_compute_threadpool* threadpool) { GGML_ALIGNED_FREE(threadpool); } +#ifndef GGML_USE_OPENMP +// pause/resume must be called under mutex +static void __ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { + GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + ggml_cond_broadcast(&threadpool->cond); +} + +static void __ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) { + GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + ggml_cond_broadcast(&threadpool->cond); +} +#endif + void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { #ifndef GGML_USE_OPENMP GGML_ASSERT(!threadpool->disposable); - GGML_PRINT_DEBUG("Pausing threadpool\n"); ggml_mutex_lock(&threadpool->mutex); if (!threadpool->pause) { - threadpool->pause = true; - ggml_cond_broadcast(&threadpool->cond); + __ggml_pause_threadpool(threadpool); } ggml_mutex_unlock(&threadpool->mutex); #else @@ -18982,12 +18995,9 @@ void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { void ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) { #ifndef GGML_USE_OPENMP GGML_ASSERT(!threadpool->disposable); - GGML_PRINT_DEBUG("Resuming threadpool\n"); - ggml_mutex_lock(&threadpool->mutex); if (threadpool->pause) { - threadpool->pause = false; - ggml_cond_broadcast(&threadpool->cond); + __ggml_resume_threadpool(threadpool); } ggml_mutex_unlock(&threadpool->mutex); #else @@ -19329,7 +19339,7 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl( threadpool->n_barrier_passed = 0; threadpool->current_chunk = 0; threadpool->stop = false; - threadpool->pause = disposable ? false : true; + threadpool->pause = disposable ? false : tpp->paused; threadpool->new_work = false; threadpool->workers = NULL; threadpool->n_threads_max = tpp->n_threads; @@ -19419,9 +19429,10 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl struct ggml_threadpool_params ttp = { .mask_specified = false, .n_threads = n_threads, - .prio = 1, + .prio = 0, .poll = false, - .strict_cpu = false + .strict_cpu = false, + .paused = false }; threadpool = ggml_create_threadpool_impl(&ttp, true, cgraph, cplan); @@ -19475,10 +19486,19 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl if (!threadpool->poll) { ggml_mutex_lock(&threadpool->mutex); threadpool->new_work = true; - ggml_cond_broadcast(&threadpool->cond); + if (threadpool->pause) { + __ggml_resume_threadpool(threadpool); + } else { + ggml_cond_broadcast(&threadpool->cond); + } ggml_mutex_unlock(&threadpool->mutex); } else { threadpool->new_work = true; + if (threadpool->pause) { + ggml_mutex_lock(&threadpool->mutex); + __ggml_resume_threadpool(threadpool); + ggml_mutex_unlock(&threadpool->mutex); + } } } // this is a work thread too diff --git a/src/llama.cpp b/src/llama.cpp index 6123510c93b77..70194b54b6db0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14458,17 +14458,14 @@ static std::pair llama_swap_threadpools( // Switch between the 2 threadpools as needed if (n_tokens > 1) { ggml_pause_threadpool(lctx.threadpool); - ggml_resume_threadpool(lctx.threadpool_batch); threadpool = lctx.threadpool_batch; n_threads = cparams.n_threads_batch; } else { ggml_pause_threadpool(lctx.threadpool_batch); - ggml_resume_threadpool(lctx.threadpool); threadpool = lctx.threadpool; n_threads = cparams.n_threads; } } else if (lctx.threadpool) { - ggml_resume_threadpool(lctx.threadpool); threadpool = lctx.threadpool; n_threads = cparams.n_threads; }