diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 5a929ceddafbe3..1009ac57b7be20 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -291,7 +291,6 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str()); printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str()); printf(" --numa (default: disabled)\n"); - printf(" -mt, --max-threads (default: %d)\n", cmd_params_defaults.cpuparams.n_threads); printf(" -C, --cpu-mask (default: 0x0)\n"); printf(" --cpu-strict <0|1> (default: %d)\n", cmd_params_defaults.cpuparams.strict_cpu); printf(" --priority <0|1|2|3> (default: %d)\n", cmd_params_defaults.cpuparams.priority); @@ -499,12 +498,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; } else { invalid_param = true; break; } } - } else if (arg == "-mt" || arg == "--max-threads") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.cpuparams.n_threads = std::stoi(argv[i]); } else if (arg == "-C" || arg == "--cpu-mask") { if (++i >= argc) { invalid_param = true; @@ -1435,21 +1428,6 @@ int main(int argc, char ** argv) { postprocess_cpu_params(params.cpuparams); - struct ggml_threadpool_params tpp; - tpp.n_threads = params.cpuparams.n_threads; - tpp.mask_specified = params.cpuparams.mask_valid; - tpp.strict_cpu = params.cpuparams.strict_cpu; - tpp.prio = params.cpuparams.priority; - tpp.poll = params.cpuparams.poll; - - std::memcpy(&tpp.cpumask[0], ¶ms.cpuparams.cpumask[0], GGML_MAX_N_THREADS); - - struct ggml_compute_threadpool* threadpool = ggml_create_threadpool(&tpp); - if (!threadpool) { - LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); - exit(1); - } - for (const auto & inst : params_instances) { // keep the same model between tests when possible if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) { @@ -1475,6 +1453,22 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); llama_kv_cache_clear(ctx); + + struct ggml_threadpool_params tpp; + tpp.n_threads = t.n_threads; + tpp.mask_specified = params.cpuparams.mask_valid; + tpp.strict_cpu = params.cpuparams.strict_cpu; + tpp.prio = params.cpuparams.priority; + tpp.poll = params.cpuparams.poll; + + std::memcpy(&tpp.cpumask[0], ¶ms.cpuparams.cpumask[0], GGML_MAX_N_THREADS); + + struct ggml_compute_threadpool* threadpool = ggml_create_threadpool(&tpp); + if (!threadpool) { + LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads); + exit(1); + } + llama_attach_threadpool(ctx, threadpool); // warmup run @@ -1515,9 +1509,9 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); llama_free(ctx); - } - ggml_release_threadpool(threadpool); + ggml_release_threadpool(threadpool); + } llama_free_model(lmodel);