diff --git a/common/common.cpp b/common/common.cpp index 5abddaefa6381..df6e1624ef7b5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2507,6 +2507,7 @@ struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_p tpp.prio = params.priority; tpp.poll = params.poll; tpp.strict_cpu = params.strict_cpu; + tpp.paused = false; return tpp; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 56e8730593115..e7d42e9cec731 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -230,17 +230,6 @@ int main(int argc, char ** argv) { struct ggml_threadpool_params tpp = ggml_threadpool_params_from_cpu_params(params.cpuparams); - struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); - if (!threadpool) { - LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); - exit(1); - } - - llama_attach_threadpool(ctx, threadpool); - if (ctx_guidance) { - llama_attach_threadpool(ctx_guidance, threadpool); - } - struct ggml_compute_threadpool * threadpool_batch = NULL; if (!ggml_threadpool_params_match(&tpp, &tpp_batch)) { threadpool_batch = ggml_create_threadpool(&tpp_batch); @@ -253,6 +242,20 @@ int main(int argc, char ** argv) { if (ctx_guidance) { llama_attach_batch_threadpool(ctx_guidance, threadpool_batch); } + + // Start the non-batch threadpool in the paused state + tpp.paused = true; + } + + struct ggml_compute_threadpool * threadpool = ggml_create_threadpool(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + + llama_attach_threadpool(ctx, threadpool); + if (ctx_guidance) { + llama_attach_threadpool(ctx_guidance, threadpool); } const int n_ctx_train = llama_n_ctx_train(model); diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index af74231565786..923182d9d9710 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -631,6 +631,7 @@ extern "C" { int32_t prio; bool poll; bool strict_cpu; + bool paused; }; struct ggml_compute_threadpool; // forward declaration, see ggml.c diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81c7a33878253..308e569856c70 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -18964,14 +18964,27 @@ void ggml_release_threadpool(struct ggml_compute_threadpool* threadpool) { GGML_ALIGNED_FREE(threadpool); } +#ifndef GGML_USE_OPENMP +// pause/resume must be called under mutex +static void __ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { + GGML_PRINT_DEBUG("Pausing threadpool\n"); + threadpool->pause = true; + ggml_cond_broadcast(&threadpool->cond); +} + +static void __ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) { + GGML_PRINT_DEBUG("Resuming threadpool\n"); + threadpool->pause = false; + ggml_cond_broadcast(&threadpool->cond); +} +#endif + void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { #ifndef GGML_USE_OPENMP GGML_ASSERT(!threadpool->disposable); - GGML_PRINT_DEBUG("Pausing threadpool\n"); ggml_mutex_lock(&threadpool->mutex); if (!threadpool->pause) { - threadpool->pause = true; - ggml_cond_broadcast(&threadpool->cond); + __ggml_pause_threadpool(threadpool); } ggml_mutex_unlock(&threadpool->mutex); #else @@ -18982,12 +18995,9 @@ void ggml_pause_threadpool(struct ggml_compute_threadpool * threadpool) { void ggml_resume_threadpool(struct ggml_compute_threadpool * threadpool) { #ifndef GGML_USE_OPENMP GGML_ASSERT(!threadpool->disposable); - GGML_PRINT_DEBUG("Resuming threadpool\n"); - ggml_mutex_lock(&threadpool->mutex); if (threadpool->pause) { - threadpool->pause = false; - ggml_cond_broadcast(&threadpool->cond); + __ggml_resume_threadpool(threadpool); } ggml_mutex_unlock(&threadpool->mutex); #else @@ -19329,7 +19339,7 @@ static struct ggml_compute_threadpool * ggml_create_threadpool_impl( threadpool->n_barrier_passed = 0; threadpool->current_chunk = 0; threadpool->stop = false; - threadpool->pause = disposable ? false : true; + threadpool->pause = disposable ? false : tpp->paused; threadpool->new_work = false; threadpool->workers = NULL; threadpool->n_threads_max = tpp->n_threads; @@ -19419,9 +19429,10 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl struct ggml_threadpool_params ttp = { .mask_specified = false, .n_threads = n_threads, - .prio = 1, + .prio = 0, .poll = false, - .strict_cpu = false + .strict_cpu = false, + .paused = false }; threadpool = ggml_create_threadpool_impl(&ttp, true, cgraph, cplan); @@ -19475,10 +19486,19 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl if (!threadpool->poll) { ggml_mutex_lock(&threadpool->mutex); threadpool->new_work = true; - ggml_cond_broadcast(&threadpool->cond); + if (threadpool->pause) { + __ggml_resume_threadpool(threadpool); + } else { + ggml_cond_broadcast(&threadpool->cond); + } ggml_mutex_unlock(&threadpool->mutex); } else { threadpool->new_work = true; + if (threadpool->pause) { + ggml_mutex_lock(&threadpool->mutex); + __ggml_resume_threadpool(threadpool); + ggml_mutex_unlock(&threadpool->mutex); + } } } // this is a work thread too diff --git a/src/llama.cpp b/src/llama.cpp index 6123510c93b77..70194b54b6db0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -14458,17 +14458,14 @@ static std::pair llama_swap_threadpools( // Switch between the 2 threadpools as needed if (n_tokens > 1) { ggml_pause_threadpool(lctx.threadpool); - ggml_resume_threadpool(lctx.threadpool_batch); threadpool = lctx.threadpool_batch; n_threads = cparams.n_threads_batch; } else { ggml_pause_threadpool(lctx.threadpool_batch); - ggml_resume_threadpool(lctx.threadpool); threadpool = lctx.threadpool; n_threads = cparams.n_threads; } } else if (lctx.threadpool) { - ggml_resume_threadpool(lctx.threadpool); threadpool = lctx.threadpool; n_threads = cparams.n_threads; }