Skip to content

Commit

Permalink
common : support for lifecycle scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
sasha0552 authored Jul 25, 2024
1 parent ed67bcb commit e8f1bd8
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 0 deletions.
28 changes: 28 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,21 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
}
return true;
}
if (arg == "--on-start") {
CHECK_ARG
params.on_start = argv[i];
return true;
}
if (arg == "--on-inference-start") {
CHECK_ARG
params.on_inference_start = argv[i];
return true;
}
if (arg == "--on-inference-end") {
CHECK_ARG
params.on_inference_end = argv[i];
return true;
}
if (arg == "-p" || arg == "--prompt") {
CHECK_ARG
params.prompt = argv[i];
Expand Down Expand Up @@ -1403,6 +1418,11 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads });
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
options.push_back({ "*", " --on-start SCRIPT", "call the specified script at application startup" });
options.push_back({ "*", " --on-inference-start SCRIPT",
"call the specified script before starting the inference" });
options.push_back({ "*", " --on-inference-end SCRIPT",
"call the specified script when the inference is complete" });
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
options.push_back({ "speculative", "-tbd, --threads-batch-draft N",
"number of threads to use during batch and prompt processing (default: same as --threads-draft)" });
Expand Down Expand Up @@ -3223,3 +3243,11 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false");
}

void script_execute(const std::string & script) {
int result = std::system(script.c_str());

if (result != 0) {
fprintf(stderr, "%s: error: unable to execute script '%s'. exit code: %d\n", __func__, script.c_str(), result);
}
}
11 changes: 11 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ enum dimre_method {
struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed

// lifecycle scripts
std::string on_start = ""; // script that will be called on application start
std::string on_inference_start = ""; // script that will be called when inference starts
std::string on_inference_end = ""; // script that will be called when inference ends

int32_t n_threads = cpu_get_num_math();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
Expand Down Expand Up @@ -455,3 +460,9 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha
void yaml_dump_non_result_info(
FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);

//
// Script utils
//

void script_execute(const std::string & script);
12 changes: 12 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ int main(int argc, char ** argv) {
return 1;
}

if (!params.on_start.empty()) {
script_execute(params.on_start);
}

llama_sampling_params & sparams = params.sparams;

#ifndef LOG_DISABLE_LOGS
Expand Down Expand Up @@ -534,6 +538,10 @@ int main(int argc, char ** argv) {
exit(1);
}

if (!params.on_inference_start.empty()) {
script_execute(params.on_inference_start);
}

if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
Expand Down Expand Up @@ -971,6 +979,10 @@ int main(int argc, char ** argv) {
}
}

if (!params.on_inference_end.empty()) {
script_execute(params.on_inference_end);
}

if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
Expand Down
12 changes: 12 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1106,6 +1106,10 @@ struct server_context {
{"id_task", slot.id_task},
});

if (!params.on_inference_start.empty()) {
script_execute(params.on_inference_start);
}

return true;
}

Expand Down Expand Up @@ -1913,6 +1917,10 @@ struct server_context {
kv_cache_clear();
}

if (!params.on_inference_end.empty()) {
script_execute(params.on_inference_end);
}

return;
}
}
Expand Down Expand Up @@ -2496,6 +2504,10 @@ int main(int argc, char ** argv) {
return 1;
}

if (!params.on_start.empty()) {
script_execute(params.on_start);
}

// TODO: not great to use extern vars
server_log_json = params.log_json;
server_verbose = params.verbosity > 0;
Expand Down

0 comments on commit e8f1bd8

Please sign in to comment.