Skip to content

Commit

Permalink
bench: create fresh threadpool for each test
Browse files Browse the repository at this point in the history
For benchmarking it's better to start a fresh pool for each test with the exact number of threads
needed for that test. Having larger pools is suboptimal (causes more load, etc).
  • Loading branch information
max-krasnyansky authored and fmz committed Aug 7, 2024
1 parent b32512a commit 8ecdd36
Showing 1 changed file with 18 additions and 24 deletions.
42 changes: 18 additions & 24 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ static void print_usage(int /* argc */, char ** argv) {
printf(" -fa, --flash-attn <0|1> (default: %s)\n", join(cmd_params_defaults.flash_attn, ",").c_str());
printf(" -mmp, --mmap <0|1> (default: %s)\n", join(cmd_params_defaults.use_mmap, ",").c_str());
printf(" --numa <distribute|isolate|numactl> (default: disabled)\n");
printf(" -mt, --max-threads <n> (default: %d)\n", cmd_params_defaults.cpuparams.n_threads);
printf(" -C, --cpu-mask <hex> (default: 0x0)\n");
printf(" --cpu-strict <0|1> (default: %d)\n", cmd_params_defaults.cpuparams.strict_cpu);
printf(" --priority <0|1|2|3> (default: %d)\n", cmd_params_defaults.cpuparams.priority);
Expand Down Expand Up @@ -499,12 +498,6 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
else { invalid_param = true; break; }
}
} else if (arg == "-mt" || arg == "--max-threads") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.cpuparams.n_threads = std::stoi(argv[i]);
} else if (arg == "-C" || arg == "--cpu-mask") {
if (++i >= argc) {
invalid_param = true;
Expand Down Expand Up @@ -1435,21 +1428,6 @@ int main(int argc, char ** argv) {

postprocess_cpu_params(params.cpuparams);

struct ggml_threadpool_params tpp;
tpp.n_threads = params.cpuparams.n_threads;
tpp.mask_specified = params.cpuparams.mask_valid;
tpp.strict_cpu = params.cpuparams.strict_cpu;
tpp.prio = params.cpuparams.priority;
tpp.poll = params.cpuparams.poll;

std::memcpy(&tpp.cpumask[0], &params.cpuparams.cpumask[0], GGML_MAX_N_THREADS);

struct ggml_compute_threadpool* threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}

for (const auto & inst : params_instances) {
// keep the same model between tests when possible
if (!lmodel || !prev_inst || !inst.equal_mparams(*prev_inst)) {
Expand All @@ -1475,6 +1453,22 @@ int main(int argc, char ** argv) {
test t(inst, lmodel, ctx);

llama_kv_cache_clear(ctx);

struct ggml_threadpool_params tpp;
tpp.n_threads = t.n_threads;
tpp.mask_specified = params.cpuparams.mask_valid;
tpp.strict_cpu = params.cpuparams.strict_cpu;
tpp.prio = params.cpuparams.priority;
tpp.poll = params.cpuparams.poll;

std::memcpy(&tpp.cpumask[0], &params.cpuparams.cpumask[0], GGML_MAX_N_THREADS);

struct ggml_compute_threadpool* threadpool = ggml_create_threadpool(&tpp);
if (!threadpool) {
LOG_TEE("%s: threadpool create failed : n_threads %d\n", __func__, tpp.n_threads);
exit(1);
}

llama_attach_threadpool(ctx, threadpool);

// warmup run
Expand Down Expand Up @@ -1515,9 +1509,9 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx);

llama_free(ctx);
}

ggml_release_threadpool(threadpool);
ggml_release_threadpool(threadpool);
}

llama_free_model(lmodel);

Expand Down

0 comments on commit 8ecdd36

Please sign in to comment.