From ab545c83809c6f975287a61d598909d1cf8aa012 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 5 Aug 2024 10:08:25 +0300 Subject: [PATCH 01/47] llama : add llama_sampling API + move grammar in libllama ggml-ci --- Makefile | 6 - common/CMakeLists.txt | 2 - common/common.cpp | 109 +-- common/common.h | 6 +- common/grammar-parser.cpp | 539 ------------ common/grammar-parser.h | 29 - common/sampling.cpp | 548 ++++-------- common/sampling.h | 199 ++--- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched.swift/Sources/main.swift | 46 +- examples/batched/batched.cpp | 39 +- examples/embedding/embedding.cpp | 10 +- examples/eval-callback/eval-callback.cpp | 4 +- examples/gbnf-validator/gbnf-validator.cpp | 43 +- examples/gritlm/gritlm.cpp | 52 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 46 +- examples/llama-bench/llama-bench.cpp | 2 +- .../llama/src/main/cpp/llama-android.cpp | 20 +- .../llama.cpp.swift/LibLlama.swift | 16 +- examples/llava/llava-cli.cpp | 18 +- examples/llava/minicpmv-cli.cpp | 30 +- examples/lookahead/lookahead.cpp | 17 +- examples/lookup/lookup.cpp | 12 +- examples/main/main.cpp | 144 +-- examples/parallel/parallel.cpp | 19 +- examples/passkey/passkey.cpp | 22 +- examples/perplexity/perplexity.cpp | 10 +- examples/quantize-stats/quantize-stats.cpp | 5 +- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 58 +- examples/server/README.md | 6 +- examples/server/server.cpp | 161 +--- examples/simple/simple.cpp | 19 +- examples/speculative/speculative.cpp | 75 +- include/llama.h | 409 ++++----- src/llama-grammar.cpp | 825 +++++++++++++++--- src/llama-grammar.h | 132 ++- src/llama-impl.h | 116 ++- src/llama-sampling.cpp | 318 +++---- src/llama-sampling.h | 115 ++- src/llama-vocab.h | 5 +- src/llama.cpp | 590 +++++++++---- tests/test-grammar-integration.cpp | 43 +- tests/test-grammar-parser.cpp | 10 +- tests/test-json-schema-to-grammar.cpp | 10 +- tests/test-llama-grammar.cpp | 17 +- tests/test-sampling.cpp | 59 +- 48 files changed, 2429 insertions(+), 2538 deletions(-) delete mode 100644 common/grammar-parser.cpp delete mode 100644 common/grammar-parser.h diff --git a/Makefile b/Makefile index 332496cfc39c1..89287831ff31f 100644 --- a/Makefile +++ b/Makefile @@ -927,7 +927,6 @@ OBJ_COMMON = \ common/ngram-cache.o \ common/sampling.o \ common/train.o \ - common/grammar-parser.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1167,11 +1166,6 @@ common/console.o: \ common/console.h $(CXX) $(CXXFLAGS) -c $< -o $@ -common/grammar-parser.o: \ - common/grammar-parser.cpp \ - common/grammar-parser.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - common/json-schema-to-grammar.o: \ common/json-schema-to-grammar.cpp \ common/json-schema-to-grammar.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 761971d6881f3..2c72793b89dbe 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -58,8 +58,6 @@ add_library(${TARGET} STATIC sampling.cpp console.h console.cpp - grammar-parser.h - grammar-parser.cpp json.hpp json-schema-to-grammar.cpp train.h diff --git a/common/common.cpp b/common/common.cpp index de2a177c165b4..23d171a4d7b96 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -353,16 +353,15 @@ void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model) } bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { - bool invalid_param = false; - std::string arg; - const std::string arg_prefix = "--"; - llama_sampling_params & sparams = params.sparams; - for (int i = 1; i < argc; i++) { - arg = argv[i]; + const std::string arg_prefix = "--"; + + std::string arg = argv[i]; if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { std::replace(arg.begin(), arg.end(), '_', '-'); } + + bool invalid_param = false; if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) { throw std::invalid_argument("error: unknown argument: " + arg); } @@ -386,11 +385,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { get_env("HF_TOKEN", params.hf_token); } + auto & sparams = params.sparams; + if (params.escape) { string_process_escapes(params.prompt); string_process_escapes(params.input_prefix); string_process_escapes(params.input_suffix); - string_process_escapes(sparams.cfg_negative_prompt); for (auto & antiprompt : params.antiprompt) { string_process_escapes(antiprompt); } @@ -401,6 +401,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.kv_overrides.back().key[0] = 0; } + if (sparams.seed == LLAMA_DEFAULT_SEED) { + sparams.seed = time(NULL); + } + return true; } @@ -526,12 +530,10 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_params & params, int & i, bool & invalid_param) { const char split_delim = ','; - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; if (arg == "-s" || arg == "--seed") { CHECK_ARG - // TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context. - params.seed = std::stoul(argv[i]); sparams.seed = std::stoul(argv[i]); return true; } @@ -842,12 +844,12 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa if (arg == "--samplers") { CHECK_ARG const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, true); + sparams.samplers = llama_sampling_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.samplers_sequence = llama_sampling_types_from_chars(argv[i]); + sparams.samplers = llama_sampling_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -873,7 +875,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } if (arg == "--typical") { CHECK_ARG - sparams.typical_p = std::stof(argv[i]); + sparams.typ_p = std::stof(argv[i]); return true; } if (arg == "--repeat-last-n") { @@ -922,30 +924,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa sparams.mirostat_tau = std::stof(argv[i]); return true; } - if (arg == "--cfg-negative-prompt") { - CHECK_ARG - sparams.cfg_negative_prompt = argv[i]; - return true; - } - if (arg == "--cfg-negative-prompt-file") { - CHECK_ARG - std::ifstream file(argv[i]); - if (!file) { - fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); - invalid_param = true; - return true; - } - std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(sparams.cfg_negative_prompt)); - if (!sparams.cfg_negative_prompt.empty() && sparams.cfg_negative_prompt.back() == '\n') { - sparams.cfg_negative_prompt.pop_back(); - } - return true; - } - if (arg == "--cfg-scale") { - CHECK_ARG - sparams.cfg_scale = std::stof(argv[i]); - return true; - } if (arg == "-b" || arg == "--batch-size") { CHECK_ARG params.n_batch = std::stoi(argv[i]); @@ -1355,7 +1333,7 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa return true; } if (arg == "--ignore-eos") { - params.ignore_eos = true; + sparams.ignore_eos = true; return true; } if (arg == "--penalize-nl") { @@ -1370,7 +1348,8 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa std::string value_str; try { if (ss >> key && ss >> sign && std::getline(ss, value_str) && (sign == '+' || sign == '-')) { - sparams.logit_bias[key] = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + const float bias = std::stof(value_str) * ((sign == '-') ? -1.0f : 1.0f); + sparams.logit_bias.push_back({key, bias}); } else { throw std::exception(); @@ -1725,13 +1704,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa #endif void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; std::string sampler_type_chars; std::string sampler_type_names; - for (const auto sampler_type : sparams.samplers_sequence) { - sampler_type_chars += static_cast(sampler_type); - sampler_type_names += llama_sampling_type_to_str(sampler_type) + ";"; + for (const auto & sampler : sparams.samplers) { + sampler_type_chars += llama_sampling_type_to_chr(sampler); + sampler_type_names += llama_sampling_type_to_str(sampler) + ";"; } sampler_type_names.pop_back(); @@ -1766,7 +1745,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" }); options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" }); options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" }); - options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed }); options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.cpuparams.n_threads }); options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" }); options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" }); @@ -1846,18 +1824,19 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param " --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" }); options.push_back({ "sampling" }); + options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" "(default: %s)", sampler_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); - options.push_back({ "*", " --temp N", "temperature (default: %.1f)", (double)sparams.temp }); + options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); options.push_back({ "*", " --top-k N", "top-k sampling (default: %d, 0 = disabled)", sparams.top_k }); - options.push_back({ "*", " --top-p N", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); - options.push_back({ "*", " --min-p N", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); - options.push_back({ "*", " --tfs N", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); - options.push_back({ "*", " --typical N", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typical_p }); + options.push_back({ "*", " --top-p P", "top-p sampling (default: %.1f, 1.0 = disabled)", (double)sparams.top_p }); + options.push_back({ "*", " --min-p P", "min-p sampling (default: %.1f, 0.0 = disabled)", (double)sparams.min_p }); + options.push_back({ "*", " --tfs P", "tail free sampling, parameter z (default: %.1f, 1.0 = disabled)", (double)sparams.tfs_z }); + options.push_back({ "*", " --typical P", "locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)sparams.typ_p }); options.push_back({ "*", " --repeat-last-n N", "last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)", sparams.penalty_last_n }); options.push_back({ "*", " --repeat-penalty N", "penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)", (double)sparams.penalty_repeat }); options.push_back({ "*", " --presence-penalty N", "repeat alpha presence penalty (default: %.1f, 0.0 = disabled)", (double)sparams.penalty_present }); @@ -1872,11 +1851,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " -l TOKEN_ID(+/-)BIAS", "modifies the likelihood of token appearing in the completion,\n" "i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n" "or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'" }); - options.push_back({ "main", " --cfg-negative-prompt PROMPT", - "negative prompt to use for guidance (default: '%s')", sparams.cfg_negative_prompt.c_str() }); - options.push_back({ "main", " --cfg-negative-prompt-file FNAME", - "negative prompt file to use for guidance" }); - options.push_back({ "main", " --cfg-scale N", "strength of guidance (default: %.1f, 1.0 = disable)", (double)sparams.cfg_scale }); options.push_back({ "main", " --chat-template JINJA_TEMPLATE", "set custom jinja chat template (default: template taken from model's metadata)\n" "if suffix/prefix are specified, template will be disabled\n" @@ -2528,8 +2502,9 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_lora_adapters_apply(lctx, iparams.lora_adapters); } - if (params.ignore_eos) { - params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + if (params.sparams.ignore_eos && llama_token_eos(model) == -1) { + fprintf(stderr, "%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); + params.sparams.ignore_eos = false; } if (params.warmup) { @@ -2558,7 +2533,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx); + llama_reset_timings(lctx, nullptr); } iparams.model = model; @@ -2637,7 +2612,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_threads = params.cpuparams.n_threads; cparams.n_threads_batch = params.cpuparams_batch.n_threads == -1 ? params.cpuparams.n_threads : params.cpuparams_batch.n_threads; - cparams.seed = params.seed; cparams.logits_all = params.logits_all; cparams.embeddings = params.embedding; cparams.rope_scaling_type = params.rope_scaling_type; @@ -3523,7 +3497,7 @@ void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const cha void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const llama_context * lctx, const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - const llama_sampling_params & sparams = params.sparams; + const auto & sparams = params.sparams; fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); @@ -3574,8 +3548,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - yaml_dump_string_multiline(stream, "cfg_negative_prompt", sparams.cfg_negative_prompt.c_str()); - fprintf(stream, "cfg_scale: %f # default: 1.0\n", sparams.cfg_scale); fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); @@ -3586,10 +3558,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - - const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx))); - const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY; - fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false"); + fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); @@ -3600,11 +3569,8 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); fprintf(stream, "logit_bias:\n"); - for (std::pair lb : sparams.logit_bias) { - if (ignore_eos && lb.first == logit_bias_eos->first) { - continue; - } - fprintf(stream, " %d: %f", lb.first, lb.second); + for (const auto & logit_bias : sparams.logit_bias) { + fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); } fprintf(stream, "lora:\n"); @@ -3657,7 +3623,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); @@ -3671,7 +3636,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); + fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); } diff --git a/common/common.h b/common/common.h index 795ff44054d40..1c4eae34a8390 100644 --- a/common/common.h +++ b/common/common.h @@ -77,8 +77,6 @@ struct cpu_params { }; struct gpt_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed - int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 0; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) @@ -120,8 +118,7 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - // // sampling parameters - struct llama_sampling_params sparams; + struct gpt_sampling_params sparams; std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding @@ -185,7 +182,6 @@ struct gpt_params { bool flash_attn = false; // flash attention bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix - bool ignore_eos = false; // ignore generated EOS tokens bool logits_all = false; // return logits for all tokens in the batch bool use_mmap = true; // use mmap for faster loads bool use_mlock = false; // use mlock to keep model in memory diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp deleted file mode 100644 index 438452eab570f..0000000000000 --- a/common/grammar-parser.cpp +++ /dev/null @@ -1,539 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - if (rule.empty()) { - throw std::runtime_error("Undefined rule"); - } - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); - return parse_state(); - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/common/grammar-parser.h b/common/grammar-parser.h deleted file mode 100644 index 9037d72728a42..0000000000000 --- a/common/grammar-parser.h +++ /dev/null @@ -1,29 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff..96cfbe0ef5b45 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,460 +1,222 @@ -#define LLAMA_API_INTERNAL #include "sampling.h" -#include -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { - struct llama_sampling_context * result = new llama_sampling_context(); +#include "common.h" - result->params = params; - result->grammar = nullptr; - - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } +std::string gpt_sampling_params::print_all() const { + char result[1024]; - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } + snprintf(result, sizeof(result), + "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" + "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" + "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", + penalty_last_n, penalty_repeat, penalty_freq, penalty_present, + top_k, tfs_z, top_p, min_p, typ_p, temp, + mirostat, mirostat_eta, mirostat_tau); - std::vector grammar_rules(result->parsed_grammar.c_rules()); + return std::string(result); +} - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); +std::string gpt_sampling_params::print_samplers() const { + std::string result = "CFG -> Penalties "; + if (mirostat == 0) { + for (const auto & sampler : samplers) { + const auto name = llama_sampling_type_to_str(sampler); + if (!name.empty()) { + result += "-> " + name + " "; + } } - result->grammar = grammar; + } else { + result += "-> mirostat "; } - result->prev.resize(params.n_prev); - - result->n_valid = 0; + return result; +} - llama_sampling_set_rng_seed(result, params.seed); +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { + llama_sampling_params lparams = llama_sampling_default_params(); + + lparams.seed = params.seed; + lparams.n_prev = params.n_prev; + lparams.n_probs = params.n_probs; + lparams.min_keep = params.min_keep; + lparams.top_k = params.top_k; + lparams.top_p = params.top_p; + lparams.min_p = params.min_p; + lparams.tfs_z = params.tfs_z; + lparams.typ_p = params.typ_p; + lparams.temp = params.temp; + lparams.dynatemp_range = params.dynatemp_range; + lparams.dynatemp_exponent = params.dynatemp_exponent; + lparams.penalty_last_n = params.penalty_last_n; + lparams.penalty_repeat = params.penalty_repeat; + lparams.penalty_freq = params.penalty_freq; + lparams.penalty_present = params.penalty_present; + lparams.mirostat = params.mirostat; + lparams.mirostat_tau = params.mirostat_tau; + lparams.mirostat_eta = params.mirostat_eta; + lparams.penalize_nl = params.penalize_nl; + lparams.ignore_eos = params.ignore_eos; + + lparams.n_samplers = params.samplers.size(); + for (int i = 0; i < lparams.n_samplers; i++) { + lparams.samplers[i] = params.samplers[i]; + } + + struct llama_sampling * result = llama_sampling_init(model, lparams); + + llama_sampling_set_grammar (result, params.grammar.c_str(), "root"); + llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data()); return result; } -void llama_sampling_free(struct llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); +void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) { + if (dst) { + llama_sampling_free(dst); } - delete ctx; + dst = llama_sampling_cp(src); } -void llama_sampling_reset(llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; - } - - if (!ctx->parsed_grammar.rules.empty()) { - std::vector grammar_rules(ctx->parsed_grammar.c_rules()); +llama_token llama_sampling_sample( + struct llama_sampling * smpl, + struct llama_context * ctx, + int idx) { + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - ctx->grammar = grammar; - } + // first, sample the token without any grammar constraints + const llama_token id = llama_sampling_sample(smpl, nullptr); - std::fill(ctx->prev.begin(), ctx->prev.end(), 0); - ctx->cur.clear(); - ctx->n_valid = 0; -} + // create an array with a single token data element for the sampled id + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = std::random_device{}(); - } - ctx->rng.seed(seed); -} + llama_sampling_grammar(smpl, &single_token_data_array); -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { - if (dst->grammar) { - llama_grammar_free(dst->grammar); - dst->grammar = nullptr; + // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; } - if (src->grammar) { - dst->grammar = llama_grammar_copy(src->grammar); - } + // if the token is not valid, sample again, after applying the grammar constraints + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - dst->prev = src->prev; -} + llama_sampling_grammar(smpl, nullptr); -llama_token llama_sampling_last(llama_sampling_context * ctx) { - return ctx->prev.back(); + return llama_sampling_sample(smpl, nullptr); } -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) { - const int size = ctx_sampling->prev.size(); +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) { + n = std::min(n, llama_sampling_n_prev(smpl)); - n = std::min(n, size); + if (n <= 0) { + return ""; + } std::string result; + result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab - for (int i = size - n; i < size; i++) { - result += llama_token_to_piece(ctx_main, ctx_sampling->prev[i]); - } + for (int i = n - 1; i >= 0; i--) { + const llama_token id = llama_sampling_prev(smpl, i); - return result; -} + GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); -std::string llama_sampling_print(const llama_sampling_params & params) { - char result[1024]; - - snprintf(result, sizeof(result), - "\trepeat_last_n = %d, repeat_penalty = %.3f, frequency_penalty = %.3f, presence_penalty = %.3f\n" - "\ttop_k = %d, tfs_z = %.3f, top_p = %.3f, min_p = %.3f, typical_p = %.3f, temp = %.3f\n" - "\tmirostat = %d, mirostat_lr = %.3f, mirostat_ent = %.3f", - params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present, - params.top_k, params.tfs_z, params.top_p, params.min_p, params.typical_p, params.temp, - params.mirostat, params.mirostat_eta, params.mirostat_tau); + result += llama_token_to_piece(ctx_main, id); + } - return std::string(result); + return result; } -std::string llama_sampling_order_print(const llama_sampling_params & params) { - std::string result = "CFG -> Penalties "; - if (params.mirostat == 0) { - for (auto sampler_type : params.samplers_sequence) { - const auto sampler_type_name = llama_sampling_type_to_str(sampler_type); - if (!sampler_type_name.empty()) { - result += "-> " + sampler_type_name + " "; - } - } - } else { - result += "-> mirostat "; +char llama_sampling_type_to_chr(llama_sampler_type sampler) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: return 'k'; + case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f'; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case LLAMA_SAMPLER_TYPE_TOP_P: return 'p'; + case LLAMA_SAMPLER_TYPE_MIN_P: return 'm'; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't'; + default : return '?'; } - - return result; } -std::string llama_sampling_type_to_str(llama_sampler_type sampler_type) { - switch (sampler_type) { - case llama_sampler_type::TOP_K: return "top_k"; - case llama_sampler_type::TFS_Z: return "tfs_z"; - case llama_sampler_type::TYPICAL_P: return "typical_p"; - case llama_sampler_type::TOP_P: return "top_p"; - case llama_sampler_type::MIN_P: return "min_p"; - case llama_sampler_type::TEMPERATURE: return "temperature"; +std::string llama_sampling_type_to_str(llama_sampler_type sampler) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k"; + case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z"; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p"; + case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p"; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { std::unordered_map sampler_canonical_name_map { - {"top_k", llama_sampler_type::TOP_K}, - {"top_p", llama_sampler_type::TOP_P}, - {"typical_p", llama_sampler_type::TYPICAL_P}, - {"min_p", llama_sampler_type::MIN_P}, - {"tfs_z", llama_sampler_type::TFS_Z}, - {"temperature", llama_sampler_type::TEMPERATURE} + { "top_k", LLAMA_SAMPLER_TYPE_TOP_K }, + { "top_p", LLAMA_SAMPLER_TYPE_TOP_P }, + { "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", LLAMA_SAMPLER_TYPE_MIN_P }, + { "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE }, }; // since samplers names are written multiple ways // make it ready for both system names and input names std::unordered_map sampler_alt_name_map { - {"top-k", llama_sampler_type::TOP_K}, - {"top-p", llama_sampler_type::TOP_P}, - {"nucleus", llama_sampler_type::TOP_P}, - {"typical-p", llama_sampler_type::TYPICAL_P}, - {"typical", llama_sampler_type::TYPICAL_P}, - {"min-p", llama_sampler_type::MIN_P}, - {"tfs-z", llama_sampler_type::TFS_Z}, - {"tfs", llama_sampler_type::TFS_Z}, - {"temp", llama_sampler_type::TEMPERATURE} + { "top-k", LLAMA_SAMPLER_TYPE_TOP_K }, + { "top-p", LLAMA_SAMPLER_TYPE_TOP_P }, + { "nucleus", LLAMA_SAMPLER_TYPE_TOP_P }, + { "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", LLAMA_SAMPLER_TYPE_MIN_P }, + { "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "tfs", LLAMA_SAMPLER_TYPE_TFS_Z }, + { "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE }, }; - std::vector sampler_types; - sampler_types.reserve(names.size()); - for (const auto & name : names) - { - auto sampler_item = sampler_canonical_name_map.find(name); - if (sampler_item != sampler_canonical_name_map.end()) - { - sampler_types.push_back(sampler_item->second); - } - else - { - if (allow_alt_names) - { - sampler_item = sampler_alt_name_map.find(name); - if (sampler_item != sampler_alt_name_map.end()) - { - sampler_types.push_back(sampler_item->second); + std::vector samplers; + samplers.reserve(names.size()); + + for (const auto & name : names) { + auto sampler = sampler_canonical_name_map.find(name); + if (sampler != sampler_canonical_name_map.end()) { + samplers.push_back(sampler->second); + } else { + if (allow_alt_names) { + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); } } } } - return sampler_types; + + return samplers; } -std::vector llama_sampling_types_from_chars(const std::string & names_string) { +std::vector llama_sampling_types_from_chars(const std::string & chars) { std::unordered_map sampler_name_map { - {'k', llama_sampler_type::TOP_K}, - {'p', llama_sampler_type::TOP_P}, - {'y', llama_sampler_type::TYPICAL_P}, - {'m', llama_sampler_type::MIN_P}, - {'f', llama_sampler_type::TFS_Z}, - {'t', llama_sampler_type::TEMPERATURE} + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P }, + { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE } }; - std::vector sampler_types; - sampler_types.reserve(names_string.size()); - for (const auto & c : names_string) { - const auto sampler_item = sampler_name_map.find(c); - if (sampler_item != sampler_name_map.end()) { - sampler_types.push_back(sampler_item->second); - } - } - return sampler_types; -} - -// no reasons to expose this function in header -static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, - llama_token_data_array & cur_p, - size_t min_keep) { - const float temp = params.temp; - const float dynatemp_range = params.dynatemp_range; - const float dynatemp_exponent = params.dynatemp_exponent; - const int32_t top_k = params.top_k; - const float top_p = params.top_p; - const float min_p = params.min_p; - const float tfs_z = params.tfs_z; - const float typical_p = params.typical_p; - const std::vector & samplers_sequence = params.samplers_sequence; - - for (auto sampler_type : samplers_sequence) { - switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; - case llama_sampler_type::TEMPERATURE: - if (dynatemp_range > 0) { - float dynatemp_min = std::max(0.0f, temp - dynatemp_range); - float dynatemp_max = std::max(0.0f, temp + dynatemp_range); - llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); - } else { - llama_sample_temp(ctx_main, &cur_p, temp); - } - break; - default : break; - } - } -} - -static llama_token llama_sampling_sample_impl( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool is_resampling) { - const llama_sampling_params & params = ctx_sampling->params; - - const float temp = params.temp; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; - - std::vector original_logits; - auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); - if (ctx_sampling->grammar != NULL && !is_resampling) { - GGML_ASSERT(!original_logits.empty()); - } - llama_token id = 0; - - if (temp < 0.0) { - // greedy sampling, with probs - llama_sample_softmax(ctx_main, &cur_p); - id = cur_p.data[0].id; - } else if (temp == 0.0) { - // greedy sampling, no probs - id = llama_sample_token_greedy(ctx_main, &cur_p); - } else { - if (mirostat == 1) { - const int mirostat_m = 100; - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); - } else if (mirostat == 2) { - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); - } else { - // temperature sampling - size_t min_keep = std::max(1, params.min_keep); - - sampler_queue(ctx_main, params, cur_p, min_keep); + std::vector samplers; + samplers.reserve(chars.size()); - id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); - - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); - - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); - // } - //} - - //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); + for (const auto & c : chars) { + const auto sampler = sampler_name_map.find(c); + if (sampler != sampler_name_map.end()) { + samplers.push_back(sampler->second); } } - if (ctx_sampling->grammar != NULL && !is_resampling) { - // Get a pointer to the logits - float * logits = llama_get_logits_ith(ctx_main, idx); - - // Create an array with a single token data element for the sampled id - llama_token_data single_token_data = {id, logits[id], 0.0f}; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - - // Apply grammar constraints to the single token - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array); - - // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY - bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - - // If the token is not valid according to the grammar, perform resampling - if (!is_valid) { - LOG("Resampling because token %d: '%s' does not meet grammar rules\n", id, llama_token_to_piece(ctx_main, id).c_str()); - - // Restore logits from the copy - std::copy(original_logits.begin(), original_logits.end(), logits); - - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ true); - } - } - - ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; - - return id; -} - -static llama_token_data_array llama_sampling_prepare_impl( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool apply_grammar, - std::vector * original_logits) { - const llama_sampling_params & params = ctx_sampling->params; - - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - - const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; - const float penalty_repeat = params.penalty_repeat; - const float penalty_freq = params.penalty_freq; - const float penalty_present = params.penalty_present; - - const bool penalize_nl = params.penalize_nl; - - auto & prev = ctx_sampling->prev; - auto & cur = ctx_sampling->cur; - - // Get a pointer to the logits - float * logits = llama_get_logits_ith(ctx_main, idx); - - if (ctx_sampling->grammar != NULL && !apply_grammar) { - GGML_ASSERT(original_logits != NULL); - // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. - *original_logits = {logits, logits + n_vocab}; - } - - // apply params.logit_bias map - for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { - logits[it->first] += it->second; - } - - if (ctx_cfg) { - float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); - } - - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; - - // apply penalties - const auto& penalty_tokens = params.use_penalty_prompt_tokens ? params.penalty_prompt_tokens : prev; - const int penalty_tokens_used_size = std::min((int)penalty_tokens.size(), penalty_last_n); - if (penalty_tokens_used_size) { - const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - - llama_sample_repetition_penalties(ctx_main, &cur_p, - penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, - penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); - - if (!penalize_nl) { - for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(llama_get_model(ctx_main))) { - cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } - - // apply grammar checks before sampling logic - if (apply_grammar && ctx_sampling->grammar != NULL) { - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); - } - - return cur_p; -} - -llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx) { - // Call the implementation function with is_resampling set to false by default - return llama_sampling_sample_impl(ctx_sampling, ctx_main, ctx_cfg, idx, /* is_resampling= */ false); -} - -llama_token_data_array llama_sampling_prepare( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - const int idx, - bool apply_grammar, - std::vector * original_logits) { - return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); -} - -void llama_sampling_accept( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - llama_token id, - bool apply_grammar) { - ctx_sampling->prev.erase(ctx_sampling->prev.begin()); - ctx_sampling->prev.push_back(id); - - if (ctx_sampling->grammar != NULL && apply_grammar) { - llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); - } + return samplers; } diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd00..b96bbce1ce869 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,159 +2,78 @@ #include "llama.h" -#include "grammar-parser.h" - -#include #include -#include #include -// sampler types -enum class llama_sampler_type : char { - TOP_K = 'k', - TOP_P = 'p', - MIN_P = 'm', - TFS_Z = 'f', - TYPICAL_P = 'y', - TEMPERATURE = 't' -}; - // sampling parameters -typedef struct llama_sampling_params { - int32_t n_prev = 64; // number of previous tokens to remember - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k = 40; // <= 0 to use vocab size - float top_p = 0.95f; // 1.0 = disabled - float min_p = 0.05f; // 0.0 = disabled - float tfs_z = 1.00f; // 1.0 = disabled - float typical_p = 1.00f; // 1.0 = disabled - float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range = 0.00f; // 0.0 = disabled - float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat = 1.00f; // 1.0 = disabled - float penalty_freq = 0.00f; // 0.0 = disabled - float penalty_present = 0.00f; // 0.0 = disabled - int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau = 5.00f; // target entropy - float mirostat_eta = 0.10f; // learning rate - bool penalize_nl = false; // consider newlines as a repeatable token - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context - - std::vector samplers_sequence = { - llama_sampler_type::TOP_K, - llama_sampler_type::TFS_Z, - llama_sampler_type::TYPICAL_P, - llama_sampler_type::TOP_P, - llama_sampler_type::MIN_P, - llama_sampler_type::TEMPERATURE +typedef struct gpt_sampling_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling + + int32_t n_prev = 64; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k = 40; // <= 0 to use vocab size + float top_p = 0.95f; // 1.0 = disabled + float min_p = 0.05f; // 0.0 = disabled + float tfs_z = 1.00f; // 1.0 = disabled + float typ_p = 1.00f; // typical_p, 1.0 = disabled + float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range = 0.00f; // 0.0 = disabled + float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.00f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled + int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau = 5.00f; // target entropy + float mirostat_eta = 0.10f; // learning rate + bool penalize_nl = false; // consider newlines as a repeatable token + bool ignore_eos = false; + + std::vector samplers = { + LLAMA_SAMPLER_TYPE_TOP_K, + LLAMA_SAMPLER_TYPE_TFS_Z, + LLAMA_SAMPLER_TYPE_TYPICAL_P, + LLAMA_SAMPLER_TYPE_TOP_P, + LLAMA_SAMPLER_TYPE_MIN_P, + LLAMA_SAMPLER_TYPE_TEMPERATURE }; - std::string grammar; // optional BNF-like grammar to constrain sampling - - // Classifier-Free Guidance - // https://arxiv.org/abs/2306.17806 - std::string cfg_negative_prompt; // string to help guidance - float cfg_scale = 1.f; // how strong is guidance - - std::unordered_map logit_bias; // logit bias for specific tokens - - std::vector penalty_prompt_tokens; - bool use_penalty_prompt_tokens = false; -} llama_sampling_params; - -// general sampler context -// TODO: move to llama.h -struct llama_sampling_context { - // parameters that will be used for sampling - llama_sampling_params params; - - // mirostat sampler state - float mirostat_mu; - - llama_grammar * grammar; - - // internal - grammar_parser::parse_state parsed_grammar; - - // TODO: replace with ring-buffer - std::vector prev; - std::vector cur; - size_t n_valid; // Number of correct top tokens with correct probabilities. - - std::mt19937 rng; -}; + std::string grammar; // optional BNF-like grammar to constrain sampling -#include "common.h" + std::vector logit_bias; // logit biases to apply -// Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); + // print the parameters into a string + std::string print_all() const; -void llama_sampling_free(struct llama_sampling_context * ctx); + // print the samplers into a string + std::string print_samplers() const; +} gpt_sampling_params; -// Reset the sampler context -// - clear prev tokens -// - reset grammar -void llama_sampling_reset(llama_sampling_context * ctx); +// overload of llama_sampling_init using gpt_sampling_params +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); -// Set the sampler seed -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); +void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); -// Copy the sampler context -void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); - -// Get the last sampled token -llama_token llama_sampling_last(llama_sampling_context * ctx); - -// Get a string representation of the last sampled tokens -std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n); - -// Print sampling parameters into a string -std::string llama_sampling_print(const llama_sampling_params & params); +// common sampling implementation: +// +// - set logits +// - apply the configured sampling constraints +// - check if the token fits the grammar (if any) +// - if not: resample by first applying the grammar constraints and then sampling again (slower path) +// +llama_token llama_sampling_sample( + struct llama_sampling * smpl, + struct llama_context * ctx, + int idx); -// Print sampling order into a string -std::string llama_sampling_order_print(const llama_sampling_params & params); +// helpers -std::string llama_sampling_type_to_str(llama_sampler_type sampler_type); +// get a string representation of the last accepted tokens +std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & names_string); +char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type); +std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type); -// this is a common sampling function used across the examples for convenience -// it can serve as a starting point for implementing your own sampling function -// Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_reset when a sequence ends -// -// required: -// - ctx_main: context to use for sampling -// - ctx_sampling: sampling-specific context -// -// optional: -// - ctx_cfg: context to use for classifier-free guidance -// - idx: sample from llama_get_logits_ith(ctx, idx) -// -// returns: -// - token: sampled token -// - candidates: vector of candidate tokens -// -llama_token llama_sampling_sample( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = -1); - -// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters. -llama_token_data_array llama_sampling_prepare( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - struct llama_context * ctx_cfg, - int idx = 0, - bool apply_grammar = true, - std::vector * original_logits = nullptr); - -void llama_sampling_accept( - struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, - llama_token id, - bool apply_grammar); +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & chars); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25a950ea59a8c..b02ef74cfcdab 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -210,7 +210,7 @@ int main(int argc, char ** argv) { } } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d841d..81763217a91a8 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo print("Failed to load model") exit(1) } - defer { llama_free_model(model) } @@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true) let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) var context_params = llama_context_default_params() -context_params.seed = 1234 context_params.n_ctx = n_kv_req context_params.n_batch = UInt32(max(n_len, n_parallel)) context_params.n_threads = 8 @@ -48,11 +46,24 @@ guard context != nil else { print("Failed to initialize context") exit(1) } - defer { llama_free(context) } +var sparams = llama_sampling_params() +sparams.top_k = 40 +sparams.top_p = 0.9 +sparams.temp = 0.4 + +let smpl = llama_sampling_init(model, sparams) +guard smpl != nil else { + print("Failed to initialize sampling") + exit(1) +} +defer { + llama_sampling_free(smpl) +} + let n_ctx = llama_n_ctx(context) print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") @@ -125,32 +136,17 @@ while n_cur <= n_len { continue } - var n_vocab = llama_n_vocab(model) var logits = llama_get_logits_ith(context, i_batch[i]) - var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) - - for token_id in 0 ..< n_vocab { - candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) - } - - var candidates_p: llama_token_data_array = .init( - data: &candidates, - size: candidates.count, - sorted: false - ) - - let top_k: Int32 = 40 - let top_p: Float = 0.9 - let temp: Float = 0.4 + llama_sampling_set_logits(smpl, logits) - llama_sample_top_k(context, &candidates_p, top_k, 1) - llama_sample_top_p(context, &candidates_p, top_p, 1) - llama_sample_temp(context, &candidates_p, temp) + llama_sampling_top_k(smpl, nil) + llama_sampling_top_p(smpl, nil) + llama_sampling_temp (smpl, nil) - let new_token_id = llama_sample_token(context, &candidates_p) + let new_token_id = llama_sampling_sample_dist(smpl, nil) - // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { @@ -212,7 +208,7 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") -llama_print_timings(context) +llama_print_timings(context, smpl) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8cf2a..4dfa19ce88af3 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -2,7 +2,6 @@ #include "llama.h" #include -#include #include #include #include @@ -65,6 +64,15 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); + auto sparams = llama_sampling_default_params(); + + sparams.seed = params.sparams.seed; + sparams.top_k = 40; + sparams.top_p = 0.9f; + sparams.temp = 0.4f; + + llama_sampling * smpl = llama_sampling_init(model, sparams); + if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; @@ -164,29 +172,17 @@ int main(int argc, char ** argv) { continue; } - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - const int top_k = 40; - const float top_p = 0.9f; - const float temp = 0.4f; + llama_sampling_set_logits(smpl, logits); - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temp (ctx, &candidates_p, temp); + llama_sampling_top_k(smpl, nullptr); + llama_sampling_top_p(smpl, nullptr); + llama_sampling_temp (smpl, nullptr); - const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr); - //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -244,12 +240,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index b05aa006e7da5..4b288b46092ef 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -90,13 +90,7 @@ int main(int argc, char ** argv) { print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); llama_backend_init(); llama_numa_init(params.numa); @@ -314,7 +308,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 5e89988e2beda..166ca4b7da6bd 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -151,8 +151,6 @@ int main(int argc, char ** argv) { print_build_info(); - std::mt19937 rng(params.seed); - llama_backend_init(); llama_numa_init(params.numa); @@ -183,7 +181,7 @@ int main(int argc, char ** argv) { return 1; } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 48a705e15cea9..f439c0c5648a8 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,9 +1,5 @@ -#define LLAMA_API_INTERNAL - -#include "grammar-parser.h" -#include "ggml.h" -#include "llama.h" #include "unicode.h" +#include "llama-grammar.h" #include #include @@ -12,22 +8,21 @@ #include #include -static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { - auto decoded = decode_utf8(input_str, {}); - const auto & code_points = decoded.first; +static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) { + const auto cpts = unicode_cpts_from_utf8(input_str); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); size_t pos = 0; - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { error_pos = pos; - error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'"; + error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; cur_stacks = prev_stacks; return false; } @@ -85,27 +80,7 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - // Parse the GBNF grammar - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - fprintf(stdout, "%s: failed to parse grammar\n", __func__); - return 1; - } - - // Ensure that there is a "root" node. - if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { - fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); - return 1; - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - - // Create the LLAMA grammar - auto grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } @@ -122,7 +97,7 @@ int main(int argc, char** argv) { // Validate the input string against the grammar size_t error_pos; std::string error_msg; - bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg); + bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg); if (is_valid) { fprintf(stdout, "Input string is valid according to the grammar.\n"); @@ -131,7 +106,7 @@ int main(int argc, char** argv) { } // Clean up - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); return 0; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1eb3bc..7d2ae77133ec5 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -9,7 +9,7 @@ static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { std::vector> result; - const llama_model * mdl = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -18,16 +18,16 @@ static std::vector> encode(llama_context * ctx, const std::ve const std::string input_string = instruction + sentences[i]; - std::vector inputs = llama_tokenize(mdl, input_string, true, false); + std::vector inputs = llama_tokenize(model, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(mdl)); + // inputs.push_back(llama_token_eos(model)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size(); + const int32_t n_inst = llama_tokenize(model, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -51,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(mdl); + uint64_t n_embd = llama_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -92,11 +92,11 @@ static std::vector> encode(llama_context * ctx, const std::ve return result; } -static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { +static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) { std::string result; - const llama_model * mdl = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(mdl); + const llama_model * model = llama_get_model(ctx); + llama_token eos_token = llama_token_eos(model); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -104,28 +104,27 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = llama_tokenize(mdl, prompt, false, true); + std::vector inputs = llama_tokenize(model, prompt, false, true); int32_t i_current_token = 0; while (true) { llama_batch_clear(bat); - auto n_inputs = (int32_t)inputs.size(); - for (int32_t i = 0; i < n_inputs; i++) { - llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + { + const int32_t n_inputs = inputs.size(); + + for (int32_t i = 0; i < n_inputs; i++) { + llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + } } inputs.clear(); llama_decode(ctx, bat); - auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - auto candidates = std::vector(llama_n_vocab(mdl)); - auto n_candidates = (int32_t)candidates.size(); - for (int32_t token = 0; token < n_candidates; token++) { - candidates[token] = llama_token_data{ token, logits[token], 0.0f }; - } - auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; + const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + + llama_sampling_set_logits(smpl, logits); - llama_token token = llama_sample_token_greedy(ctx, &candidates_p); + llama_token token = llama_sampling_sample_greedy(smpl, nullptr); if (token == eos_token) { break; } @@ -167,10 +166,12 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(model, cparams); + + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -191,7 +192,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(mdl); + const int n_embd = llama_n_embd(model); const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); @@ -208,11 +209,12 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx, prompt, true); + std::string response = generate(ctx, smpl, prompt, true); } + llama_sampling_free(smpl); llama_free(ctx); - llama_free_model(mdl); + llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 83b85d72b043a..1c7f5350555e9 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -638,7 +638,7 @@ int main(int argc, char ** argv) { g_collector.save_imatrix(); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 05700c1d591d9..371232421b71c 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -2,7 +2,6 @@ #include "console.h" #include "llama.h" -#include "grammar-parser.h" #include #include @@ -34,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -103,7 +103,6 @@ static void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; - llama_sampling_params & sparams = params.sparams; g_params = ¶ms; if (!gpt_params_parse(argc, argv, params)) { @@ -111,6 +110,8 @@ int main(int argc, char ** argv) { return 1; } + auto & sparams = params.sparams; + #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("infill", "log")); LOG_TEE("Log start\n"); @@ -156,26 +157,21 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); - - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - LOG_TEE("%s: seed = %u\n", __func__, params.seed); + print_build_info(); - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); LOG("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling * smpl = nullptr; g_model = &model; g_ctx = &ctx; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -305,7 +301,7 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); + LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); @@ -349,7 +345,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + smpl = llama_sampling_init(model, sparams); while (n_remain != 0 || params.interactive) { // predict @@ -421,11 +417,11 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -444,7 +440,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + llama_sampling_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -476,7 +472,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); @@ -542,7 +538,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + else if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found EOS token\n"); if (params.interactive) { @@ -615,7 +611,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -638,13 +634,13 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index fe1802b51bdf6..1385634ddcc9c 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1630,7 +1630,7 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe23167557..c33f55f720223 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -120,8 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo LOGi("Using %d threads", n_threads); llama_context_params ctx_params = llama_context_default_params(); - ctx_params.seed = 1234; - ctx_params.n_ctx = 2048; + + ctx_params.n_ctx = 2048; ctx_params.n_threads = n_threads; ctx_params.n_threads_batch = n_threads; @@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( JNIEnv * env, jobject, jlong context_pointer, + jlong sampling_pointer, jlong batch_pointer, jint n_len, jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); + const auto sampling = reinterpret_cast(sampling_pointer); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); @@ -392,20 +394,12 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - auto n_vocab = llama_n_vocab(model); - auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(sampling, logits); // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 48b7840ae49c3..515170f679f82 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer + private var sampling: OpaquePointer private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false @@ -42,9 +43,11 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] + self.sampling = llama_sampling_init(context, llama_sampling_default_params()) } deinit { + llama_sampling_free(sampling) llama_batch_free(batch) llama_free(context) llama_free_model(model) @@ -69,7 +72,6 @@ actor LlamaContext { print("Using \(n_threads) threads") var ctx_params = llama_context_default_params() - ctx_params.seed = 1234 ctx_params.n_ctx = 2048 ctx_params.n_threads = Int32(n_threads) ctx_params.n_threads_batch = Int32(n_threads) @@ -147,17 +149,9 @@ actor LlamaContext { let n_vocab = llama_n_vocab(model) let logits = llama_get_logits_ith(context, batch.n_tokens - 1) - var candidates = Array() - candidates.reserveCapacity(Int(n_vocab)) + llama_sampling_set_logits(sampling, logits); - for token_id in 0..sparams); - if (!ctx_sampling) { + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); printf("\n"); } @@ -310,7 +310,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); @@ -327,7 +327,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index f500ea5b944f4..c041fe530e987 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e LOG_TEE("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct llama_sampling_context * ctx_sampling, +static const char * sample(struct llama_sampling * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); - llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); + llama_sampling_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri return ctx_llava; } -static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); - return ctx_sampling; + struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){ - const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past); + const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; } @@ -278,12 +278,12 @@ int main(int argc, char ** argv) { if (!params.prompt.empty()) { LOG_TEE("%s\n", params.prompt.c_str()); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; bool have_tmp = false; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0){ if(!have_tmp)continue; @@ -296,18 +296,18 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); }else { while (true) { LOG_TEE(""); std::string prompt; std::getline(std::cin, prompt); LOG_TEE(""); - auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true); + auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true); const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { - auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past); + auto tmp = llama_loop(ctx_llava, smpl, n_past); response += tmp; if (strcmp(tmp, "") == 0) break; if (strstr(tmp, "###")) break; // Yi-VL behavior @@ -315,11 +315,11 @@ int main(int argc, char ** argv) { if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); } } printf("\n"); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629c5b6a..2bd31d00268a2 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -1,7 +1,6 @@ #include "common.h" #include "llama.h" -#include #include #include #include @@ -118,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); // verification n-grams std::vector ngrams_cur(G); @@ -159,9 +158,9 @@ int main(int argc, char ** argv) { // sample first token { - id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); + id = llama_sampling_sample(smpl, ctx, 0); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -284,9 +283,9 @@ int main(int argc, char ** argv) { } // sample the next token - id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); + id = llama_sampling_sample(smpl, ctx, i_batch); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); // print { @@ -361,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -468,10 +467,10 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); llama_kv_cache_view_free(&kvc_view); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828c2ea2..da4d57a518754 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -3,13 +3,11 @@ #include "common.h" #include "ngram-cache.h" -#include #include #include #include #include #include -#include int main(int argc, char ** argv){ gpt_params params; @@ -106,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); std::vector draft; @@ -130,9 +128,9 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); + llama_token id = llama_sampling_sample(smpl, ctx, i_dft); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(smpl, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -241,9 +239,9 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_batch_free(batch_tgt); llama_free(ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index c55efbb66d7c1..296c1c687ad5b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -105,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { llama_chat_msg new_msg{role, content}; - auto formatted = llama_chat_format_single( - model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); LOG("formatted: %s\n", formatted.c_str()); return formatted; @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling_params & sparams = params.sparams; + auto & sparams = params.sparams; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("main", "log")); @@ -183,27 +183,23 @@ int main(int argc, char ** argv) { LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); } - LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); - LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); + print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - LOG_TEE("%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); LOG("%s: llama backend init\n", __func__); llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; - llama_context * ctx_guidance = NULL; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling * smpl = nullptr; + std::vector chat_msgs; + g_model = &model; g_ctx = &ctx; + g_smpl = &smpl; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -211,10 +207,6 @@ int main(int argc, char ** argv) { model = llama_init.model; ctx = llama_init.context; - if (sparams.cfg_scale > 1.f) { - struct llama_context_params lparams = llama_context_params_from_gpt_params(params); - ctx_guidance = llama_new_context_with_model(model, lparams); - } if (model == NULL) { LOG_TEE("%s: error: unable to load model\n", __func__); @@ -251,9 +243,6 @@ int main(int argc, char ** argv) { } llama_attach_threadpool(ctx, threadpool, threadpool_batch); - if (ctx_guidance) { - llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch); - } const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); @@ -337,24 +326,6 @@ int main(int argc, char ** argv) { } // Tokenize negative prompt - std::vector guidance_inp; - int guidance_offset = 0; - int original_prompt_len = 0; - if (ctx_guidance) { - LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true); - LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); - - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, true, true); - LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); - - original_prompt_len = original_inp.size(); - guidance_offset = (int)guidance_inp.size() - original_prompt_len; - LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); - LOG("guidance_offset: %s", log_tostr(guidance_offset)); - } - if ((int) embd_inp.size() > n_ctx - 4) { LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); return 1; @@ -421,15 +392,6 @@ int main(int argc, char ** argv) { LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); } - if (ctx_guidance) { - LOG_TEE("\n"); - LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); - LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); - for (int i = 0; i < (int) guidance_inp.size(); i++) { - LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); - } - } - if (params.n_keep > add_bos) { LOG_TEE("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { @@ -495,8 +457,8 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); - LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str()); + LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); + LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -543,7 +505,6 @@ int main(int argc, char ** argv) { int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; - int n_past_guidance = 0; std::vector input_tokens; g_input_tokens = &input_tokens; std::vector output_tokens; g_output_tokens = &output_tokens; @@ -555,7 +516,6 @@ int main(int argc, char ** argv) { display = params.display_prompt; std::vector embd; - std::vector embd_guidance; // tokenized antiprompts std::vector> antiprompt_ids; @@ -565,8 +525,8 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); - if (!ctx_sampling) { + smpl = llama_sampling_init(model, sparams); + if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); } @@ -612,7 +572,7 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches - if (n_past + (int) embd.size() + std::max(0, guidance_offset) >= n_ctx) { + if (n_past + (int) embd.size() >= n_ctx) { if (params.n_predict == -2) { LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); break; @@ -629,11 +589,7 @@ int main(int argc, char ** argv) { n_past -= n_discard; - if (ctx_guidance) { - n_past_guidance -= n_discard; - } - - LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + LOG("after swap: n_past = %d\n", n_past); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); @@ -686,46 +642,6 @@ int main(int argc, char ** argv) { } } - // evaluate tokens in batches - // embd is typically prepared beforehand to fit within a batch, but not always - if (ctx_guidance) { - int input_size = 0; - llama_token * input_buf = NULL; - - if (n_past_guidance < (int) guidance_inp.size()) { - // Guidance context should have the same data with these modifications: - // - // * Replace the initial prompt - // * Shift everything by guidance_offset - embd_guidance = guidance_inp; - if (embd.begin() + original_prompt_len < embd.end()) { - embd_guidance.insert( - embd_guidance.end(), - embd.begin() + original_prompt_len, - embd.end() - ); - } - - input_buf = embd_guidance.data(); - input_size = embd_guidance.size(); - - LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); - } else { - input_buf = embd.data(); - input_size = embd.size(); - } - - for (int i = 0; i < input_size; i += params.n_batch) { - int n_eval = std::min(input_size - i, params.n_batch); - if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { - LOG_TEE("%s : failed to eval\n", __func__); - return 1; - } - - n_past_guidance += n_eval; - } - } - for (int i = 0; i < (int) embd.size(); i += params.n_batch) { int n_eval = (int) embd.size() - i; if (n_eval > params.n_batch) { @@ -755,7 +671,6 @@ int main(int argc, char ** argv) { } embd.clear(); - embd_guidance.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { // optionally save the session on first sample (for faster prompt loading next time) @@ -766,11 +681,11 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + const llama_token id = llama_sampling_sample(smpl, ctx, -1); - llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); + llama_sampling_accept(smpl, id, /* apply_grammar= */ true); - LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); embd.push_back(id); @@ -789,7 +704,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); + llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -832,7 +747,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); + const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -854,7 +769,7 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = llama_sampling_last(ctx_sampling); + llama_token last_token = llama_sampling_last(smpl); for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { @@ -871,7 +786,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) { + if (llama_token_is_eog(model, llama_sampling_last(smpl))) { LOG("found an EOG token\n"); if (params.interactive) { @@ -892,7 +807,7 @@ int main(int argc, char ** argv) { // if current token is not EOG, we add it to current assistant message if (params.conversation) { - auto id = llama_sampling_last(ctx_sampling); + auto id = llama_sampling_last(smpl); assistant_ss << llama_token_to_piece(ctx, id, false); } @@ -988,7 +903,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(ctx_sampling); + llama_sampling_reset(smpl); } is_interacting = false; } @@ -1013,14 +928,13 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); - if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); llama_backend_free(); ggml_threadpool_free(threadpool); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c9590622..7ce982a92e497 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -50,8 +50,8 @@ static std::vector k_prompts = { struct client { ~client() { - if (ctx_sampling) { - llama_sampling_free(ctx_sampling); + if (smpl) { + llama_sampling_free(smpl); } } @@ -72,7 +72,7 @@ struct client { std::string prompt; std::string response; - struct llama_sampling_context * ctx_sampling = nullptr; + struct llama_sampling * smpl = nullptr; }; static void print_date_time() { @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.smpl = llama_sampling_init(model, params.sparams); } std::vector tokens_system; @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - llama_sampling_reset(client.ctx_sampling); + llama_sampling_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); + const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling, ctx, id, true); + llama_sampling_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -371,7 +371,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); @@ -413,7 +413,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); - llama_print_timings(ctx); + // TODO: print sampling/grammar timings for all clients + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a9..0992ccc3c1808 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -26,8 +26,6 @@ int main(int argc, char ** argv) { return 1; } - srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed); - int n_junk = params.n_junk; int n_keep = params.n_keep; int n_grp = params.grp_attn_n; @@ -80,12 +78,13 @@ int main(int argc, char ** argv) { GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); @@ -217,20 +216,12 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -267,12 +258,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 484dd589109c7..987236ab683de 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) { print_build_info(); - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - fprintf(stderr, "%s: seed = %u\n", __func__, params.seed); - - std::mt19937 rng(params.seed); + LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed); llama_backend_init(); llama_numa_init(params.numa); @@ -2054,7 +2048,7 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); write_logfile(ctx, params, model, results); llama_free(ctx); diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 68cf8d3595e87..498cbbe3ce1cd 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include #include @@ -319,8 +319,7 @@ int main(int argc, char ** argv) { } auto cparams = llama_context_default_params(); - cparams.n_ctx = 256; - cparams.seed = 1; + cparams.n_ctx = 256; ctx = llama_new_context_with_model(model, cparams); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index aab9d81058af9..6089344a02bb7 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -294,8 +294,8 @@ int main(int argc, char ** argv) { } // clean up + llama_print_timings(ctx, nullptr); llama_batch_free(query_batch); - llama_print_timings(ctx); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3ea7c790d2bf7..02f7a93ebac19 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -3,12 +3,12 @@ #include #include -#include int main(int argc, char ** argv) { gpt_params params; params.prompt = "The quick brown fox"; + params.sparams.seed = 1234; if (!gpt_params_parse(argc, argv, params)) { gpt_params_print_usage(argc, argv, params); @@ -38,6 +38,11 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling_params sparams = llama_sampling_default_params(); + sparams.seed = params.sparams.seed; + + llama_sampling * smpl = llama_sampling_init(model, sparams); + // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -64,16 +69,11 @@ int main(int argc, char ** argv) { printf("\nfirst run: %s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx); - auto n_vocab = llama_n_vocab(model); + const auto * logits = llama_get_logits(ctx); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx, &candidates_p); + llama_sampling_set_logits(smpl, logits); + + auto next_token = llama_sampling_sample_dist(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -96,6 +96,8 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_sampling * smpl2 = llama_sampling_init(model, sparams); + printf("\nsecond run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file @@ -124,15 +126,11 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx2); - auto n_vocab = llama_n_vocab(model); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx2, &candidates_p); + const auto * logits = llama_get_logits(ctx2); + + llama_sampling_set_logits(smpl2, logits); + + auto next_token = llama_sampling_sample_dist(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -157,7 +155,9 @@ int main(int argc, char ** argv) { } // make new context - auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + llama_sampling * smpl3 = llama_sampling_init(model, sparams); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -215,15 +215,11 @@ int main(int argc, char ** argv) { // third run with seq 1 instead of 0 for (auto i = 0; i < params.n_predict; i++) { - auto * logits = llama_get_logits(ctx3); - auto n_vocab = llama_n_vocab(model); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); - } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx3, &candidates_p); + const auto * logits = llama_get_logits(ctx3); + + llama_sampling_set_logits(smpl3, logits); + + auto next_token = llama_sampling_sample_dist(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); @@ -240,6 +236,10 @@ int main(int argc, char ** argv) { printf("\n"); + llama_sampling_free(smpl); + llama_sampling_free(smpl2); + llama_sampling_free(smpl3); + llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index 805e05b4a5114..37024dea0055c 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -470,8 +470,6 @@ node index.js `frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled. - `penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`. - `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0. `mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0` @@ -724,7 +722,6 @@ Example: "stopping_word": "" }, "penalize_nl": true, - "penalty_prompt_tokens": [], "presence_penalty": 0.0, "prompt": "Say hello to llama.cpp", "repeat_last_n": 64, @@ -748,8 +745,7 @@ Example: "tfs_z": 1.0, "top_k": 40, "top_p": 0.949999988079071, - "typical_p": 1.0, - "use_penalty_prompt_tokens": false + "typical_p": 1.0 } ] ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cc65c57ab723c..139e503b9eb29 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,7 +3,6 @@ #include "common.h" #include "json-schema-to-grammar.h" #include "llama.h" -#include "grammar-parser.h" // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT @@ -169,11 +168,13 @@ struct server_slot { std::string stopping_word; // sampling - llama_token sampled; - struct llama_sampling_params sparams; - llama_sampling_context * ctx_sampling = nullptr; json json_schema; + struct gpt_sampling_params sparams; + + llama_token sampled; + llama_sampling * smpl = nullptr; + int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor int32_t ga_w = 512; // group-attention width @@ -651,8 +652,8 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } } @@ -883,8 +884,8 @@ struct server_context { bool launch_slot_with_task(server_slot & slot, const server_task & task) { slot_params default_params; // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) - llama_sampling_params default_sparams = params.sparams; - auto & data = task.data; + auto default_sparams = params.sparams; + const auto & data = task.data; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -901,7 +902,7 @@ struct server_context { slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p); slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p); slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); - slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p); + slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); @@ -923,7 +924,8 @@ struct server_context { if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) { send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST); return false; - } else if (data.contains("json_schema") && !data.contains("grammar")) { + } + if (data.contains("json_schema") && !data.contains("grammar")) { try { auto schema = json_value(data, "json_schema", json::object()); slot.sparams.grammar = json_schema_to_grammar(schema); @@ -973,56 +975,11 @@ struct server_context { } } - // penalize user-provided tokens - { - slot.sparams.penalty_prompt_tokens.clear(); - slot.sparams.use_penalty_prompt_tokens = false; - - const auto & penalty_prompt = data.find("penalty_prompt"); - - if (penalty_prompt != data.end()) { - if (penalty_prompt->is_string()) { - const auto penalty_prompt_string = penalty_prompt->get(); - slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false); - - if (slot.params.n_predict > 0) { - slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict); - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - else if (penalty_prompt->is_array()) { - const auto n_tokens = penalty_prompt->size(); - slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict)); - - const int n_vocab = llama_n_vocab(model); - for (const auto & penalty_token : *penalty_prompt) { - if (penalty_token.is_number_integer()) { - const auto tok = penalty_token.get(); - if (tok >= 0 && tok < n_vocab) { - slot.sparams.penalty_prompt_tokens.push_back(tok); - } - } - } - slot.sparams.use_penalty_prompt_tokens = true; - - LOG_VERBOSE("penalty_prompt_tokens", { - {"id_slot", slot.id}, - {"tokens", slot.sparams.penalty_prompt_tokens}, - }); - } - } - } - { slot.sparams.logit_bias.clear(); if (json_value(data, "ignore_eos", false) && has_eos_token) { - slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; + slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY}); } const auto & logit_bias = data.find("logit_bias"); @@ -1043,12 +1000,12 @@ struct server_context { if (el[0].is_number_integer()) { llama_token tok = el[0].get(); if (tok >= 0 && tok < n_vocab) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } else if (el[0].is_string()) { auto toks = llama_tokenize(model, el[0].get(), false); for (auto tok : toks) { - slot.sparams.logit_bias[tok] = bias; + slot.sparams.logit_bias.push_back({tok, bias}); } } } @@ -1070,26 +1027,27 @@ struct server_context { } { - const auto & samplers_sequence = data.find("samplers"); - if (samplers_sequence != data.end() && samplers_sequence->is_array()) { + const auto & samplers = data.find("samplers"); + if (samplers != data.end() && samplers->is_array()) { std::vector sampler_names; - for (const auto & sampler_name : *samplers_sequence) { + for (const auto & sampler_name : *samplers) { if (sampler_name.is_string()) { sampler_names.emplace_back(sampler_name); } } - slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false); + slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false); } else { - slot.sparams.samplers_sequence = default_sparams.samplers_sequence; + slot.sparams.samplers = default_sparams.samplers; } } { - if (slot.ctx_sampling != nullptr) { - llama_sampling_free(slot.ctx_sampling); + if (slot.smpl != nullptr) { + llama_sampling_free(slot.smpl); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); - if (slot.ctx_sampling == nullptr) { + + slot.smpl = llama_sampling_init(model, slot.sparams); + if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); return false; @@ -1178,11 +1136,6 @@ struct server_context { slot.generated_text += token_str; slot.has_next_token = true; - if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) { - // we can change penalty_prompt_tokens because it is always created from scratch each request - slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok); - } - // check if there is incomplete UTF-8 character at the end bool incomplete = false; for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) { @@ -1300,13 +1253,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model)); - const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); - - std::vector samplers_sequence; - samplers_sequence.reserve(slot.sparams.samplers_sequence.size()); - for (const auto & sampler_type : slot.sparams.samplers_sequence) { - samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type)); + std::vector samplers; + samplers.reserve(slot.sparams.samplers.size()); + for (const auto & sampler : slot.sparams.samplers) { + samplers.emplace_back(llama_sampling_type_to_str(sampler)); } return json { @@ -1321,13 +1271,11 @@ struct server_context { {"top_p", slot.sparams.top_p}, {"min_p", slot.sparams.min_p}, {"tfs_z", slot.sparams.tfs_z}, - {"typical_p", slot.sparams.typical_p}, + {"typical_p", slot.sparams.typ_p}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, {"frequency_penalty", slot.sparams.penalty_freq}, - {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens}, - {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens}, {"mirostat", slot.sparams.mirostat}, {"mirostat_tau", slot.sparams.mirostat_tau}, {"mirostat_eta", slot.sparams.mirostat_eta}, @@ -1336,13 +1284,13 @@ struct server_context { {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, - {"ignore_eos", ignore_eos}, + {"ignore_eos", slot.sparams.ignore_eos}, {"stream", slot.params.stream}, - {"logit_bias", slot.sparams.logit_bias}, + //{"logit_bias", slot.sparams.logit_bias}, {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers_sequence} + {"samplers", samplers}, }; } @@ -2136,7 +2084,7 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.ctx_sampling); + llama_sampling_reset(slot.smpl); if (!slot.params.cache_prompt) { slot.n_past_se = 0; @@ -2149,7 +2097,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2202,7 +2150,7 @@ struct server_context { slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.ctx_sampling); + llama_sampling_reset(slot.smpl); } // remove the non-common part from the cache @@ -2384,9 +2332,9 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); + const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + llama_sampling_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2395,34 +2343,17 @@ struct server_context { metrics.on_prompt_eval(slot); } - llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); - if (n_probs > 0) { - const size_t n_valid = slot.ctx_sampling->n_valid; + const auto * cur_p = llama_sampling_get_candidates(slot.smpl); - // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); - } - - if (slot.sparams.temp == 0.0f) { - // With greedy sampling the probabilities have possibly not been calculated. - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i == 0 ? 1.0f : 0.0f - }); - } - } else { - for (size_t i = 0; i < n_probs; ++i) { - result.probs.push_back({ - cur_p.data[i].id, - i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. - }); - } - } + // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 + // fix if necessary + for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { + result.probs.push_back({ + cur_p->data[i].id, + i >= cur_p->size ? 0.0f : cur_p->data[i].p, + }); } if (!process_token(result, slot)) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 69a92cf7dc0c0..674158b857c4c 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,6 +55,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + // tokenize the prompt std::vector tokens_list; @@ -110,20 +112,12 @@ int main(int argc, char ** argv) { while (n_cur <= n_predict) { // sample the next token { - auto n_vocab = llama_n_vocab(model); - auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - - std::vector candidates; - candidates.reserve(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); - } + const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_sampling_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -160,12 +154,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 1616edecbbef6..e950733665303 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -21,7 +21,7 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct llama_sampling_context * ctx_sampling; + struct llama_sampling * smpl; }; int main(int argc, char ** argv) { @@ -37,16 +37,16 @@ int main(int argc, char ** argv) { return 1; } + // for probabilities to be computed even with temp = 0 + params.sparams.n_probs = 16; + // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; // probability threshold for splitting a draft branch (only for n_seq_dft > 1) const float p_split = params.p_split; - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - std::default_random_engine rng(params.seed); + std::default_random_engine rng(params.sparams.seed); std::uniform_real_distribution<> u_dist; #ifndef LOG_DISABLE_LOGS @@ -179,19 +179,15 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; - // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + // target model sampling context (reuse the llama_context's sampling instance) + struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams); // draft sequence data std::vector drafts(n_seq_dft); - params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar - if (params.sparams.temp == 0) { - params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model - } - for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + // allocate llama_sampling for each draft sequence + drafts[s].smpl = llama_sampling_init(model_dft, params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -234,9 +230,15 @@ int main(int argc, char ** argv) { if (params.sparams.temp > 0) { // stochastic verification - llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); - llama_sample_softmax(ctx_tgt, &dist_tgt); - float p_tgt = 0, p_dft = 0; + llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); + + auto & dist_tgt = *llama_sampling_get_candidates(smpl); + + llama_sampling_grammar(smpl, &dist_tgt); + llama_sampling_softmax(smpl, &dist_tgt); + + float p_tgt = 0.0f; + float p_dft = 0.0f; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); @@ -278,7 +280,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(smpl, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -332,8 +334,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sample_token(ctx_tgt, &dist_tgt); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_id = llama_sampling_sample_dist(smpl, &dist_tgt); + llama_sampling_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -342,11 +344,11 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(smpl, token_id, true); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); token_str = llama_token_to_piece(ctx_tgt, token_id); @@ -434,7 +436,7 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + llama_sampling_cp(smpl, drafts[0].smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -463,20 +465,20 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); - const auto & cur_p = drafts[s].ctx_sampling->cur; + const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl); - for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { + for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); + k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } std::vector sa(1, s); // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { + if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); @@ -503,7 +505,7 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl); sa.push_back(n_seq_cur); @@ -515,15 +517,15 @@ int main(int argc, char ** argv) { // add drafted token for each sequence for (int is = 0; is < (int) sa.size(); ++is) { - const llama_token id = cur_p[is].id; + const llama_token id = cur_p->data[is].id; const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); + llama_sampling_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists - drafts[s].dists.push_back(cur_p); + drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); // add unique drafted tokens to the target batch drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); @@ -594,14 +596,15 @@ int main(int argc, char ** argv) { LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ndraft:\n"); - llama_print_timings(ctx_dft); + // TODO: print sampling/grammar timings for all drafts + llama_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt); + llama_print_timings(ctx_tgt, smpl); - llama_sampling_free(ctx_sampling); + llama_sampling_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - llama_sampling_free(drafts[s].ctx_sampling); + llama_sampling_free(drafts[s].smpl); } llama_batch_free(batch_dft); diff --git a/include/llama.h b/include/llama.h index a495e866d5a1a..099c9e7465b9c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -33,16 +33,21 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +// TODO: use everywhere in the implementation +#define LLAMA_TOKEN_NULL -1 + #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_MAX_SAMPLERS 16 + #ifdef __cplusplus extern "C" { #endif @@ -53,8 +58,10 @@ extern "C" { // TODO: show sample usage // + // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; + struct llama_sampling; typedef int32_t llama_pos; typedef int32_t llama_token; @@ -201,6 +208,16 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; + enum llama_sampler_type { + LLAMA_SAMPLER_TYPE_NONE = 0, + LLAMA_SAMPLER_TYPE_TOP_K = 1, + LLAMA_SAMPLER_TYPE_TOP_P = 2, + LLAMA_SAMPLER_TYPE_MIN_P = 3, + LLAMA_SAMPLER_TYPE_TFS_Z = 4, + LLAMA_SAMPLER_TYPE_TYPICAL_P = 5, + LLAMA_SAMPLER_TYPE_TEMPERATURE = 6, + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -208,6 +225,7 @@ extern "C" { } llama_token_data; typedef struct llama_token_data_array { + // TODO: consider SoA llama_token_data * data; size_t size; bool sorted; @@ -302,7 +320,6 @@ extern "C" { // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggerganov/llama.cpp/pull/7544 struct llama_context_params { - uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size @@ -330,7 +347,8 @@ extern "C" { enum ggml_type type_k; // data type for K cache [EXPERIMENTAL] enum ggml_type type_v; // data type for V cache [EXPERIMENTAL] - // Keep the booleans together to avoid misalignment during copy-by-value. + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + // TODO: move at the end of the struct bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU @@ -358,53 +376,56 @@ extern "C" { void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; - // grammar types - struct llama_grammar; - - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; + typedef struct llama_logit_bias { + llama_token token; + float bias; + } llama_logit_bias; + + // parameters for sampling the logits + typedef struct llama_sampling_params { + uint32_t seed; // the seed used to initialize llama_sampling_context + int32_t n_prev; // number of previous tokens to remember + int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens. + int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t top_k; // <= 0 to use vocab size + float top_p; // 1.0 = disabled + float min_p; // 0.0 = disabled + float tfs_z; // 1.0 = disabled + float typ_p; // typical_p, 1.0 = disabled + float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities + float dynatemp_range; // 0.0 = disabled + float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler + int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat; // 1.0 = disabled + float penalty_freq; // 0.0 = disabled + float penalty_present; // 0.0 = disabled + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau; // target entropy + float mirostat_eta; // learning rate + + // samplers + int32_t n_samplers; + enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS]; + + // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. + bool penalize_nl; // consider newlines as a repeatable token + bool ignore_eos; // ignore the end-of-sequence token + } llama_sampling_params; // performance timing information struct llama_timings { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sample_ms; + double t_sampling_ms; + double t_grammar_ms; + double t_accept_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sample; + int32_t n_sampling; + int32_t n_grammar; + int32_t n_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -419,8 +440,9 @@ extern "C" { struct llama_lora_adapter; // Helpers for getting default parameters - LLAMA_API struct llama_model_params llama_model_default_params(void); - LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_params llama_model_default_params(void); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -447,6 +469,7 @@ extern "C" { LLAMA_API void llama_free_model(struct llama_model * model); + // TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params); @@ -462,23 +485,22 @@ extern "C" { LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); - LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); + + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -706,7 +728,7 @@ extern "C" { // // Returns the *actual* size in bytes of the state - // (rng, logits, embedding and kv_cache) + // (logits, embedding and kv_cache) // Only use when saving the state, not when restoring it, otherwise the size may be too small. LLAMA_API size_t llama_state_get_size(struct llama_context * ctx); LLAMA_API DEPRECATED(size_t llama_get_state_size(struct llama_context * ctx), @@ -1009,158 +1031,130 @@ extern "C" { int32_t length); // - // Grammar + // Sampling functions // - /// Initialize a llama_grammar. - /// - /// @param rules The rule elements of the grammar to initialize. - /// @param n_rules The number of rules. - /// @param start_rule_index The index of the root rule (the starting point of the grammar). - /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + // TODO: llama_model should become llama_vocab + LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); - LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + // Copies the internal state of the sampler (rng, prev, params, grammar, etc.) + LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - /// @details Apply constraints from grammar - LLAMA_API void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates); - LLAMA_API DEPRECATED(void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar), - "use llama_grammar_sample instead"); + // - clear prev token + // - reset grammar state + LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); - /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token); + // Sampling parameter mutation + // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable + LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - // - // Sampling functions - // + // Set the logits from which to sample. + // This call initializes the internal token candidates array. + // The internal candidates are implicitly used by the sampling API below when no candidates are provided. + LLAMA_API void llama_sampling_set_logits( + struct llama_sampling * smpl, + const float * logits); - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + /// @details Returns the current candidate tokens. + LLAMA_API llama_token_data_array * llama_sampling_get_candidates( + struct llama_sampling * smpl); - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); - - /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 - /// @param logits Logits extracted from the original generation context. - /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. - /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale); + // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. + // Each function can accept an array of token candidates. If the candidates are not provided, the internal + // candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax( - struct llama_context * ctx, - llama_token_data_array * candidates); + LLAMA_API void llama_sampling_softmax( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k( - struct llama_context * ctx, - llama_token_data_array * candidates, - int32_t k, - size_t min_keep); + LLAMA_API void llama_sampling_top_k( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_top_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sample_min_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_min_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free( - struct llama_context * ctx, - llama_token_data_array * candidates, - float z, - size_t min_keep); + LLAMA_API void llama_sampling_tail_free( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_typical( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API void llama_sample_entropy( - struct llama_context * ctx, - llama_token_data_array * candidates_p, - float min_temp, - float max_temp, - float exponent_val); + /// @details Apply temperature and entropy + LLAMA_API void llama_sampling_temp( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - LLAMA_API void llama_sample_temp( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp); - - /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - int32_t m, - float * mu); - - /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. - /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. - /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. - /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); + /// @details Apply constraints from grammar + LLAMA_API void llama_sampling_grammar( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sample_softmax() instead. - LLAMA_API llama_token llama_sample_token_greedy( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + LLAMA_API void llama_sampling_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. - LLAMA_API llama_token llama_sample_token( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + LLAMA_API llama_token llama_sampling_sample_mirostat( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Selects the token with the highest probability. + /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. + LLAMA_API llama_token llama_sampling_sample_greedy( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Randomly selects a token from the candidates based on their probability distribution. + LLAMA_API llama_token llama_sampling_sample_dist( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). + LLAMA_API llama_token llama_sampling_sample( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the sampling context. + /// - adds it to "prev" tokens + /// - updates the grammar state (if apply_grammar is true) + LLAMA_API void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token, + bool apply_grammar); + + /// @details Get the number of accepted tokens so far (max of n_prev) + LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); + + /// @details Get the ith accepted token + /// @param ith [0, n_prev), ith == 0 is the last accepted token. + /// returns LLAMA_TOKEN_NULL if ith is out of bounds + LLAMA_API llama_token llama_sampling_prev( + const struct llama_sampling * smpl, + int32_t ith); + + /// @details Get the last accepted token + /// Same as llama_sampling_prev(smpl, 0) + /// returns LLAMA_TOKEN_NULL if there are no accepted tokens + LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); // // Model split @@ -1179,8 +1173,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx); - LLAMA_API void llama_reset_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl); // Print system information LLAMA_API const char * llama_print_system_info(void); @@ -1195,59 +1189,4 @@ extern "C" { } #endif -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL - -#include -#include -#include - -struct ggml_tensor; - -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); - -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - -// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. -// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); - -#endif // LLAMA_API_INTERNAL - #endif // LLAMA_H diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b123d733100ce..8cd98bae4dba6 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -4,10 +4,29 @@ #include "llama-sampling.h" #include +#include -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( +// +// helpers +// + +// NOTE: assumes valid utf8 (but checks for overrun) +static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} + +static std::pair, llama_partial_utf8> decode_utf8( const std::string & src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; @@ -40,7 +59,7 @@ std::pair, llama_partial_utf8> decode_utf8( while (*pos != 0) { uint8_t first_byte = static_cast(*pos); uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; + n_remain = lookup[highbits] - 1; if (n_remain < 0) { // invalid sequence, abort @@ -50,7 +69,7 @@ std::pair, llama_partial_utf8> decode_utf8( } uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; + value = first_byte & mask; ++pos; while (*pos != 0 && n_remain > 0) { @@ -67,12 +86,510 @@ std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } -const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { - return grammar->rules; +static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; } -llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { - return grammar->stacks; +static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +} + +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; + } + } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); +} + +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } + } + return pos; +} + +static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; +} + +static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; +} + +static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); + } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); +} + +static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } +} + +static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; + } +} + +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; + } + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; + } + } + fprintf(file, "\n"); +} + +static void print_rule( + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); +} + +// +// implementation +// + +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(src, len), next_id); + return result.first->second; +} + +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { + uint32_t next_id = static_cast(symbol_ids.size()); + symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; +} + +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = rule; +} + +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + llama_grammar_rule rule; + const char * pos = parse_sequence(src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(rule_id, rule); + return pos; +} + +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { + size_t last_sym_start = rule.size(); + const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } + } + return pos; + } + +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + +bool llama_grammar_parser::parse(const char * src) { + try { + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(pos); + } + // Validate the state to ensure that all rules are defined + for (const auto & rule : rules) { + if (rule.empty()) { + throw std::runtime_error("Undefined rule"); + } + for (const auto & elem : rule) { + if (elem.type == LLAMA_GRETYPE_RULE_REF) { + // Ensure that the rule at that location exists + if (elem.value >= rules.size() || rules[elem.value].empty()) { + // Get the name of the rule that is missing + for (const auto & kv : symbol_ids) { + if (kv.second == elem.value) { + throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); + } + } + } + } + } + } + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + rules.clear(); + return false; + } + + return true; +} + +void llama_grammar_parser::print(FILE * file) { + try { + std::map symbol_id_names; + for (const auto & kv : symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, rules[i]); + print_rule(file, uint32_t(i), rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } +} + +llama_grammar_stack llama_grammar_parser::c_rules() const { + llama_grammar_stack ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; } // returns true iff pos points to the end of one of the definitions of a rule @@ -89,7 +606,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) static std::pair llama_grammar_match_char( const llama_grammar_element * pos, const uint32_t chr) { - bool found = false; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; @@ -225,16 +741,92 @@ static void llama_grammar_advance_stack( } } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept( +static llama_grammar_candidates llama_grammar_reject_candidates( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const llama_grammar_candidates & candidates) { + GGML_ASSERT(!stacks.empty()); // REVIEW + + if (candidates.empty()) { + return {}; + } + + auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + + return rejects; +} + +static bool llama_grammar_detect_left_recursion( + const llama_grammar_rules & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } + + (*rules_in_progress)[rule_index] = true; + + const llama_grammar_rule & rule = rules[rule_index]; + + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } + + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + + return false; +} + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} + +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +llama_grammar_stacks llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); + const uint32_t chr) { + llama_grammar_stacks result; + result.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -250,27 +842,11 @@ void llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + llama_grammar_advance_stack(rules, new_stack, result); } } -} -static llama_grammar_candidates llama_grammar_reject_candidates( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const llama_grammar_candidates & candidates) { - GGML_ASSERT(!stacks.empty()); // REVIEW - - if (candidates.empty()) { - return {}; - } - - auto rejects = llama_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); - - for (size_t i = 1, size = stacks.size(); i < size; ++i) { - rejects = llama_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); - } - return rejects; + return result; } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -328,72 +904,97 @@ llama_grammar_candidates llama_grammar_reject_candidates_for_stack( return rejects; } -static bool llama_grammar_detect_left_recursion( - const llama_grammar_rules & rules, - size_t rule_index, - std::vector * rules_visited, - std::vector * rules_in_progress, - std::vector * rules_may_be_empty) { - if ((*rules_in_progress)[rule_index]) { - return true; - } +//////////////////// - (*rules_in_progress)[rule_index] = true; +struct llama_grammar * llama_grammar_init_impl( + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { + const llama_grammar_element * pos; - const llama_grammar_rule & rule = rules[rule_index]; + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } - // First check if the rule might produce the empty string. This could be done combined with the second - // step but it's more readable as two steps. - bool at_rule_start = true; - for (size_t i = 0; i < rule.size(); i++) { - if (llama_grammar_is_end_of_sequence(&rule[i])) { - if (at_rule_start) { - (*rules_may_be_empty)[rule_index] = true; - break; - } - at_rule_start = true; - } else { - at_rule_start = false; + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; } } - // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may - // be empty) - bool recurse_into_nonterminal = true; - for (size_t i = 0; i < rule.size(); i++) { - if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { - if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { - return true; - } - if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { - recurse_into_nonterminal = false; - } - } else if (llama_grammar_is_end_of_sequence(&rule[i])) { - recurse_into_nonterminal = true; + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; } else { - recurse_into_nonterminal = false; + break; } - } + } while (true); - (*rules_in_progress)[rule_index] = false; - (*rules_visited)[rule_index] = true; - return false; + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } -// -// grammar - external -// +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) { + llama_grammar_parser parser; + + // if there is a grammar, parse it + if (!parser.parse(grammar_str)) { + return nullptr; + } + + // will be empty (default) if there are parse errors + if (parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + // Ensure that there is a "root" node. + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + return nullptr; + } + + std::vector grammar_rules(parser.c_rules()); + + const size_t n_rules = grammar_rules.size(); + const size_t start_rule_index = parser.symbol_ids.at(grammar_root); -struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { const llama_grammar_element * pos; // copy rule definitions into vectors llama_grammar_rules vec_rules(n_rules); for (size_t i = 0; i < n_rules; i++) { - for (pos = rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { vec_rules[i].push_back(*pos); } vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); @@ -438,22 +1039,22 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) { + llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { result->stacks[is][ie] = &result->rules[ir0][ir1]; } } @@ -464,14 +1065,11 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(grammar); - GGML_ASSERT(vocab); - - int64_t t_start_sample_us = ggml_time_us(); +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) { + GGML_ASSERT(grammar.vocab != nullptr); bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -486,33 +1084,31 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + GGML_ASSERT(grammar.vocab != nullptr); - if (llama_token_is_eog_impl(*vocab, token)) { - for (const auto & stack : grammar->stacks) { + if (llama_token_is_eog_impl(*grammar.vocab, token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -520,20 +1116,17 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; - llama_grammar_stacks tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - grammar->stacks = tmp_new_stacks; + llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); + grammar.stacks = std::move(new_stacks); } - grammar->partial_utf8 = decoded.second; - GGML_ASSERT(!grammar->stacks.empty()); - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; + grammar.partial_utf8 = decoded.second; + GGML_ASSERT(!grammar.stacks.empty()); } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 695ea0632bb84..9b13354f67c74 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -2,11 +2,114 @@ #include "llama-impl.h" +#include + struct llama_vocab; -struct llama_sampling; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +llama_grammar_stacks llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + uint32_t chr); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +struct llama_grammar_parser { + std::map symbol_ids; + + llama_grammar_rules rules; + + llama_grammar_stack c_rules() const; + + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); + + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); + + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); + + const char * parse_rule(const char * src); + + bool parse(const char * src); + void print(FILE * file); +}; struct llama_grammar { - const llama_grammar_rules rules; + // note: allow null vocab for testing (not great) + const llama_vocab * vocab; + + const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens @@ -17,23 +120,24 @@ struct llama_grammar { // internal API // +// note: needed for tests (not great) struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + const struct llama_vocab * vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar); -void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, llama_token_data_array * candidates); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, llama_token token); diff --git a/src/llama-impl.h b/src/llama-impl.h index 9527740961da6..b67f511c08157 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,8 +1,11 @@ #pragma once -#define LLAMA_API_INTERNAL #include "llama.h" +#include +#include +#include + #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -45,3 +48,114 @@ static void replace_all(std::string & s, const std::string & search, const std:: builder.append(s, last_pos, std::string::npos); s = std::move(builder); } + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); + +// the ring buffer works similarly to std::deque, but with a fixed capacity +template +struct ring_buffer { + ring_buffer() {} + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + //T & operator[](size_t i) { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + //const T & at(size_t i) const { + // if (i >= sz) { + // throw std::runtime_error("ring buffer: index out of bounds"); + // } + // return data[(first + i) % capacity]; + //} + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8f4841d9daf7b..8abfc3fc6d86a 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,5 +1,8 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" + #include #include #include @@ -21,18 +24,104 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { +llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) { +} + +llama_sampling::~llama_sampling() { + if (grammar) { + llama_grammar_free_impl(grammar); + } +} + +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { + auto * result = new llama_sampling(vocab); + + result->params = params; + + result->prev = ring_buffer(params.n_prev); + + for (int i = 0; i < params.n_samplers; ++i) { + result->samplers.push_back(params.samplers[i]); + } + + llama_sampling_set_rng_seed_impl(*result, params.seed); + + return result; +} + +void llama_sampling_free_impl(struct llama_sampling * sampling) { + delete sampling; +} + +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { + auto * result = new llama_sampling(smpl.vocab); + + result->params = smpl.params; + + result->grammar_str = smpl.grammar_str; + result->grammar_root = smpl.grammar_root; + + result->logit_bias = smpl.logit_bias; + + if (smpl.grammar) { + result->grammar = llama_grammar_cp_impl(*smpl.grammar); + } + + result->rng = smpl.rng; + result->prev = smpl.prev; + + return result; +} + +void llama_sampling_reset_impl(struct llama_sampling & smpl) { + if (smpl.grammar) { + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } + + if (!smpl.grammar_str.empty()) { + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); + } + + smpl.prev.clear(); +} + +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } - smpl->rng.seed(seed); + smpl.rng.seed(seed); } -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); +void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { + if (smpl.grammar) { + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } - const int64_t t_start_sample_us = ggml_time_us(); + if (grammar_str != nullptr && grammar_str[0] != '\0') { + smpl.grammar_str = grammar_str; + smpl.grammar_root = grammar_root; + + smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root); + } else { + smpl.grammar_str.clear(); + smpl.grammar_root.clear(); + } +} + +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + smpl.logit_bias.clear(); + smpl.logit_bias.reserve(n_logit_bias); + + for (int32_t i = 0; i < n_logit_bias; ++i) { + smpl.logit_bias.push_back(logit_bias[i]); + } +} + +void llama_sampling_softmax_impl(llama_token_data_array * candidates) { + GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order if (!candidates->sorted) { @@ -44,28 +133,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar float max_l = candidates->data[0].logit; float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { float p = expf(candidates->data[i].logit - max_l); candidates->data[i].p = p; cum_sum += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; // } - const int64_t t_start_sample_us = ggml_time_us(); - if (k <= 0) { k = candidates->size; } @@ -101,10 +186,12 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra int ib = nbuckets - 1; for ( ; ib >= 0; --ib) { nhave += histo[ib]; - if (nhave >= k) break; + if (nhave >= k) { + break; + } } std::vector tmp_tokens(nhave); - auto ptr = tmp_tokens.data(); + auto * ptr = tmp_tokens.data(); std::vector bucket_ptrs; bucket_ptrs.reserve(nbuckets - ib); for (int j = nbuckets - 1; j >= ib; --j) { @@ -133,20 +220,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra candidates->sorted = true; } candidates->size = k; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax_impl(smpl, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -165,19 +246,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - bool min_p_applied = false; // if the candidates aren't sorted, try the unsorted implementation first @@ -226,19 +301,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the matching tokens candidates->size = i; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -285,13 +355,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_ // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -299,9 +365,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -349,15 +413,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); candidates->size = new_candidates.size(); candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -366,7 +424,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sampling_softmax_impl(candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -398,13 +456,15 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar } // Re-compute softmax probabilities after scaling logits with dynamic temperature - double max_l_double = candidates->data[0].logit; + const double max_l_double = candidates->data[0].logit; + double cum_sum_double = 0.0; for (size_t i = 0; i < candidates->size; ++i) { double p = exp(candidates->data[i].logit - max_l_double); candidates->data[i].p = p; // Store the scaled probability cum_sum_double += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities } @@ -416,44 +476,24 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); } #endif - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } +} - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } +void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { + llama_grammar_apply_impl(grammar, candidates); } -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +void llama_sampling_penalties_impl( llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - // Create a frequency map to count occurrences of each token in last_tokens - std::unordered_map token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[last_tokens[i]]++; - } - + const llama_token_cnt & token_count, + float penalty_repeat, + float penalty_freq, + float penalty_present) { // Apply frequency and presence penalties to the candidates for (size_t i = 0; i < candidates->size; ++i) { const auto token_iter = token_count.find(candidates->data[i].id); @@ -475,43 +515,10 @@ void llama_sample_repetition_penalties_impl( } candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale) { - GGML_ASSERT(smpl); - - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = smpl->n_vocab; - - llama_log_softmax(logits, n_vocab); - llama_log_softmax(logits_guidance, n_vocab); - - for (int i = 0; i < n_vocab; ++i) { - auto & l = logits[i]; - const auto & g = logits_guidance[i]; - - l = scale * (l - g) + g; - } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(smpl); - - const int32_t n_vocab = float(smpl->n_vocab); - - int64_t t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_sampling_softmax_impl(candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -527,13 +534,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_sampling_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampling_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -543,93 +548,88 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; return X; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_sampling_softmax_impl(candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > *mu; + return -log2f(candidate.p) > mu; })); if (candidates->size == 0) { candidates->size = 1; } - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } - // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(candidates); // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_token X = llama_sampling_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); + float observed_surprise = -log2f(candidates->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error - *mu = *mu - eta * e; + mu = mu - eta * e; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } return X; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - +llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } + return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(smpl); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sampling_softmax_impl(candidates); std::vector probs; probs.reserve(candidates->size); + for (size_t i = 0; i < candidates->size; ++i) { probs.push_back(candidates->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); - int idx = dist(rng); + const int idx = dist(rng); llama_token result = candidates->data[idx].id; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) { + smpl.prev.push_back(token); + + if (apply_grammar && smpl.grammar) { + llama_grammar_accept_impl(*smpl.grammar, token); + } +} + +llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.prev.size()) { + return LLAMA_TOKEN_NULL; + } + + return smpl.prev.rat(ith); +} + +int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { + return smpl.prev.size(); } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef706bc..c51542259e27d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,56 +1,107 @@ #pragma once -#include "llama-impl.h" +#include "llama-grammar.h" + +#include +#include + +struct llama_vocab; +struct llama_grammar; + +using llama_token_cnt = std::unordered_map; struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + llama_sampling(const struct llama_vocab & vocab); + ~llama_sampling(); + + llama_sampling_params params; + + std::string grammar_str; + std::string grammar_root; + + std::vector logit_bias; // logit biases to apply + + // state std::mt19937 rng; - int32_t n_vocab = 0; + const struct llama_vocab & vocab; + + std::vector samplers; + + ring_buffer prev; + + struct llama_grammar * grammar = nullptr; - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; + // mirostat sampler state + float mirostat_mu; - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } + mutable int64_t t_sample_us = 0; + mutable int64_t t_grammar_us = 0; + mutable int64_t t_accept_us = 0; + + mutable int32_t n_sample = 0; + mutable int32_t n_grammar = 0; + mutable int32_t n_accept = 0; + + std::vector cur; + + llama_token_data_array cur_p; }; // // internal API // -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params); + +void llama_sampling_free_impl(struct llama_sampling * sampling); + +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); +void llama_sampling_reset_impl(struct llama_sampling & smpl); -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +// TODO: move the API below as member functions of llama_sampling +void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); +void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); +void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); + +void llama_sampling_softmax_impl (struct llama_token_data_array * candidates); +void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp); +void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); + +void llama_sampling_penalties_impl( llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, + const llama_token_cnt & token_count, float penalty_repeat, float penalty_freq, float penalty_present); -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, - float * logits, - float * logits_guidance, - float scale); +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); + +llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates); +llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar); +llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); +int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 6e8f30be43ba1..dc4b5f12f7860 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -18,6 +18,8 @@ struct llama_vocab { tattr attr; }; + uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -62,8 +64,6 @@ struct llama_vocab { int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; }; -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); - // // internal API // @@ -76,6 +76,7 @@ std::vector llama_tokenize_internal( bool add_special, bool parse_special = false); +// TODO: move the API below as member functions of llama_vocab llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index 1a78112a3a84d..258a568421347 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,6 +1,5 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-grammar.h" #include "llama-sampling.h" #include "unicode.h" @@ -148,6 +147,19 @@ static void zeros(std::ofstream & file, size_t n) { } } +struct time_meas { + time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + t_acc += ggml_time_us() - t_start_us; + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + + LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -3179,7 +3191,6 @@ struct llama_sbatch { struct llama_context { llama_context(const llama_model & model) : model(model) - , sampling(llama_n_vocab(&model)) , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -3196,7 +3207,6 @@ struct llama_context { const struct llama_model & model; struct llama_cparams cparams; - struct llama_sampling sampling; struct llama_sbatch sbatch; struct llama_kv_cache kv_self; struct llama_control_vector cvec; @@ -3217,16 +3227,16 @@ struct llama_context { bool has_evaluated_once = false; - int64_t t_start_us; - int64_t t_load_us; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; - int64_t t_compute_start_us = 0; - int64_t n_queued_tokens = 0; + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls // host buffer for the model output (logits and embeddings) ggml_backend_buffer_t buf_output = nullptr; @@ -6251,6 +6261,7 @@ static void llm_load_vocab( const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + vocab.n_vocab = n_vocab; vocab.id_to_token.resize(n_vocab); for (uint32_t i = 0; i < n_vocab; i++) { @@ -17892,7 +17903,6 @@ struct llama_model_params llama_model_default_params() { struct llama_context_params llama_context_default_params() { struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, @@ -17925,6 +17935,36 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_sampling_params llama_sampling_default_params() { + struct llama_sampling_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 64, + /*.n_probs =*/ 0, + /*.min_keep =*/ 0, + /*.top_k =*/ 40, + /*.top_p =*/ 0.95f, + /*.min_p =*/ 0.05f, + /*.tfs_z =*/ 1.00f, + /*.typ_p =*/ 1.00f, + /*.temp =*/ 0.80f, + /*.dynatemp_range =*/ 0.00f, + /*.dynatemp_exponent =*/ 1.00f, + /*.penalty_last_n =*/ 64, + /*.penalty_repeat =*/ 1.00f, + /*.penalty_freq =*/ 0.00f, + /*.penalty_present =*/ 0.00f, + /*.mirostat =*/ 0, + /*.mirostat_tau =*/ 5.00f, + /*.mirostat_eta =*/ 0.10f, + /*.n_samplers =*/ 3, + /*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, }, + /*.penalize_nl =*/ false, + /*.ignore_eos =*/ false, + }; + + return result; +} + struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -18178,10 +18218,6 @@ struct llama_context * llama_new_context_with_model( cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; } - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, cparams.n_ctx); LLAMA_LOG_INFO("%s: n_batch = %u\n", __func__, cparams.n_batch); LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); @@ -18192,10 +18228,10 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + ctx->logits_all = params.logits_all; + // build worst-case graph for encoder if a model contains encoder - ctx->is_encoding = llama_model_has_encoder(model); + ctx->is_encoding = llama_model_has_encoder(model); uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -18473,14 +18509,6 @@ void llama_free(struct llama_context * ctx) { delete ctx; } -const struct llama_model * llama_get_model(const struct llama_context * ctx) { - return &ctx->model; -} - -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) { - return &ctx->model.vocab; -} - uint32_t llama_n_ctx(const struct llama_context * ctx) { return ctx->cparams.n_ctx; } @@ -18501,6 +18529,30 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } +int32_t llama_n_vocab(const struct llama_model * model) { + return model->hparams.n_vocab; +} + +int32_t llama_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_n_layer(const struct llama_model * model) { + return model->hparams.n_layer; +} + +const struct llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + enum llama_rope_type llama_rope_type(const struct llama_model * model) { switch (model->arch) { // these models do not use RoPE @@ -18564,26 +18616,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } -enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { - return ctx->cparams.pooling_type; -} - -int32_t llama_n_vocab(const struct llama_model * model) { - return model->hparams.n_vocab; -} - -int32_t llama_n_ctx_train(const struct llama_model * model) { - return model->hparams.n_ctx_train; -} - -int32_t llama_n_embd(const struct llama_model * model) { - return model->hparams.n_embd; -} - -int32_t llama_n_layer(const struct llama_model * model) { - return model->hparams.n_layer; -} - float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -19000,14 +19032,14 @@ struct llama_data_write { // TODO: add more model-specific info which should prevent loading the session file if not identical } - void write_rng(const std::mt19937 & rng) { - std::ostringstream rng_ss; - rng_ss << rng; + //void write_rng(const std::mt19937 & rng) { + // std::ostringstream rng_ss; + // rng_ss << rng; - const std::string & rng_str = rng_ss.str(); + // const std::string & rng_str = rng_ss.str(); - write_string(rng_str); - } + // write_string(rng_str); + //} void write_output_ids(struct llama_context * ctx) { llama_output_reorder(ctx); @@ -19227,17 +19259,17 @@ struct llama_data_read { // TODO: add more info which needs to be identical but which is not verified otherwise } - void read_rng(std::mt19937 & rng) { - std::string rng_str; - read_string(rng_str); + //void read_rng(std::mt19937 & rng) { + // std::string rng_str; + // read_string(rng_str); - std::istringstream rng_ss(rng_str); - rng_ss >> rng; + // std::istringstream rng_ss(rng_str); + // rng_ss >> rng; - if (rng_ss.fail()) { - throw std::runtime_error("failed to load RNG state"); - } - } + // if (rng_ss.fail()) { + // throw std::runtime_error("failed to load RNG state"); + // } + //} void read_output_ids(struct llama_context * ctx) { std::vector output_pos; @@ -19667,8 +19699,6 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_model_info(ctx); - data_ctx.write_rng(ctx->sampling.rng); - // copy outputs data_ctx.write_output_ids(ctx); data_ctx.write_logits(ctx); @@ -19706,9 +19736,6 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_model_info(ctx); - // set rng - data_ctx.read_rng(ctx->sampling.rng); - // set outputs data_ctx.read_output_ids(ctx); data_ctx.read_logits(ctx); @@ -20111,8 +20138,9 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); -#endif +#else return nullptr; +#endif } } @@ -20160,8 +20188,9 @@ float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG GGML_ABORT("fatal error"); -#endif +#else return nullptr; +#endif } } @@ -20595,124 +20624,349 @@ int32_t llama_chat_apply_template( } // -// grammar +// sampling // -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); +struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { + return llama_sampling_init_impl(model->vocab, params); } -void llama_grammar_free(struct llama_grammar * grammar) { - llama_grammar_free_impl(grammar); +void llama_sampling_free(struct llama_sampling * smpl) { + if (smpl == nullptr) { + return; + } + + llama_sampling_free_impl(smpl); } -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); +struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { + return llama_sampling_cp_impl(*smpl); } -void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates) { - llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); +void llama_sampling_reset(struct llama_sampling * smpl) { + llama_sampling_reset_impl(*smpl); } -void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar) { - llama_grammar_sample(grammar, ctx, candidates); +void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { + llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); } -void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); +void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); } -// -// sampling -// +void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { + const int n_vocab = smpl->vocab.n_vocab; + + smpl->cur.resize(n_vocab); -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(&ctx->sampling, seed); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + for (const auto & lb : smpl->logit_bias) { + smpl->cur[lb.token].logit += lb.bias; + } + + if (smpl->params.ignore_eos) { + smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; + } + + smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; + + // apply penalties + { + const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; + + llama_sampling_penalties(smpl, &smpl->cur_p); + + if (!smpl->params.penalize_nl) { + for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { + if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { + smpl->cur_p.data[idx].logit = nl_logit; + break; + } + } + } + } } -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); +llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { + return &smpl->cur_p; } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); +void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_softmax_impl(candidates); } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); } -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); +void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); +} + +void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); } -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); +} + +void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + if (smpl->params.dynatemp_range > 0) { + const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); + const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); + + llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); + } else { + llama_sampling_temp_impl(candidates, smpl->params.temp); + } } -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); +void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_grammar_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + if (smpl->grammar) { + llama_sampling_grammar_impl(candidates, *smpl->grammar); + + smpl->n_grammar++; + } } -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); +void llama_sampling_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); + + const float penalty_repeat = smpl->params.penalty_repeat; + const float penalty_freq = smpl->params.penalty_freq; + const float penalty_present = smpl->params.penalty_present; + + if ((penalty_last_n == 0) || + (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { + return; + } + + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: move to sampling state and avoid reallocation + llama_token_cnt token_count; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[smpl->prev.rat(i)]++; + } + + llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); } -void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); +llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto type = smpl->params.mirostat; + + llama_token res; + + if (type == 1) { + res = llama_sampling_sample_mirostat_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + 100, + smpl->vocab.n_vocab, + smpl->mirostat_mu); + } else if (type == 2) { + res = llama_sampling_sample_mirostat_v2_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + smpl->mirostat_mu); + } else { + GGML_ABORT("invalid mirostat type: %d", type); + } + + smpl->n_sample++; + + return res; } -void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); +llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_greedy_impl(candidates); + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); +llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto & params = smpl->params; + + const float temp = params.temp; + const int mirostat = params.mirostat; + + auto & cur_p = candidates; + + llama_token res = 0; + + if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { + // greedy sampling, with probs + llama_sampling_softmax_impl(cur_p); + res = cur_p->data[0].id; + } else if (temp == 0.0f) { + // greedy sampling, no probs + res = llama_sampling_sample_greedy(smpl, cur_p); + } else { + if (mirostat != 0) { + llama_sampling_temp(smpl, cur_p); + res = llama_sampling_sample_mirostat(smpl, cur_p); + } else { + for (const auto & sampler : smpl->samplers) { + switch (sampler) { + case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; + case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; + default : break; + } + } + + res = llama_sampling_sample_dist(smpl, cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} + + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); + } + } + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); +void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token, + bool apply_grammar) { + time_meas tm(smpl->t_accept_us); + + llama_sampling_accept_impl(*smpl, token, apply_grammar); + + smpl->n_accept++; } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); +int llama_sampling_n_prev(const struct llama_sampling * smpl) { + return llama_sampling_n_prev_impl(*smpl); } -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng); +llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { + return llama_sampling_prev_impl(*smpl, ith); } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); +llama_token llama_sampling_last(const struct llama_sampling * smpl) { + return llama_sampling_prev_impl(*smpl, 0); } +// +// model split +// + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { @@ -20737,30 +20991,32 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -struct llama_timings llama_get_timings(struct llama_context * ctx) { - struct llama_timings result = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - - /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), +void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) { + const llama_timings timings = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), + /*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0), + /*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0), + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + + /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), + /*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0), + /*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; - return result; -} - -void llama_print_timings(struct llama_context * ctx) { - const llama_timings timings = llama_get_timings(ctx); - LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar); + //LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + // __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -20768,12 +21024,16 @@ void llama_print_timings(struct llama_context * ctx) { LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - ctx->sampling.reset_timings(); + if (smpl) { + smpl->t_sample_us = smpl->n_sample = 0; + smpl->t_grammar_us = smpl->n_grammar = 0; + smpl->t_accept_us = smpl->n_accept = 0; + } } const char * llama_print_system_info(void) { @@ -20815,21 +21075,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e-3 * ctx->t_eval_us / ctx->n_eval); fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); } // For internal test use diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 9c4e7d18e37b2..788b02a6a5cd8 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,33 +2,18 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL - -#include "ggml.h" -#include "llama.h" -#include "grammar-parser.h" -#include "json-schema-to-grammar.h" #include "unicode.h" +#include "llama-grammar.h" +#include "json-schema-to-grammar.h" + #include #include #include using json = nlohmann::ordered_json; -static llama_grammar* build_grammar(const std::string & grammar_str) { - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // Ensure we parsed correctly - assert(!parsed_grammar.rules.empty()); - - // Ensure we have a root node - assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); - - std::vector grammar_rules(parsed_grammar.c_rules()); - llama_grammar* grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - - return grammar; +static llama_grammar * build_grammar(const std::string & grammar_str) { + return llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -45,17 +30,15 @@ static bool test_build_grammar_fails(const std::string & grammar_str) { } static bool match_string(const std::string & input, llama_grammar * grammar) { - auto decoded = decode_utf8(input, {}); - - const auto & code_points = decoded.first; + const auto cpts = unicode_cpts_from_utf8(input); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); - for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + for (const auto & cpt : cpts) { const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy - llama_grammar_accept(rules, prev_stacks, *it, cur_stacks); + cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); if (cur_stacks.empty()) { // no stacks means that the grammar failed to match at this point @@ -77,7 +60,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, fprintf(stderr, "âš« Testing %s\n%s\n", test_desc.c_str(), grammar_str.c_str()); fflush(stderr); - auto grammar = build_grammar(grammar_str); + auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar); @@ -143,7 +126,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, } // Clean up allocated memory - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); } static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); @@ -683,7 +666,8 @@ static void test_failure_missing_root() { term ::= number number ::= [0-9]+)"""; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we parsed correctly assert(!parsed_grammar.rules.empty()); @@ -705,7 +689,8 @@ static void test_failure_missing_reference() { fprintf(stderr, " Expected error: "); - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we did NOT parsed correctly assert(parsed_grammar.rules.empty()); diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 5df5abb25394c..259172d999c78 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -3,7 +3,7 @@ #endif #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include @@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) { static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { uint32_t index = 0; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_bytes); std::map symbol_names; for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { @@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector #include #include #include -#include "json-schema-to-grammar.h" -#include "grammar-parser.h" - static std::string trim(const std::string & source) { std::string s(source); s.erase(0,s.find_first_not_of(" \n\r\t")); @@ -40,7 +41,8 @@ struct TestCase { } void verify_expectation_parseable() const { try { - auto state = grammar_parser::parse(expected_grammar.c_str()); + llama_grammar_parser state; + state.parse(expected_grammar.c_str()); if (state.symbol_ids.find("root") == state.symbol_ids.end()) { throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar); } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b39f9b..6f1374ca8ed58 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,16 +2,15 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include #include int main() { - grammar_parser::parse_state parsed_grammar; + llama_grammar_parser parsed_grammar; std::vector> expected = { {"expr", 2}, @@ -117,7 +116,7 @@ int main() llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + grammar = llama_grammar_init_impl(nullptr, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); @@ -174,13 +173,13 @@ int main() }}; auto index = 0; - for (auto stack : llama_grammar_get_stacks(grammar)) + for (const llama_grammar_stack & stack : llama_grammar_get_stacks(grammar)) { // compare stack to expected_stack for (uint32_t i = 0; i < stack.size(); i++) { - auto element = stack[i]; - auto expected_element = expected_stacks[index][i]; + const llama_grammar_element * element = stack[i]; + const llama_grammar_element & expected_element = expected_stacks[index][i]; // pretty print error message before asserting if (expected_element.type != element->type || expected_element.value != element->value) @@ -403,6 +402,8 @@ int main() delete[] candidate.code_points; candidate.code_points = nullptr; } - llama_grammar_free(grammar); + + llama_grammar_free_impl(grammar); + return 0; } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 6c2a5db9accf2..f5e32a741b23c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -20,6 +21,7 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -28,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -49,9 +52,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -71,7 +75,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free(nullptr, &candidates_p, z, 1); + llama_sampling_tail_free_impl(&candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -82,6 +86,7 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -91,9 +96,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -112,7 +118,7 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence ) { GGML_ASSERT(probs.size() == expected_probs.size()); const size_t n_vocab = probs.size(); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -135,11 +142,16 @@ static void test_repetition_penalties( candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); } + llama_token_cnt token_count; + for (size_t i = 0; i < last_tokens.size(); i++) { + token_count[last_tokens[i]]++; + } + llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); + llama_sampling_softmax_impl(&candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -148,8 +160,7 @@ static void test_repetition_penalties( } } -static void test_sampler_queue( - const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p +static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { std::vector candidates; candidates.reserve(n_vocab); @@ -165,16 +176,16 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; + case 'k': llama_sampling_top_k_impl(&candidates_p, top_k, 1); break; case 'f': GGML_ABORT("tail_free test not implemented"); case 'y': GGML_ABORT("typical test not implemented"); - case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; + case 'p': llama_sampling_top_p_impl(&candidates_p, top_p, 1); break; + case 'm': llama_sampling_min_p_impl(&candidates_p, min_p, 1); break; case 't': GGML_ABORT("temperature test not implemented"); default : GGML_ABORT("Unknown sampler"); } - llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests + llama_sampling_softmax_impl(&candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size; @@ -259,13 +270,13 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); - test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); + test_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); test_sampler_queue(10000, "k", 10000, 1.0f, 1.0f); test_sampler_queue(10000, "k", 1, 1.0f, 1.0f); From 86b07ccbb3d93779d47f067664269a429d29d263 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 12:09:08 +0300 Subject: [PATCH 02/47] llama : sketching new sampling API --- include/llama.h | 39 +++++++++++++++++++++++++++++++++++++-- src/llama.cpp | 8 ++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/include/llama.h b/include/llama.h index 099c9e7465b9c..244e0770536f0 100644 --- a/include/llama.h +++ b/include/llama.h @@ -61,6 +61,8 @@ extern "C" { // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; + struct llama_sampler; + struct llama_constraint; struct llama_sampling; typedef int32_t llama_pos; @@ -412,6 +414,11 @@ extern "C" { bool ignore_eos; // ignore the end-of-sequence token } llama_sampling_params; + typedef struct llama_sampler_params { + bool dummy; + // TODO: add type of sampler: greedy, dist, mirostat, etc. + } llama_sampler_params; + // performance timing information struct llama_timings { double t_start_ms; @@ -440,8 +447,10 @@ extern "C" { struct llama_lora_adapter; // Helpers for getting default parameters + // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); @@ -465,7 +474,7 @@ extern "C" { LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, - struct llama_model_params params); + struct llama_model_params params); LLAMA_API void llama_free_model(struct llama_model * model); @@ -1031,7 +1040,7 @@ extern "C" { int32_t length); // - // Sampling functions + // Sampling API // // TODO: llama_model should become llama_vocab @@ -1156,6 +1165,32 @@ extern "C" { /// returns LLAMA_TOKEN_NULL if there are no accepted tokens LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); + // + // Sampling v2 API + // + + // samplers + + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + + LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + + // constraints + + LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); + // ... + LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); + + LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); + LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + // // Model split // diff --git a/src/llama.cpp b/src/llama.cpp index 258a568421347..d0ca96acaa610 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17935,6 +17935,14 @@ struct llama_context_params llama_context_default_params() { return result; } +struct llama_sampler_params llama_sampler_default_params() { + struct llama_sampler_params result = { + /*.dummy =*/ false, + }; + + return result; +} + struct llama_sampling_params llama_sampling_default_params() { struct llama_sampling_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, From 5116b3681cbab925a7c00903019cf32448ee8e0b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 13:12:50 +0300 Subject: [PATCH 03/47] cont : add llama_constraint_i [no ci] --- include/llama.h | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/include/llama.h b/include/llama.h index 244e0770536f0..baa136537e58d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -62,7 +62,6 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_constraint; struct llama_sampling; typedef int32_t llama_pos; @@ -1169,19 +1168,26 @@ extern "C" { // Sampling v2 API // - // samplers + // constraints - LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + struct llama_constraint; - LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + typedef void * llama_constraint_context_t; - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + struct llama_constraint_i { + void (*accept)(struct llama_constraint * cnstr, llama_token token); + void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); + void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints + void (*free) (struct llama_constraint * cnstr); - // constraints + // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph + //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + }; + + struct llama_constraint { + struct llama_constraint_i * iface; + llama_constraint_context_t ctx; + }; LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); @@ -1191,6 +1197,20 @@ extern "C" { LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + // samplers + + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + + // TODO: should this take ownership so the user does not need to call llama_constraint_free + // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? + LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + // // Model split // From cf4dd10ea543b6e679a597add5e6c465fe8c2f98 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 14:33:10 +0300 Subject: [PATCH 04/47] cont : initial implementation sketch [no ci] --- include/llama.h | 17 +++-- src/llama-sampling.cpp | 167 +++++++++++++++++++++++++++++++++++++++++ src/llama-sampling.h | 45 +++++++++++ src/llama.cpp | 66 +++++++++++++++- 4 files changed, 289 insertions(+), 6 deletions(-) diff --git a/include/llama.h b/include/llama.h index baa136537e58d..bd756fc5cbfa6 100644 --- a/include/llama.h +++ b/include/llama.h @@ -62,7 +62,7 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_sampling; + struct llama_sampling; // TODO: remove before merge typedef int32_t llama_pos; typedef int32_t llama_token; @@ -414,7 +414,8 @@ extern "C" { } llama_sampling_params; typedef struct llama_sampler_params { - bool dummy; + uint32_t seed; // the seed used to initialize the rng of the sampler + // TODO: add type of sampler: greedy, dist, mirostat, etc. } llama_sampler_params; @@ -1175,10 +1176,11 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { - void (*accept)(struct llama_constraint * cnstr, llama_token token); + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); - void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints - void (*free) (struct llama_constraint * cnstr); + void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints, can be NULL + void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); + void (*free) (struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1192,10 +1194,13 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); // ... + + // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + LLAMA_API void llama_constraint_reset (struct llama_constraint * cnstr); // samplers @@ -1206,6 +1211,8 @@ extern "C" { // TODO: should this take ownership so the user does not need to call llama_constraint_free // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? + // + // seems better to take the ownership, otherwise the copying of the sampler will be more complicated LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8abfc3fc6d86a..91b95d3af0fcd 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -50,6 +50,10 @@ struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & voca } void llama_sampling_free_impl(struct llama_sampling * sampling) { + if (sampling == nullptr) { + return; + } + delete sampling; } @@ -633,3 +637,166 @@ llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { return smpl.prev.size(); } + +// +// sampling v2 +// + +// constraints + +// top-k + +struct llama_constraint_context_top_k { + int32_t k; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_top_k_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; + llama_sampling_top_k_impl(candidates, ctx->k, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_top_k; + const auto * ctx_src = (const llama_constraint_context_top_k *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_top_k *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_top_k *) cnstr->ctx; + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_top_k_i; + result->ctx = new llama_constraint_context_top_k{k, min_keep}; + + return result; +} + +// top-p + +struct llama_constraint_context_top_p { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_top_p_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_top_p * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_top_p; + const auto * ctx_src = (const llama_constraint_context_top_p *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_top_p *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_top_p *) cnstr->ctx; + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_top_p_i; + result->ctx = new llama_constraint_context_top_p{p, min_keep}; + + return result; +} + +void llama_constraint_free_impl(struct llama_constraint * constraint) { + if (constraint->iface->free) { + constraint->iface->free(constraint); + } +} + +void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token) { + if (constraint->iface->accept) { + constraint->iface->accept(constraint, token); + } +} + +void llama_constraint_apply_impl(struct llama_constraint * constraint, struct llama_token_data_array * candidates) { + GGML_ASSERT(constraint->iface->apply); + constraint->iface->apply(constraint, candidates); +} + +void llama_constraint_reset_impl(struct llama_constraint * constraint) { + if (constraint->iface->reset) { + constraint->iface->reset(constraint); + } +} + +// samplers + +struct llama_sampler * llama_sampler_init_impl(struct llama_sampler_params params) { + auto * result = new llama_sampler; + + result->params = params; + + result->rng.seed(params.seed); + + return result; +} + +void llama_sampler_free_impl(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + for (auto * constraint : smpl->constraints) { + llama_constraint_free_impl(constraint); + } + + delete smpl; +} + +struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) { + auto * result = new llama_sampler; + + *result = smpl; + + // copy the constraints objects + result->constraints.clear(); + for (const auto & constraint : smpl.constraints) { + GGML_ASSERT(constraint->iface->copy); + + result->constraints.push_back(new llama_constraint); + result->constraints.back()->iface = constraint->iface; + result->constraints.back()->iface->copy(result->constraints.back(), constraint); + } + + return result; +} + +void llama_sampler_reset_impl(struct llama_sampler & smpl) { + smpl.prev.clear(); + + for (auto * constraint : smpl.constraints) { + llama_constraint_reset_impl(constraint); + } + + // TODO: should we reset the timings? +} + +void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { + smpl.constraints.push_back(cnstr); +} + +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { + smpl.prev.push_back(token); + + for (auto * constraint : smpl.constraints) { + llama_constraint_accept_impl(constraint, token); + } +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index c51542259e27d..4141c4fa35cb0 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -105,3 +105,48 @@ void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); + + +// +// sampling v2 +// + +// constraints + +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); + +void llama_constraint_free_impl(struct llama_constraint * constraint); + +void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token); +void llama_constraint_apply_impl (struct llama_constraint * constraint, struct llama_token_data_array * candidates); +void llama_constraint_reset_impl (struct llama_constraint * constraint); + +// samplers + +struct llama_sampler { + llama_sampler_params params; + + // state + + std::mt19937 rng; + + // TODO: move to a standalone penalty constraint? + ring_buffer prev; + + std::vector constraints; + + // timing + + mutable int64_t t_sample_us = 0; + + mutable int32_t n_sample = 0; +}; + +struct llama_sampler * llama_sampler_init_impl ( struct llama_sampler_params params); +void llama_sampler_free_impl ( struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); +void llama_sampler_reset_impl( struct llama_sampler & smpl); + +void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); +void llama_sampler_accept_impl (struct llama_sampler & smpl, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index d0ca96acaa610..5f06f33adb1e9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17937,7 +17937,7 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { - /*.dummy =*/ false, + /*.seed =*/ LLAMA_DEFAULT_SEED, }; return result; @@ -20971,6 +20971,70 @@ llama_token llama_sampling_last(const struct llama_sampling * smpl) { return llama_sampling_prev_impl(*smpl, 0); } +// +// sampling v2 +// + +struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { + return llama_constraint_init_top_k_impl(k, min_keep); +} + +struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { + return llama_constraint_init_top_p_impl(p, min_keep); +} + +void llama_constraint_free(struct llama_constraint * cnstr) { + if (cnstr == nullptr) { + return; + } + + llama_constraint_free_impl(cnstr); +} + +void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { + llama_constraint_accept_impl(cnstr, token); +} + +void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_apply_impl(cnstr, candidates); +} + +void llama_constraint_reset(struct llama_constraint * cnstr) { + llama_constraint_reset_impl(cnstr); +} + +struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) { + return llama_sampler_init_impl(params); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + llama_sampler_free_impl(smpl); +} + +struct llama_sampler * llama_sampler_cp(const struct llama_sampler * smpl) { + return llama_sampler_cp_impl(*smpl); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + llama_sampler_reset_impl(*smpl); +} + +void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { + llama_sampler_add_constraint_impl(*smpl, cnstr); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { + GGML_ABORT("not implemented"); +} + // // model split // From 1b07dc51c6df02e6f37adabf7dc6d66e3c59396a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 14:45:07 +0300 Subject: [PATCH 05/47] cont : fixes, naming [no ci] --- src/llama-sampling.cpp | 56 +++++++++++++++++++++++------------------- src/llama-sampling.h | 8 +++--- src/llama.cpp | 8 +++--- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 91b95d3af0fcd..5ccbf029f8926 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -665,7 +665,9 @@ static struct llama_constraint_i llama_constraint_top_k_i = { *ctx_dst = *ctx_src; }, /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_k *) cnstr->ctx; + if (cnstr->ctx) { + delete (llama_constraint_context_top_k *) cnstr->ctx; + } delete cnstr; } }; @@ -700,7 +702,9 @@ static struct llama_constraint_i llama_constraint_top_p_i = { *ctx_dst = *ctx_src; }, /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_p *) cnstr->ctx; + if (cnstr->ctx) { + delete (llama_constraint_context_top_p *) cnstr->ctx; + } delete cnstr; } }; @@ -714,26 +718,26 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k return result; } -void llama_constraint_free_impl(struct llama_constraint * constraint) { - if (constraint->iface->free) { - constraint->iface->free(constraint); +void llama_constraint_free_impl(struct llama_constraint * cnstr) { + if (cnstr->iface->free && cnstr) { + cnstr->iface->free(cnstr); } } -void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token) { - if (constraint->iface->accept) { - constraint->iface->accept(constraint, token); +void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token) { + if (cnstr.iface->accept) { + cnstr.iface->accept(&cnstr, token); } } -void llama_constraint_apply_impl(struct llama_constraint * constraint, struct llama_token_data_array * candidates) { - GGML_ASSERT(constraint->iface->apply); - constraint->iface->apply(constraint, candidates); +void llama_constraint_apply_impl(struct llama_constraint & cnstr, struct llama_token_data_array * candidates) { + GGML_ASSERT(cnstr.iface->apply); + cnstr.iface->apply(&cnstr, candidates); } -void llama_constraint_reset_impl(struct llama_constraint * constraint) { - if (constraint->iface->reset) { - constraint->iface->reset(constraint); +void llama_constraint_reset_impl(struct llama_constraint & cnstr) { + if (cnstr.iface->reset) { + cnstr.iface->reset(&cnstr); } } @@ -754,8 +758,8 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) { return; } - for (auto * constraint : smpl->constraints) { - llama_constraint_free_impl(constraint); + for (auto * cnstr : smpl->constraints) { + llama_constraint_free_impl(cnstr); } delete smpl; @@ -768,12 +772,14 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) // copy the constraints objects result->constraints.clear(); - for (const auto & constraint : smpl.constraints) { - GGML_ASSERT(constraint->iface->copy); - + for (const auto & cnstr : smpl.constraints) { result->constraints.push_back(new llama_constraint); - result->constraints.back()->iface = constraint->iface; - result->constraints.back()->iface->copy(result->constraints.back(), constraint); + result->constraints.back()->iface = cnstr->iface; + + if (cnstr->ctx) { + GGML_ASSERT(cnstr->iface->copy); + result->constraints.back()->iface->copy(result->constraints.back(), cnstr); + } } return result; @@ -782,8 +788,8 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) void llama_sampler_reset_impl(struct llama_sampler & smpl) { smpl.prev.clear(); - for (auto * constraint : smpl.constraints) { - llama_constraint_reset_impl(constraint); + for (auto * cnstr : smpl.constraints) { + llama_constraint_reset_impl(*cnstr); } // TODO: should we reset the timings? @@ -796,7 +802,7 @@ void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { smpl.prev.push_back(token); - for (auto * constraint : smpl.constraints) { - llama_constraint_accept_impl(constraint, token); + for (auto * cnstr : smpl.constraints) { + llama_constraint_accept_impl(*cnstr, token); } } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 4141c4fa35cb0..2a98d103eba6d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -116,11 +116,11 @@ int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); -void llama_constraint_free_impl(struct llama_constraint * constraint); +void llama_constraint_free_impl(struct llama_constraint * cnstr); -void llama_constraint_accept_impl(struct llama_constraint * constraint, llama_token token); -void llama_constraint_apply_impl (struct llama_constraint * constraint, struct llama_token_data_array * candidates); -void llama_constraint_reset_impl (struct llama_constraint * constraint); +void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token); +void llama_constraint_apply_impl (struct llama_constraint & cnstr, struct llama_token_data_array * candidates); +void llama_constraint_reset_impl (struct llama_constraint & cnstr); // samplers diff --git a/src/llama.cpp b/src/llama.cpp index 5f06f33adb1e9..4d41c9262ad99 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20992,22 +20992,22 @@ void llama_constraint_free(struct llama_constraint * cnstr) { } void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { - llama_constraint_accept_impl(cnstr, token); + llama_constraint_accept_impl(*cnstr, token); } void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_apply_impl(cnstr, candidates); + llama_constraint_apply_impl(*cnstr, candidates); } void llama_constraint_reset(struct llama_constraint * cnstr) { - llama_constraint_reset_impl(cnstr); + llama_constraint_reset_impl(*cnstr); } struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) { return llama_sampler_init_impl(params); } -void llama_sampler_free(struct llama_sampler * smpl) { +void llama_sampler_free(struct llama_sampler * smpl) { if (smpl == nullptr) { return; } From 71293a64564c050c7ab57b8fd1c69a5ed7759a49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 15:17:02 +0300 Subject: [PATCH 06/47] cont : add rest of the existing samplers [no ci] --- include/llama.h | 21 ++-- src/llama-grammar.cpp | 4 + src/llama-sampling.cpp | 263 +++++++++++++++++++++++++++++++++++++++++ src/llama-sampling.h | 10 +- src/llama.cpp | 24 ++++ 5 files changed, 312 insertions(+), 10 deletions(-) diff --git a/include/llama.h b/include/llama.h index bd756fc5cbfa6..76c1aaf98937f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1176,11 +1176,11 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { - void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); - void (*reset) (struct llama_constraint * cnstr); // e.g. for grammar and penalty constraints, can be NULL - void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); - void (*free) (struct llama_constraint * cnstr); // can be NULL + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL + void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required + void (*reset) (struct llama_constraint * cnstr); // can be NULL + void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); // can be NULL if ctx is NULL + void (*free) (struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1191,9 +1191,14 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep); - // ... + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 8cd98bae4dba6..092a738aafe6b 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1043,6 +1043,10 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, } void llama_grammar_free_impl(struct llama_grammar * grammar) { + if (grammar == nullptr) { + return; + } + delete grammar; } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 5ccbf029f8926..becc949dca4a1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -718,6 +718,269 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k return result; } +// min-p + +struct llama_constraint_context_min_p { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_min_p_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_min_p; + const auto * ctx_src = (const llama_constraint_context_min_p *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_min_p *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_min_p *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_min_p_i; + result->ctx = new llama_constraint_context_min_p{p, min_keep}; + + return result; +} + +// tail-free + +struct llama_constraint_context_tail_free { + float z; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_tail_free_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_tail_free; + const auto * ctx_src = (const llama_constraint_context_tail_free *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_tail_free *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_tail_free *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_tail_free_i; + result->ctx = new llama_constraint_context_tail_free{z, min_keep}; + + return result; +} + +// typical + +struct llama_constraint_context_typical { + float p; + size_t min_keep; +}; + +static struct llama_constraint_i llama_constraint_typical_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx; + llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_typical; + const auto * ctx_src = (const llama_constraint_context_typical *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_typical *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_typical *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_typical_i; + result->ctx = new llama_constraint_context_typical{p, min_keep}; + + return result; +} + +// temp + +struct llama_constraint_context_temp { + float temp; +}; + +static struct llama_constraint_i llama_constraint_temp_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx; + llama_sampling_temp_impl(candidates, ctx->temp); + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_temp; + const auto * ctx_src = (const llama_constraint_context_temp *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_temp *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_temp *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_temp_impl(float temp) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_temp_i; + result->ctx = new llama_constraint_context_temp{temp}; + + return result; +} + +// temp-ext + +struct llama_constraint_context_temp_ext { + float temp; + float delta; + float exponent; +}; + +static struct llama_constraint_i llama_constraint_temp_ext_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + if (ctx->delta > 0) { + const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); + const float temp_max = ctx->temp + ctx->delta; + + llama_sampling_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); + } else { + llama_sampling_temp_impl(candidates, ctx->temp); + } + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_temp_ext; + const auto * ctx_src = (const llama_constraint_context_temp_ext *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_temp_ext *) cnstr->ctx; + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_temp_ext *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_temp_ext_i; + result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent}; + + return result; +} + +// grammar + +struct llama_constraint_context_grammar { + std::string grammar_str; + std::string grammar_root; + + struct llama_grammar * grammar; +}; + +static struct llama_constraint_i llama_constraint_grammar_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_sampling_grammar_impl(candidates, *ctx->grammar); + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = nullptr; + } + + if (!ctx->grammar_str.empty()) { + ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + } + }, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_grammar; + const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_grammar *) cnstr->ctx; + + *ctx_dst = *ctx_src; + + if (ctx_src->grammar) { + ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar); + } else { + ctx_dst->grammar = nullptr; + } + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + { + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + llama_grammar_free_impl(ctx->grammar); + } + + delete (llama_constraint_context_grammar *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_grammar_i; + result->ctx = new llama_constraint_context_grammar; + + auto * ctx = (llama_constraint_context_grammar *) result->ctx; + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root); + } else { + ctx->grammar = nullptr; + } + + return result; +} + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 2a98d103eba6d..ed18da10d8a35 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -113,8 +113,14 @@ int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); // constraints -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep); -struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep); +struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_temp_impl (float t); +struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); +struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); void llama_constraint_free_impl(struct llama_constraint * cnstr); diff --git a/src/llama.cpp b/src/llama.cpp index 4d41c9262ad99..a47ad0103860e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20983,6 +20983,30 @@ struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) return llama_constraint_init_top_p_impl(p, min_keep); } +struct llama_constraint * llama_constraint_init_min_p(float p, int32_t min_keep) { + return llama_constraint_init_min_p_impl(p, min_keep); +} + +struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep) { + return llama_constraint_init_tail_free_impl(z, min_keep); +} + +struct llama_constraint * llama_constraint_init_typical(float p, int32_t min_keep) { + return llama_constraint_init_typical_impl(p, min_keep); +} + +struct llama_constraint * llama_constraint_init_temp(float temp) { + return llama_constraint_init_temp_impl(temp); +} + +struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta, float exponent) { + return llama_constraint_init_temp_ext_impl(temp, delta, exponent); +} + +struct llama_constraint * llama_constraint_init_grammar(struct llama_model * model, const char * grammar_str, const char * grammar_root) { + return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return; From 0daebc6b8d1ca5349b12628f9b1dc85b663700e0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 15:19:32 +0300 Subject: [PATCH 07/47] cont : fix [no ci] --- src/llama-sampling.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index becc949dca4a1..f6328b60132d4 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -973,8 +973,14 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ auto * ctx = (llama_constraint_context_grammar *) result->ctx; if (grammar_str != nullptr && grammar_str[0] != '\0') { + ctx->grammar_str = grammar_str; + ctx->grammar_root = grammar_root; + ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root); } else { + ctx->grammar_str.clear(); + ctx->grammar_root.clear(); + ctx->grammar = nullptr; } From a2ce91cbef4117d0d5d91671db2c6525d80d516c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 16:04:22 +0300 Subject: [PATCH 08/47] cont : add penalties and logit-bias constraints [no ci] --- common/sampling.cpp | 88 ++++++------- common/sampling.h | 35 ++++-- include/llama.h | 95 ++++++++++---- src/llama-sampling.cpp | 275 +++++++++++++++++++++++++++++++++++++---- src/llama-sampling.h | 41 +++++- src/llama.cpp | 139 +++++++++++++++++++-- 6 files changed, 555 insertions(+), 118 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 96cfbe0ef5b45..a98117cf82455 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -128,57 +128,57 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m return result; } -char llama_sampling_type_to_chr(llama_sampler_type sampler) { +char llama_sampling_type_to_chr(llama_constraint_type sampler) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: return 'k'; - case LLAMA_SAMPLER_TYPE_TFS_Z: return 'f'; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: return 'y'; - case LLAMA_SAMPLER_TYPE_TOP_P: return 'p'; - case LLAMA_SAMPLER_TYPE_MIN_P: return 'm'; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: return 't'; + case LLAMA_CONSTRAINT_TYPE_TOP_K: return 'k'; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: return 'f'; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; + case LLAMA_CONSTRAINT_TYPE_TOP_P: return 'p'; + case LLAMA_CONSTRAINT_TYPE_MIN_P: return 'm'; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string llama_sampling_type_to_str(llama_sampler_type sampler) { +std::string llama_sampling_type_to_str(llama_constraint_type sampler) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: return "top_k"; - case LLAMA_SAMPLER_TYPE_TFS_Z: return "tfs_z"; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; - case LLAMA_SAMPLER_TYPE_TOP_P: return "top_p"; - case LLAMA_SAMPLER_TYPE_MIN_P: return "min_p"; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case LLAMA_CONSTRAINT_TYPE_TOP_K: return "top_k"; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; + case LLAMA_CONSTRAINT_TYPE_TOP_P: return "top_p"; + case LLAMA_CONSTRAINT_TYPE_MIN_P: return "min_p"; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "top_k", LLAMA_SAMPLER_TYPE_TOP_K }, - { "top_p", LLAMA_SAMPLER_TYPE_TOP_P }, - { "typ_p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "min_p", LLAMA_SAMPLER_TYPE_MIN_P }, - { "tfs_z", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "temperature", LLAMA_SAMPLER_TYPE_TEMPERATURE }, +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "top_k", LLAMA_CONSTRAINT_TYPE_TOP_K }, + { "top_p", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "typ_p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "min_p", LLAMA_CONSTRAINT_TYPE_MIN_P }, + { "tfs_z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "temperature", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, }; // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", LLAMA_SAMPLER_TYPE_TOP_K }, - { "top-p", LLAMA_SAMPLER_TYPE_TOP_P }, - { "nucleus", LLAMA_SAMPLER_TYPE_TOP_P }, - { "typical-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typical", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typ-p", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "typ", LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { "min-p", LLAMA_SAMPLER_TYPE_MIN_P }, - { "tfs-z", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "tfs", LLAMA_SAMPLER_TYPE_TFS_Z }, - { "temp", LLAMA_SAMPLER_TYPE_TEMPERATURE }, + std::unordered_map sampler_alt_name_map { + { "top-k", LLAMA_CONSTRAINT_TYPE_TOP_K }, + { "top-p", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "nucleus", LLAMA_CONSTRAINT_TYPE_TOP_P }, + { "typical-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typical", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { "min-p", LLAMA_CONSTRAINT_TYPE_MIN_P }, + { "tfs-z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "tfs", LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { "temp", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, }; - std::vector samplers; + std::vector samplers; samplers.reserve(names.size()); for (const auto & name : names) { @@ -198,17 +198,17 @@ std::vector llama_sampling_types_from_names(const std::vecto return samplers; } -std::vector llama_sampling_types_from_chars(const std::string & chars) { - std::unordered_map sampler_name_map { - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_K), LLAMA_SAMPLER_TYPE_TOP_K }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TFS_Z), LLAMA_SAMPLER_TYPE_TFS_Z }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TYPICAL_P), LLAMA_SAMPLER_TYPE_TYPICAL_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TOP_P), LLAMA_SAMPLER_TYPE_TOP_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_MIN_P), LLAMA_SAMPLER_TYPE_MIN_P }, - { llama_sampling_type_to_chr(LLAMA_SAMPLER_TYPE_TEMPERATURE), LLAMA_SAMPLER_TYPE_TEMPERATURE } +std::vector llama_sampling_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map { + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_K), LLAMA_CONSTRAINT_TYPE_TOP_K }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TFS_Z), LLAMA_CONSTRAINT_TYPE_TFS_Z }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TYPICAL_P), LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_P), LLAMA_CONSTRAINT_TYPE_TOP_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_MIN_P), LLAMA_CONSTRAINT_TYPE_MIN_P }, + { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TEMPERATURE), LLAMA_CONSTRAINT_TYPE_TEMPERATURE } }; - std::vector samplers; + std::vector samplers; samplers.reserve(chars.size()); for (const auto & c : chars) { diff --git a/common/sampling.h b/common/sampling.h index b96bbce1ce869..365b7639acf52 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -6,7 +6,7 @@ #include // sampling parameters -typedef struct gpt_sampling_params { +struct gpt_sampling_params { uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling int32_t n_prev = 64; // number of previous tokens to remember @@ -30,13 +30,13 @@ typedef struct gpt_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector samplers = { - LLAMA_SAMPLER_TYPE_TOP_K, - LLAMA_SAMPLER_TYPE_TFS_Z, - LLAMA_SAMPLER_TYPE_TYPICAL_P, - LLAMA_SAMPLER_TYPE_TOP_P, - LLAMA_SAMPLER_TYPE_MIN_P, - LLAMA_SAMPLER_TYPE_TEMPERATURE + std::vector samplers = { + LLAMA_CONSTRAINT_TYPE_TOP_K, + LLAMA_CONSTRAINT_TYPE_TFS_Z, + LLAMA_CONSTRAINT_TYPE_TYPICAL_P, + LLAMA_CONSTRAINT_TYPE_TOP_P, + LLAMA_CONSTRAINT_TYPE_MIN_P, + LLAMA_CONSTRAINT_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -48,7 +48,16 @@ typedef struct gpt_sampling_params { // print the samplers into a string std::string print_samplers() const; -} gpt_sampling_params; +}; + +// TODO: implement +struct gpt_sampler { + gpt_sampling_params params; + + struct llama_constraint * grmr = nullptr; + + struct llama_sampler * smpl = nullptr; +}; // overload of llama_sampling_init using gpt_sampling_params struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); @@ -72,8 +81,8 @@ llama_token llama_sampling_sample( // get a string representation of the last accepted tokens std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); -char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type); -std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type); +char llama_sampling_type_to_chr(enum llama_constraint_type sampler_type); +std::string llama_sampling_type_to_str(enum llama_constraint_type sampler_type); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & chars); +std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector llama_sampling_types_from_chars(const std::string & chars); diff --git a/include/llama.h b/include/llama.h index 76c1aaf98937f..7225874f7e878 100644 --- a/include/llama.h +++ b/include/llama.h @@ -46,6 +46,7 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 +// TODO: remove before merge #define LLAMA_MAX_SAMPLERS 16 #ifdef __cplusplus @@ -209,14 +210,15 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; - enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_NONE = 0, - LLAMA_SAMPLER_TYPE_TOP_K = 1, - LLAMA_SAMPLER_TYPE_TOP_P = 2, - LLAMA_SAMPLER_TYPE_MIN_P = 3, - LLAMA_SAMPLER_TYPE_TFS_Z = 4, - LLAMA_SAMPLER_TYPE_TYPICAL_P = 5, - LLAMA_SAMPLER_TYPE_TEMPERATURE = 6, + // TODO: move to common, rename to gpt_constraint_type + enum llama_constraint_type { + LLAMA_CONSTRAINT_TYPE_NONE = 0, + LLAMA_CONSTRAINT_TYPE_TOP_K = 1, + LLAMA_CONSTRAINT_TYPE_TOP_P = 2, + LLAMA_CONSTRAINT_TYPE_MIN_P = 3, + LLAMA_CONSTRAINT_TYPE_TFS_Z = 4, + LLAMA_CONSTRAINT_TYPE_TYPICAL_P = 5, + LLAMA_CONSTRAINT_TYPE_TEMPERATURE = 6, }; typedef struct llama_token_data { @@ -382,6 +384,7 @@ extern "C" { float bias; } llama_logit_bias; + // TODO: remove before merge // parameters for sampling the logits typedef struct llama_sampling_params { uint32_t seed; // the seed used to initialize llama_sampling_context @@ -406,7 +409,7 @@ extern "C" { // samplers int32_t n_samplers; - enum llama_sampler_type samplers[LLAMA_MAX_SAMPLERS]; + enum llama_constraint_type samplers[LLAMA_MAX_SAMPLERS]; // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. bool penalize_nl; // consider newlines as a repeatable token @@ -414,7 +417,11 @@ extern "C" { } llama_sampling_params; typedef struct llama_sampler_params { - uint32_t seed; // the seed used to initialize the rng of the sampler + uint32_t seed; // the seed used to initialize the rng of the sampler + + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + float mirostat_tau; // target entropy + float mirostat_eta; // learning rate // TODO: add type of sampler: greedy, dist, mirostat, etc. } llama_sampler_params; @@ -1176,6 +1183,8 @@ extern "C" { typedef void * llama_constraint_context_t; struct llama_constraint_i { + // TODO: add name API + void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required void (*reset) (struct llama_constraint * cnstr); // can be NULL @@ -1184,6 +1193,8 @@ extern "C" { // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + + // TODO: add API to get timing stats }; struct llama_constraint { @@ -1191,14 +1202,28 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + + LLAMA_API struct llama_constraint * llama_constraint_init_penalties( + struct llama_model * model, + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present, // 0.0 = disabled + bool penalize_nl, // consider newlines as a repeatable token + bool ignore_eos); // ignore the end-of-sequence token + + LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); @@ -1209,19 +1234,47 @@ extern "C" { // samplers - LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params); + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); + + LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); + + // TODO: should this take ownership so the user does not need to call llama_constraint_free // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? // // seems better to take the ownership, otherwise the copying of the sampler will be more complicated LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i); + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates); + + LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates); + + /// @details Get the number of accepted tokens so far (max of n_prev) + LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); + + /// @details Get the ith accepted token + /// @param ith [0, n_prev), ith == 0 is the last accepted token. + /// returns LLAMA_TOKEN_NULL if ith is out of bounds + LLAMA_API llama_token llama_sampler_prev( + const struct llama_sampler * smpl, + int32_t ith); + + /// @details Get the last accepted token + /// Same as llama_sampler_prev(smpl, 0) + /// returns LLAMA_TOKEN_NULL if there are no accepted tokens + LLAMA_API llama_token llama_sampler_last(const struct llama_sampler * smpl); + + // TODO: extend in the future + //LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t i); + //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); // // Model split diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f6328b60132d4..36bbc0c1bad19 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -676,7 +676,14 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_top_k_i; - result->ctx = new llama_constraint_context_top_k{k, min_keep}; + result->ctx = new llama_constraint_context_top_k; + + auto * ctx = (llama_constraint_context_top_k *) result->ctx; + + *ctx = { + /*.k =*/ k, + /*.min_keep =*/ min_keep, + }; return result; } @@ -691,7 +698,7 @@ struct llama_constraint_context_top_p { static struct llama_constraint_i llama_constraint_top_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_top_p * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -713,7 +720,14 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_top_p_i; - result->ctx = new llama_constraint_context_top_p{p, min_keep}; + result->ctx = new llama_constraint_context_top_p; + + auto * ctx = (llama_constraint_context_top_p *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -728,7 +742,7 @@ struct llama_constraint_context_min_p { static struct llama_constraint_i llama_constraint_min_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_min_p * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -750,7 +764,14 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_min_p_i; - result->ctx = new llama_constraint_context_min_p{p, min_keep}; + result->ctx = new llama_constraint_context_min_p; + + auto * ctx = (llama_constraint_context_min_p *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -765,7 +786,7 @@ struct llama_constraint_context_tail_free { static struct llama_constraint_i llama_constraint_tail_free_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_tail_free * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, @@ -787,7 +808,14 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_tail_free_i; - result->ctx = new llama_constraint_context_tail_free{z, min_keep}; + result->ctx = new llama_constraint_context_tail_free; + + auto * ctx = (llama_constraint_context_tail_free *) result->ctx; + + *ctx = { + /*.z =*/ z, + /*.min_keep =*/ min_keep, + }; return result; } @@ -802,7 +830,7 @@ struct llama_constraint_context_typical { static struct llama_constraint_i llama_constraint_typical_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_typical * ctx = (llama_constraint_context_typical *) cnstr->ctx; + auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -824,7 +852,14 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_typical_i; - result->ctx = new llama_constraint_context_typical{p, min_keep}; + result->ctx = new llama_constraint_context_typical; + + auto * ctx = (llama_constraint_context_typical *) result->ctx; + + *ctx = { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }; return result; } @@ -838,7 +873,7 @@ struct llama_constraint_context_temp { static struct llama_constraint_i llama_constraint_temp_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_temp * ctx = (llama_constraint_context_temp *) cnstr->ctx; + auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; llama_sampling_temp_impl(candidates, ctx->temp); }, /* .reset = */ nullptr, @@ -860,7 +895,13 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_temp_i; - result->ctx = new llama_constraint_context_temp{temp}; + result->ctx = new llama_constraint_context_temp; + + auto * ctx = (llama_constraint_context_temp *) result->ctx; + + *ctx = { + /*.temp =*/ temp, + }; return result; } @@ -876,7 +917,7 @@ struct llama_constraint_context_temp_ext { static struct llama_constraint_i llama_constraint_temp_ext_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_temp_ext * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; @@ -905,7 +946,15 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float struct llama_constraint * result = new llama_constraint; result->iface = &llama_constraint_temp_ext_i; - result->ctx = new llama_constraint_context_temp_ext{temp, delta, exponent}; + result->ctx = new llama_constraint_context_temp_ext; + + auto * ctx = (llama_constraint_context_temp_ext *) result->ctx; + + *ctx = { + /*.temp =*/ temp, + /*.delta =*/ delta, + /*.exponent =*/ exponent, + }; return result; } @@ -920,15 +969,20 @@ struct llama_constraint_context_grammar { }; static struct llama_constraint_i llama_constraint_grammar_i = { - /* .accept = */ nullptr, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_grammar_accept_impl(*ctx->grammar, token); + } + }, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_sampling_grammar_impl(candidates, *ctx->grammar); } }, /* .reset = */ [](struct llama_constraint * cnstr) { - llama_constraint_context_grammar * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); ctx->grammar = nullptr; @@ -973,20 +1027,173 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ auto * ctx = (llama_constraint_context_grammar *) result->ctx; if (grammar_str != nullptr && grammar_str[0] != '\0') { - ctx->grammar_str = grammar_str; - ctx->grammar_root = grammar_root; - - ctx->grammar = llama_grammar_init_impl(&vocab, grammar_str, grammar_root); + *ctx = { + /*.grammar_str = */ grammar_str, + /*.grammar_root = */ grammar_root, + /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + }; } else { - ctx->grammar_str.clear(); - ctx->grammar_root.clear(); + *ctx = { + /*.grammar_str = */ {}, + /*.grammar_root = */ {}, + /*.grammar = */ nullptr, + }; + } + + return result; +} + +// penalties + +struct llama_constraint_context_penalties { + const struct llama_vocab * vocab; + + int32_t penalty_last_n; + float penalty_repeat; + float penalty_freq; + float penalty_present; + + bool penalize_nl; + bool ignore_eos; + + ring_buffer prev; +}; + +static struct llama_constraint_i llama_constraint_penalties_i = { + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + ctx->prev.push_back(token); + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + + GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); + + if (ctx->ignore_eos) { + candidates->data[ctx->vocab->special_eos_id].logit = -INFINITY; + } + + if ((ctx->penalty_last_n == 0) || + (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) { + return; + } + + const float nl_logit = !ctx->penalize_nl ? candidates->data[ctx->vocab->linefeed_id].logit : -INFINITY; + + // Create a frequency map to count occurrences of each token in last_tokens + // TODO: optimize this by maintaining the token count in the constraint context + llama_token_cnt token_count; + for (int i = 0; i < ctx->penalty_last_n; ++i) { + token_count[ctx->prev.rat(i)]++; + } + + llama_sampling_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + + if (!ctx->penalize_nl) { + // restore the logit of the newline token if it was penalized + candidates->data[ctx->vocab->linefeed_id].logit = nl_logit; + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + ctx->prev.clear(); + }, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_penalties; + const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_penalties *) cnstr->ctx; + + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_penalties *) cnstr->ctx; + } + delete cnstr; + } +}; + +struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { + GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); + + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_penalties_i; + result->ctx = new llama_constraint_context_penalties; + + auto * ctx = (llama_constraint_context_penalties *) result->ctx; + + *ctx = { + /*.vocab = */ &vocab, + /*.penalty_last_n = */ penalty_last_n, + /*.penalty_repeat = */ penalty_repeat, + /*.penalty_freq = */ penalty_freq, + /*.penalty_present = */ penalty_present, + /*.penalize_nl = */ penalize_nl, + /*.ignore_eos = */ ignore_eos, + /*.prev = */ {}, + }; + + return result; +} + +// logit-bias - ctx->grammar = nullptr; +struct llama_constraint_context_logit_bias { + const struct llama_vocab * vocab; + + std::vector logit_bias; +}; + +static struct llama_constraint_i llama_constraint_logit_bias_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; + + GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); + + for (const auto & lb : ctx->logit_bias) { + candidates->data[lb.token].logit += lb.bias; + } + }, + /* .reset = */ nullptr, + /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { + cnstr->ctx = new llama_constraint_context_logit_bias; + const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr_src->ctx; + auto * ctx_dst = ( llama_constraint_context_logit_bias *) cnstr->ctx; + + *ctx_dst = *ctx_src; + }, + /* .free = */ [](struct llama_constraint * cnstr) { + if (cnstr->ctx) { + delete (llama_constraint_context_logit_bias *) cnstr->ctx; + } + delete cnstr; } +}; + +struct llama_constraint * llama_constraint_init_logit_bias_impl( + const struct llama_vocab & vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + struct llama_constraint * result = new llama_constraint; + + result->iface = &llama_constraint_logit_bias_i; + result->ctx = new llama_constraint_context_logit_bias; + + auto * ctx = (llama_constraint_context_logit_bias *) result->ctx; + + *ctx = { + /*.vocab = */ &vocab, + /*.logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + }; return result; } +//////////////////////////////////////// + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); @@ -1012,10 +1219,11 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { // samplers -struct llama_sampler * llama_sampler_init_impl(struct llama_sampler_params params) { +struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { auto * result = new llama_sampler; result->params = params; + result->vocab = &vocab; result->rng.seed(params.seed); @@ -1075,3 +1283,22 @@ void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { llama_constraint_accept_impl(*cnstr, token); } } + +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * candidates) { + for (auto * cnstr : smpl.constraints) { + llama_constraint_apply_impl(*cnstr, candidates); + } +} + +llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.prev.size()) { + return LLAMA_TOKEN_NULL; + } + + return smpl.prev.rat(ith); +} + +int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { + return smpl.prev.size(); +} + diff --git a/src/llama-sampling.h b/src/llama-sampling.h index ed18da10d8a35..7de37c89e5817 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,6 +10,7 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; +// TODO: remove before merge struct llama_sampling { llama_sampling(const struct llama_vocab & vocab); ~llama_sampling(); @@ -27,7 +28,7 @@ struct llama_sampling { const struct llama_vocab & vocab; - std::vector samplers; + std::vector samplers; ring_buffer prev; @@ -120,7 +121,25 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_temp_impl (float t); struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); -struct llama_constraint * llama_constraint_init_grammar_impl (const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +struct llama_constraint * llama_constraint_init_grammar_impl ( + const struct llama_vocab & vocab, + const char * grammar_str, + const char * grammar_root); + +struct llama_constraint * llama_constraint_init_penalties_impl( + const struct llama_vocab & vocab, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos); + + LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias_impl( + const struct llama_vocab & vocab, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); void llama_constraint_free_impl(struct llama_constraint * cnstr); @@ -133,15 +152,22 @@ void llama_constraint_reset_impl (struct llama_constraint & cnstr); struct llama_sampler { llama_sampler_params params; + const struct llama_vocab * vocab; + // state std::mt19937 rng; - // TODO: move to a standalone penalty constraint? + float mirostat_mu; + ring_buffer prev; std::vector constraints; + std::vector cur; + + llama_token_data_array cur_p; + // timing mutable int64_t t_sample_us = 0; @@ -149,10 +175,15 @@ struct llama_sampler { mutable int32_t n_sample = 0; }; -struct llama_sampler * llama_sampler_init_impl ( struct llama_sampler_params params); +struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); void llama_sampler_free_impl ( struct llama_sampler * smpl); struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); void llama_sampler_reset_impl( struct llama_sampler & smpl); void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); -void llama_sampler_accept_impl (struct llama_sampler & smpl, llama_token token); + +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * candidates); + +llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); +int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); diff --git a/src/llama.cpp b/src/llama.cpp index a47ad0103860e..4060fa1de420d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17938,6 +17938,9 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.mirostat =*/ 0, + /*.mirostat_tau =*/ 5.00f, + /*.mirostat_eta =*/ 0.10f, }; return result; @@ -17965,7 +17968,7 @@ struct llama_sampling_params llama_sampling_default_params() { /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, /*.n_samplers =*/ 3, - /*.samplers =*/ { LLAMA_SAMPLER_TYPE_TEMPERATURE, LLAMA_SAMPLER_TYPE_TOP_K, LLAMA_SAMPLER_TYPE_TOP_P, }, + /*.samplers =*/ { LLAMA_CONSTRAINT_TYPE_TEMPERATURE, LLAMA_CONSTRAINT_TYPE_TOP_K, LLAMA_CONSTRAINT_TYPE_TOP_P, }, /*.penalize_nl =*/ false, /*.ignore_eos =*/ false, }; @@ -20916,12 +20919,12 @@ llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data } else { for (const auto & sampler : smpl->samplers) { switch (sampler) { - case LLAMA_SAMPLER_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; - case LLAMA_SAMPLER_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; + case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; + case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; default : break; } } @@ -21007,6 +21010,24 @@ struct llama_constraint * llama_constraint_init_grammar(struct llama_model * mod return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +struct llama_constraint * llama_constraint_init_penalties( + struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); +} + +LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return; @@ -21027,8 +21048,8 @@ void llama_constraint_reset(struct llama_constraint * cnstr) { llama_constraint_reset_impl(*cnstr); } -struct llama_sampler * llama_sampler_init(struct llama_sampler_params params) { - return llama_sampler_init_impl(params); +struct llama_sampler * llama_sampler_init(const struct llama_model * model, struct llama_sampler_params params) { + return llama_sampler_init_impl(model->vocab, params); } void llama_sampler_free(struct llama_sampler * smpl) { @@ -21047,6 +21068,22 @@ void llama_sampler_reset(struct llama_sampler * smpl) { llama_sampler_reset_impl(*smpl); } +void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { + const int n_vocab = smpl->vocab->n_vocab; + + smpl->cur.resize(n_vocab); + + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; +} + +llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl) { + return &smpl->cur_p; +} + void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { llama_sampler_add_constraint_impl(*smpl, cnstr); } @@ -21055,10 +21092,90 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { llama_sampler_accept_impl(*smpl, token); } -llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { - GGML_ABORT("not implemented"); +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + llama_sampler_apply_impl(*smpl, candidates); } +llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + const auto type = smpl->params.mirostat; + + llama_token res; + + if (type == 1) { + res = llama_sampling_sample_mirostat_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + 100, + smpl->vocab->n_vocab, + smpl->mirostat_mu); + } else if (type == 2) { + res = llama_sampling_sample_mirostat_v2_impl(candidates, + smpl->rng, + smpl->params.mirostat_tau, + smpl->params.mirostat_eta, + smpl->mirostat_mu); + } else { + GGML_ABORT("invalid mirostat type: %d", type); + } + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_greedy_impl(candidates); + + smpl->n_sample++; + + return res; +} + +llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_sample_us); + + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + + auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + + smpl->n_sample++; + + return res; +} + +int llama_sampler_n_prev(const struct llama_sampler * smpl) { + return llama_sampler_n_prev_impl(*smpl); +} + +llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith) { + return llama_sampler_prev_impl(*smpl, ith); +} + +llama_token llama_sampler_last(const struct llama_sampler * smpl) { + return llama_sampler_prev_impl(*smpl, 0); +} + +//llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { +// GGML_ABORT("not implemented"); +//} + // // model split // From 09ceb68caa27ed07245e23d0c3c843ce46927e3a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 3 Sep 2024 17:27:40 +0300 Subject: [PATCH 09/47] cont : add comments [no ci] --- include/llama.h | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/include/llama.h b/include/llama.h index 7225874f7e878..4dd5348a8a903 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1048,6 +1048,7 @@ extern "C" { // // Sampling API + // TODO: remove before merge // // TODO: llama_model should become llama_vocab @@ -1175,6 +1176,23 @@ extern "C" { // // Sampling v2 API // + // - Constraints + // The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their + // logits and probabilities inplace. The interface is abstracted so that users can implement custom constraints. + // + // - Samplers + // The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the + // sampler can apply a sequence of constraints to the candidate tokens. + // + // The llama_sampler object contains the entire sampling information: + // + // - RNG state (seed and generator) + // - Custom set of constraints (see llama_sampler_add_constraint) + // - Sampling method (greedy, dist, mirostat) + // - Previous tokens + // + // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). + // // constraints @@ -1182,6 +1200,7 @@ extern "C" { typedef void * llama_constraint_context_t; + // user code can implement the interface below in order to create custom llama_constraint struct llama_constraint_i { // TODO: add name API @@ -1263,9 +1282,7 @@ extern "C" { /// @details Get the ith accepted token /// @param ith [0, n_prev), ith == 0 is the last accepted token. /// returns LLAMA_TOKEN_NULL if ith is out of bounds - LLAMA_API llama_token llama_sampler_prev( - const struct llama_sampler * smpl, - int32_t ith); + LLAMA_API llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith); /// @details Get the last accepted token /// Same as llama_sampler_prev(smpl, 0) From 1e8e26c155cbf480509d6efa23f215ab84bce069 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 10:03:14 +0300 Subject: [PATCH 10/47] cont : leaner constraint initialization [no ci] --- src/llama-sampling.cpp | 168 ++++++++++++++++------------------------- 1 file changed, 66 insertions(+), 102 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 36bbc0c1bad19..f21b5fd55b3ae 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -673,16 +673,12 @@ static struct llama_constraint_i llama_constraint_top_k_i = { }; struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_top_k_i; - result->ctx = new llama_constraint_context_top_k; - - auto * ctx = (llama_constraint_context_top_k *) result->ctx; - - *ctx = { - /*.k =*/ k, - /*.min_keep =*/ min_keep, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_top_k_i, + /* .ctx = */ new llama_constraint_context_top_k { + /*.k =*/ k, + /*.min_keep =*/ min_keep, + }, }; return result; @@ -717,16 +713,12 @@ static struct llama_constraint_i llama_constraint_top_p_i = { }; struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_top_p_i; - result->ctx = new llama_constraint_context_top_p; - - auto * ctx = (llama_constraint_context_top_p *) result->ctx; - - *ctx = { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_top_p_i, + /* .ctx = */ new llama_constraint_context_top_p { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }, }; return result; @@ -761,16 +753,12 @@ static struct llama_constraint_i llama_constraint_min_p_i = { }; struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_min_p_i; - result->ctx = new llama_constraint_context_min_p; - - auto * ctx = (llama_constraint_context_min_p *) result->ctx; - - *ctx = { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_min_p_i, + /* .ctx = */ new llama_constraint_context_min_p { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }, }; return result; @@ -805,16 +793,12 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { }; struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_tail_free_i; - result->ctx = new llama_constraint_context_tail_free; - - auto * ctx = (llama_constraint_context_tail_free *) result->ctx; - - *ctx = { - /*.z =*/ z, - /*.min_keep =*/ min_keep, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_tail_free_i, + /* .ctx = */ new llama_constraint_context_tail_free { + /*.z =*/ z, + /*.min_keep =*/ min_keep, + }, }; return result; @@ -849,16 +833,12 @@ static struct llama_constraint_i llama_constraint_typical_i = { }; struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_typical_i; - result->ctx = new llama_constraint_context_typical; - - auto * ctx = (llama_constraint_context_typical *) result->ctx; - - *ctx = { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_typical_i, + /* .ctx = */ new llama_constraint_context_typical { + /*.p =*/ p, + /*.min_keep =*/ min_keep, + }, }; return result; @@ -892,15 +872,11 @@ static struct llama_constraint_i llama_constraint_temp_i = { }; struct llama_constraint * llama_constraint_init_temp_impl(float temp) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_temp_i; - result->ctx = new llama_constraint_context_temp; - - auto * ctx = (llama_constraint_context_temp *) result->ctx; - - *ctx = { - /*.temp =*/ temp, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_temp_i, + /* .ctx = */ new llama_constraint_context_temp { + /*.temp =*/ temp, + }, }; return result; @@ -943,17 +919,13 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { }; struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_temp_ext_i; - result->ctx = new llama_constraint_context_temp_ext; - - auto * ctx = (llama_constraint_context_temp_ext *) result->ctx; - - *ctx = { - /*.temp =*/ temp, - /*.delta =*/ delta, - /*.exponent =*/ exponent, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_temp_ext_i, + /* .ctx = */ new llama_constraint_context_temp_ext { + /*.temp =*/ temp, + /*.delta =*/ delta, + /*.exponent =*/ exponent, + }, }; return result; @@ -1019,12 +991,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { }; struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_grammar_i; - result->ctx = new llama_constraint_context_grammar; - - auto * ctx = (llama_constraint_context_grammar *) result->ctx; + auto * ctx = new llama_constraint_context_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { @@ -1040,6 +1007,11 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ }; } + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_grammar_i, + /* .ctx = */ ctx, + }; + return result; } @@ -1117,22 +1089,18 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_penalties_i; - result->ctx = new llama_constraint_context_penalties; - - auto * ctx = (llama_constraint_context_penalties *) result->ctx; - - *ctx = { - /*.vocab = */ &vocab, - /*.penalty_last_n = */ penalty_last_n, - /*.penalty_repeat = */ penalty_repeat, - /*.penalty_freq = */ penalty_freq, - /*.penalty_present = */ penalty_present, - /*.penalize_nl = */ penalize_nl, - /*.ignore_eos = */ ignore_eos, - /*.prev = */ {}, + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_penalties_i, + /* .ctx = */ new llama_constraint_context_penalties { + /*.vocab =*/ &vocab, + /*.penalty_last_n =*/ penalty_last_n, + /*.penalty_repeat =*/ penalty_repeat, + /*.penalty_freq =*/ penalty_freq, + /*.penalty_present =*/ penalty_present, + /*.penalize_nl =*/ penalize_nl, + /*.ignore_eos =*/ ignore_eos, + /*.prev =*/ {}, + }, }; return result; @@ -1177,16 +1145,12 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - struct llama_constraint * result = new llama_constraint; - - result->iface = &llama_constraint_logit_bias_i; - result->ctx = new llama_constraint_context_logit_bias; - - auto * ctx = (llama_constraint_context_logit_bias *) result->ctx; - - *ctx = { - /*.vocab = */ &vocab, - /*.logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_logit_bias_i, + /* .ctx = */ new llama_constraint_context_logit_bias { + /*.vocab =*/ &vocab, + /*.logit_bias=*/ std::vector(logit_bias, logit_bias + n_logit_bias), + }, }; return result; From 91cbb40b2923558a6e42eebe416e0eae5b79ee89 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 11:21:37 +0300 Subject: [PATCH 11/47] cont : common/sampling use the new API [no ci] --- common/common.cpp | 24 +- common/common.h | 2 +- common/sampling.cpp | 316 +++++++++++++-------- common/sampling.h | 69 +++-- examples/main/main.cpp | 31 ++- include/llama.h | 333 ++++++++++------------ src/llama-sampling.cpp | 319 +++++++-------------- src/llama-sampling.h | 93 ++----- src/llama.cpp | 616 +++++++++++++++++++---------------------- 9 files changed, 833 insertions(+), 970 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 23d171a4d7b96..f7095c7f3c1de 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -841,15 +841,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.defrag_thold = std::stof(argv[i]); return true; } - if (arg == "--samplers") { + if (arg == "--samplers" || arg == "--constraints") { CHECK_ARG - const auto sampler_names = string_split(argv[i], ';'); - sparams.samplers = llama_sampling_types_from_names(sampler_names, true); + const auto constraint_names = string_split(argv[i], ';'); + sparams.constraints = gpt_constraint_types_from_names(constraint_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.samplers = llama_sampling_types_from_chars(argv[i]); + sparams.constraints = gpt_constraint_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -1706,13 +1706,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const auto & sparams = params.sparams; - std::string sampler_type_chars; - std::string sampler_type_names; - for (const auto & sampler : sparams.samplers) { - sampler_type_chars += llama_sampling_type_to_chr(sampler); - sampler_type_names += llama_sampling_type_to_str(sampler) + ";"; + std::string constraint_type_chars; + std::string constraint_type_names; + for (const auto & constraint : sparams.constraints) { + constraint_type_chars += gpt_constraint_type_to_chr(constraint); + constraint_type_names += gpt_constraint_type_to_str(constraint) + ";"; } - sampler_type_names.pop_back(); + constraint_type_names.pop_back(); struct option_info { LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) @@ -1826,9 +1826,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "sampling" }); options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" - "(default: %s)", sampler_type_names.c_str() }); + "(default: %s)", constraint_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", - "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); + "simplified sequence for samplers that will be used (default: %s)", constraint_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); diff --git a/common/common.h b/common/common.h index 1c4eae34a8390..3a6c8e0b5377a 100644 --- a/common/common.h +++ b/common/common.h @@ -118,7 +118,7 @@ struct gpt_params { enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings - struct gpt_sampling_params sparams; + struct gpt_sampler_params sparams; std::string model = ""; // model path std::string model_draft = ""; // draft model for speculative decoding diff --git a/common/sampling.cpp b/common/sampling.cpp index a98117cf82455..a5e76dfd41e32 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,7 +2,7 @@ #include "common.h" -std::string gpt_sampling_params::print_all() const { +std::string gpt_sampler_params::print_all() const { char result[1024]; snprintf(result, sizeof(result), @@ -16,11 +16,11 @@ std::string gpt_sampling_params::print_all() const { return std::string(result); } -std::string gpt_sampling_params::print_samplers() const { +std::string gpt_sampler_params::print_constraints() const { std::string result = "CFG -> Penalties "; if (mirostat == 0) { - for (const auto & sampler : samplers) { - const auto name = llama_sampling_type_to_str(sampler); + for (const auto & cnstr : constraints) { + const auto name = gpt_constraint_type_to_str(cnstr); if (!name.empty()) { result += "-> " + name + " "; } @@ -32,66 +32,159 @@ std::string gpt_sampling_params::print_samplers() const { return result; } -struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) { - llama_sampling_params lparams = llama_sampling_default_params(); - - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; - lparams.n_probs = params.n_probs; - lparams.min_keep = params.min_keep; - lparams.top_k = params.top_k; - lparams.top_p = params.top_p; - lparams.min_p = params.min_p; - lparams.tfs_z = params.tfs_z; - lparams.typ_p = params.typ_p; - lparams.temp = params.temp; - lparams.dynatemp_range = params.dynatemp_range; - lparams.dynatemp_exponent = params.dynatemp_exponent; - lparams.penalty_last_n = params.penalty_last_n; - lparams.penalty_repeat = params.penalty_repeat; - lparams.penalty_freq = params.penalty_freq; - lparams.penalty_present = params.penalty_present; - lparams.mirostat = params.mirostat; - lparams.mirostat_tau = params.mirostat_tau; - lparams.mirostat_eta = params.mirostat_eta; - lparams.penalize_nl = params.penalize_nl; - lparams.ignore_eos = params.ignore_eos; - - lparams.n_samplers = params.samplers.size(); - for (int i = 0; i < lparams.n_samplers; i++) { - lparams.samplers[i] = params.samplers[i]; +struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { + gpt_sampler * result = new gpt_sampler(); + + llama_sampler_params lparams = llama_sampler_default_params(); + + lparams.seed = params.seed; + lparams.mirostat = params.mirostat; + lparams.mirostat_tau = params.mirostat_tau; + lparams.mirostat_eta = params.mirostat_eta; + + result->smpl = llama_sampler_init(model, lparams); + + llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data())); + + llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos)); + + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } + } + + result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"); + + return result; +} + +void gpt_sampler_free(struct gpt_sampler * gsmpl) { + if (gsmpl) { + llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->smpl); + + delete gsmpl; } +} - struct llama_sampling * result = llama_sampling_init(model, lparams); +struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { + gpt_sampler * result = new gpt_sampler(); - llama_sampling_set_grammar (result, params.grammar.c_str(), "root"); - llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data()); + result->grmr = llama_constraint_cp(gsmpl->grmr); + result->smpl = llama_sampler_cp(gsmpl->smpl); return result; } -void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst) { - if (dst) { - llama_sampling_free(dst); +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { + if (apply_grammar) { + llama_constraint_accept(gsmpl->grmr, token); + } + + llama_sampler_accept(gsmpl->smpl, token); +} + +void gpt_sampler_reset (struct gpt_sampler * gsmpl) { + llama_constraint_reset(gsmpl->grmr); + + llama_sampler_reset(gsmpl->smpl); +} + +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { + return llama_sampler_last(gsmpl->smpl); +} + +static llama_token gpt_sampler_sample( + struct llama_sampler * smpl, + struct llama_token_data_array * cur_p, + float temp, + int mirostat, + int n_probs) { + GGML_ASSERT(cur_p != nullptr && "candidates array must be provided"); + + llama_token res = 0; + + if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { + // greedy sampling, with probs + res = llama_sampler_sample_greedy(smpl, cur_p, true); + } else if (temp == 0.0f) { + // greedy sampling, no probs + res = llama_sampler_sample_greedy(smpl, cur_p, false); + } else { + llama_sampler_apply(smpl, cur_p); + + if (mirostat != 0) { + res = llama_sampler_sample_mirostat(smpl, cur_p); + } else { + res = llama_sampler_sample_dist(smpl, cur_p); + + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); + + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} + + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); + } } - dst = llama_sampling_cp(src); + return res; } -llama_token llama_sampling_sample( - struct llama_sampling * smpl, +llama_token gpt_sampler_sample( + struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + const auto & params = gsmpl->params; + + auto & grmr = gsmpl->grmr; + auto & smpl = gsmpl->smpl; + + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); // first, sample the token without any grammar constraints - const llama_token id = llama_sampling_sample(smpl, nullptr); + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); // create an array with a single token data element for the sampled id llama_token_data single_token_data = { id, 1.0f, 0.0f }; llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - llama_sampling_grammar(smpl, &single_token_data_array); + llama_constraint_apply(grmr, &single_token_data_array); // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -100,15 +193,18 @@ llama_token llama_sampling_sample( } // if the token is not valid, sample again, after applying the grammar constraints - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + auto * cur_p = llama_sampler_get_candidates(smpl); - llama_sampling_grammar(smpl, nullptr); + llama_constraint_apply(grmr, cur_p); - return llama_sampling_sample(smpl, nullptr); + return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); } -std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) { - n = std::min(n, llama_sampling_n_prev(smpl)); +std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { + auto & smpl = gsmpl->smpl; + + n = std::min(n, llama_sampler_n_prev(smpl)); if (n <= 0) { return ""; @@ -118,7 +214,7 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab for (int i = n - 1; i >= 0; i--) { - const llama_token id = llama_sampling_prev(smpl, i); + const llama_token id = llama_sampler_prev(smpl, i); GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); @@ -128,95 +224,95 @@ std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_m return result; } -char llama_sampling_type_to_chr(llama_constraint_type sampler) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: return 'k'; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: return 'f'; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; - case LLAMA_CONSTRAINT_TYPE_TOP_P: return 'p'; - case LLAMA_CONSTRAINT_TYPE_MIN_P: return 'm'; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return 't'; +char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: return 'k'; + case GPT_CONSTRAINT_TYPE_TFS_Z: return 'f'; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; + case GPT_CONSTRAINT_TYPE_TOP_P: return 'p'; + case GPT_CONSTRAINT_TYPE_MIN_P: return 'm'; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string llama_sampling_type_to_str(llama_constraint_type sampler) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: return "top_k"; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; - case LLAMA_CONSTRAINT_TYPE_TOP_P: return "top_p"; - case LLAMA_CONSTRAINT_TYPE_MIN_P: return "min_p"; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; +std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: return "top_k"; + case GPT_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; + case GPT_CONSTRAINT_TYPE_TOP_P: return "top_p"; + case GPT_CONSTRAINT_TYPE_MIN_P: return "min_p"; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map sampler_canonical_name_map { - { "top_k", LLAMA_CONSTRAINT_TYPE_TOP_K }, - { "top_p", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "typ_p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "min_p", LLAMA_CONSTRAINT_TYPE_MIN_P }, - { "tfs_z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "temperature", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, +std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map constraint_canonical_name_map { + { "top_k", GPT_CONSTRAINT_TYPE_TOP_K }, + { "top_p", GPT_CONSTRAINT_TYPE_TOP_P }, + { "typ_p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "min_p", GPT_CONSTRAINT_TYPE_MIN_P }, + { "tfs_z", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "temperature", GPT_CONSTRAINT_TYPE_TEMPERATURE }, }; - // since samplers names are written multiple ways + // since constraints names are written multiple ways // make it ready for both system names and input names - std::unordered_map sampler_alt_name_map { - { "top-k", LLAMA_CONSTRAINT_TYPE_TOP_K }, - { "top-p", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "nucleus", LLAMA_CONSTRAINT_TYPE_TOP_P }, - { "typical-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typical", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ-p", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ", LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { "min-p", LLAMA_CONSTRAINT_TYPE_MIN_P }, - { "tfs-z", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "tfs", LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { "temp", LLAMA_CONSTRAINT_TYPE_TEMPERATURE }, + std::unordered_map constraint_alt_name_map { + { "top-k", GPT_CONSTRAINT_TYPE_TOP_K }, + { "top-p", GPT_CONSTRAINT_TYPE_TOP_P }, + { "nucleus", GPT_CONSTRAINT_TYPE_TOP_P }, + { "typical-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typical", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "typ", GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { "min-p", GPT_CONSTRAINT_TYPE_MIN_P }, + { "tfs-z", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "tfs", GPT_CONSTRAINT_TYPE_TFS_Z }, + { "temp", GPT_CONSTRAINT_TYPE_TEMPERATURE }, }; - std::vector samplers; - samplers.reserve(names.size()); + std::vector constraints; + constraints.reserve(names.size()); for (const auto & name : names) { - auto sampler = sampler_canonical_name_map.find(name); - if (sampler != sampler_canonical_name_map.end()) { - samplers.push_back(sampler->second); + auto constraint = constraint_canonical_name_map.find(name); + if (constraint != constraint_canonical_name_map.end()) { + constraints.push_back(constraint->second); } else { if (allow_alt_names) { - sampler = sampler_alt_name_map.find(name); - if (sampler != sampler_alt_name_map.end()) { - samplers.push_back(sampler->second); + constraint = constraint_alt_name_map.find(name); + if (constraint != constraint_alt_name_map.end()) { + constraints.push_back(constraint->second); } } } } - return samplers; + return constraints; } -std::vector llama_sampling_types_from_chars(const std::string & chars) { - std::unordered_map sampler_name_map { - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_K), LLAMA_CONSTRAINT_TYPE_TOP_K }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TFS_Z), LLAMA_CONSTRAINT_TYPE_TFS_Z }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TYPICAL_P), LLAMA_CONSTRAINT_TYPE_TYPICAL_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TOP_P), LLAMA_CONSTRAINT_TYPE_TOP_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_MIN_P), LLAMA_CONSTRAINT_TYPE_MIN_P }, - { llama_sampling_type_to_chr(LLAMA_CONSTRAINT_TYPE_TEMPERATURE), LLAMA_CONSTRAINT_TYPE_TEMPERATURE } +std::vector gpt_constraint_types_from_chars(const std::string & chars) { + std::unordered_map constraint_name_map { + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_K), GPT_CONSTRAINT_TYPE_TOP_K }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TFS_Z), GPT_CONSTRAINT_TYPE_TFS_Z }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TYPICAL_P), GPT_CONSTRAINT_TYPE_TYPICAL_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_P), GPT_CONSTRAINT_TYPE_TOP_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_MIN_P), GPT_CONSTRAINT_TYPE_MIN_P }, + { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TEMPERATURE), GPT_CONSTRAINT_TYPE_TEMPERATURE } }; - std::vector samplers; - samplers.reserve(chars.size()); + std::vector constraints; + constraints.reserve(chars.size()); for (const auto & c : chars) { - const auto sampler = sampler_name_map.find(c); - if (sampler != sampler_name_map.end()) { - samplers.push_back(sampler->second); + const auto constraint = constraint_name_map.find(c); + if (constraint != constraint_name_map.end()) { + constraints.push_back(constraint->second); } } - return samplers; + return constraints; } diff --git a/common/sampling.h b/common/sampling.h index 365b7639acf52..4efa4a17ce4ae 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -5,13 +5,23 @@ #include #include +enum gpt_constraint_type { + GPT_CONSTRAINT_TYPE_NONE = 0, + GPT_CONSTRAINT_TYPE_TOP_K = 1, + GPT_CONSTRAINT_TYPE_TOP_P = 2, + GPT_CONSTRAINT_TYPE_MIN_P = 3, + GPT_CONSTRAINT_TYPE_TFS_Z = 4, + GPT_CONSTRAINT_TYPE_TYPICAL_P = 5, + GPT_CONSTRAINT_TYPE_TEMPERATURE = 6, +}; + // sampling parameters -struct gpt_sampling_params { - uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling +struct gpt_sampler_params { + uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens + int32_t min_keep = 0; // 0 = disabled, otherwise constraints should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -30,13 +40,13 @@ struct gpt_sampling_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector samplers = { - LLAMA_CONSTRAINT_TYPE_TOP_K, - LLAMA_CONSTRAINT_TYPE_TFS_Z, - LLAMA_CONSTRAINT_TYPE_TYPICAL_P, - LLAMA_CONSTRAINT_TYPE_TOP_P, - LLAMA_CONSTRAINT_TYPE_MIN_P, - LLAMA_CONSTRAINT_TYPE_TEMPERATURE + std::vector constraints = { + GPT_CONSTRAINT_TYPE_TOP_K, + GPT_CONSTRAINT_TYPE_TFS_Z, + GPT_CONSTRAINT_TYPE_TYPICAL_P, + GPT_CONSTRAINT_TYPE_TOP_P, + GPT_CONSTRAINT_TYPE_MIN_P, + GPT_CONSTRAINT_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -46,23 +56,30 @@ struct gpt_sampling_params { // print the parameters into a string std::string print_all() const; - // print the samplers into a string - std::string print_samplers() const; + // print the constraints into a string + std::string print_constraints() const; }; -// TODO: implement struct gpt_sampler { - gpt_sampling_params params; + gpt_sampler_params params; struct llama_constraint * grmr = nullptr; struct llama_sampler * smpl = nullptr; }; -// overload of llama_sampling_init using gpt_sampling_params -struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params); +// llama_sampler API overload + +struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); + +void gpt_sampler_free(struct gpt_sampler * gsmpl); + +struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); + +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); +void gpt_sampler_reset (struct gpt_sampler * gsmpl); -void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); // common sampling implementation: // @@ -71,18 +88,18 @@ void llama_sampling_cp(llama_sampling * src, llama_sampling *& dst); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token llama_sampling_sample( - struct llama_sampling * smpl, - struct llama_context * ctx, - int idx); +llama_token gpt_sampler_sample( + struct gpt_sampler * gsmpl, + struct llama_context * ctx, + int idx); // helpers // get a string representation of the last accepted tokens -std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n); +std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); -char llama_sampling_type_to_chr(enum llama_constraint_type sampler_type); -std::string llama_sampling_type_to_str(enum llama_constraint_type sampler_type); +char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr); +std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr); -std::vector llama_sampling_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector llama_sampling_types_from_chars(const std::string & chars); +std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector gpt_constraint_types_from_chars(const std::string & chars); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 296c1c687ad5b..88202b800762b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling ** g_smpl; +static gpt_sampler ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, *g_smpl); + llama_print_timings(*g_ctx, (*g_smpl)->smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -193,7 +193,7 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling * smpl = nullptr; + gpt_sampler * smpl = nullptr; std::vector chat_msgs; @@ -458,7 +458,7 @@ int main(int argc, char ** argv) { } } LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); - LOG_TEE("sampling order: \n%s\n", sparams.print_samplers().c_str()); + LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -525,7 +525,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - smpl = llama_sampling_init(model, sparams); + smpl = gpt_sampler_init(model, sparams); if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -681,9 +681,9 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } - const llama_token id = llama_sampling_sample(smpl, ctx, -1); + const llama_token id = gpt_sampler_sample(smpl, ctx, -1); - llama_sampling_accept(smpl, id, /* apply_grammar= */ true); + gpt_sampler_accept(smpl, id, /* apply_grammar= */ true); // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); @@ -704,7 +704,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); + gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -747,7 +747,7 @@ int main(int argc, char ** argv) { // check for reverse prompt in the last n_prev tokens if (!params.antiprompt.empty()) { const int n_prev = 32; - const std::string last_output = llama_sampling_prev_str(smpl, ctx, n_prev); + const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev); is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. @@ -769,7 +769,7 @@ int main(int argc, char ** argv) { } // check for reverse prompt using special tokens - llama_token last_token = llama_sampling_last(smpl); + llama_token last_token = gpt_sampler_last(smpl); for (std::vector ids : antiprompt_ids) { if (ids.size() == 1 && last_token == ids[0]) { if (params.interactive) { @@ -786,7 +786,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(model, llama_sampling_last(smpl))) { + if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { LOG("found an EOG token\n"); if (params.interactive) { @@ -807,7 +807,7 @@ int main(int argc, char ** argv) { // if current token is not EOG, we add it to current assistant message if (params.conversation) { - auto id = llama_sampling_last(smpl); + const auto id = gpt_sampler_last(smpl); assistant_ss << llama_token_to_piece(ctx, id, false); } @@ -903,7 +903,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(smpl); + gpt_sampler_reset(smpl); } is_interacting = false; } @@ -928,13 +928,14 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx, smpl); + llama_print_timings(ctx, smpl->smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + gpt_sampler_free(smpl); + llama_free(ctx); llama_free_model(model); - llama_sampling_free(smpl); llama_backend_free(); ggml_threadpool_free(threadpool); diff --git a/include/llama.h b/include/llama.h index 4dd5348a8a903..920952d68fa94 100644 --- a/include/llama.h +++ b/include/llama.h @@ -63,7 +63,6 @@ extern "C" { struct llama_model; struct llama_context; struct llama_sampler; - struct llama_sampling; // TODO: remove before merge typedef int32_t llama_pos; typedef int32_t llama_token; @@ -210,17 +209,6 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; - // TODO: move to common, rename to gpt_constraint_type - enum llama_constraint_type { - LLAMA_CONSTRAINT_TYPE_NONE = 0, - LLAMA_CONSTRAINT_TYPE_TOP_K = 1, - LLAMA_CONSTRAINT_TYPE_TOP_P = 2, - LLAMA_CONSTRAINT_TYPE_MIN_P = 3, - LLAMA_CONSTRAINT_TYPE_TFS_Z = 4, - LLAMA_CONSTRAINT_TYPE_TYPICAL_P = 5, - LLAMA_CONSTRAINT_TYPE_TEMPERATURE = 6, - }; - typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -384,38 +372,6 @@ extern "C" { float bias; } llama_logit_bias; - // TODO: remove before merge - // parameters for sampling the logits - typedef struct llama_sampling_params { - uint32_t seed; // the seed used to initialize llama_sampling_context - int32_t n_prev; // number of previous tokens to remember - int32_t n_probs; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep; // 0 = disabled, otherwise samplers should return at least min_keep tokens - int32_t top_k; // <= 0 to use vocab size - float top_p; // 1.0 = disabled - float min_p; // 0.0 = disabled - float tfs_z; // 1.0 = disabled - float typ_p; // typical_p, 1.0 = disabled - float temp; // <= 0.0 to sample greedily, 0.0 to not output probabilities - float dynatemp_range; // 0.0 = disabled - float dynatemp_exponent; // controls how entropy maps to temperature in dynamic temperature sampler - int32_t penalty_last_n; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat; // 1.0 = disabled - float penalty_freq; // 0.0 = disabled - float penalty_present; // 0.0 = disabled - int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau; // target entropy - float mirostat_eta; // learning rate - - // samplers - int32_t n_samplers; - enum llama_constraint_type samplers[LLAMA_MAX_SAMPLERS]; - - // Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value. - bool penalize_nl; // consider newlines as a repeatable token - bool ignore_eos; // ignore the end-of-sequence token - } llama_sampling_params; - typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler @@ -432,14 +388,10 @@ extern "C" { double t_end_ms; double t_load_ms; double t_sampling_ms; - double t_grammar_ms; - double t_accept_ms; double t_p_eval_ms; double t_eval_ms; int32_t n_sampling; - int32_t n_grammar; - int32_t n_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -458,7 +410,6 @@ extern "C" { LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); - LLAMA_API struct llama_sampling_params llama_sampling_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -1052,126 +1003,126 @@ extern "C" { // // TODO: llama_model should become llama_vocab - LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); - - LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - - // Copies the internal state of the sampler (rng, prev, params, grammar, etc.) - LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - - // - clear prev token - // - reset grammar state - LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); - - // Sampling parameter mutation - // TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable - LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); - LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - - // Set the logits from which to sample. - // This call initializes the internal token candidates array. - // The internal candidates are implicitly used by the sampling API below when no candidates are provided. - LLAMA_API void llama_sampling_set_logits( - struct llama_sampling * smpl, - const float * logits); - - /// @details Returns the current candidate tokens. - LLAMA_API llama_token_data_array * llama_sampling_get_candidates( - struct llama_sampling * smpl); - - // The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. - // Each function can accept an array of token candidates. If the candidates are not provided, the internal - // candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). - - /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sampling_softmax( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sampling_top_k( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sampling_top_p( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sampling_min_p( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sampling_tail_free( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sampling_typical( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Apply temperature and entropy - LLAMA_API void llama_sampling_temp( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Apply constraints from grammar - LLAMA_API void llama_sampling_grammar( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sampling_penalties( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - LLAMA_API llama_token llama_sampling_sample_mirostat( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. - LLAMA_API llama_token llama_sampling_sample_greedy( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Randomly selects a token from the candidates based on their probability distribution. - LLAMA_API llama_token llama_sampling_sample_dist( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). - LLAMA_API llama_token llama_sampling_sample( - struct llama_sampling * smpl, - llama_token_data_array * candidates); - - /// @details Accepts the sampled token into the sampling context. - /// - adds it to "prev" tokens - /// - updates the grammar state (if apply_grammar is true) - LLAMA_API void llama_sampling_accept( - struct llama_sampling * smpl, - llama_token token, - bool apply_grammar); - - /// @details Get the number of accepted tokens so far (max of n_prev) - LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); - - /// @details Get the ith accepted token - /// @param ith [0, n_prev), ith == 0 is the last accepted token. - /// returns LLAMA_TOKEN_NULL if ith is out of bounds - LLAMA_API llama_token llama_sampling_prev( - const struct llama_sampling * smpl, - int32_t ith); - - /// @details Get the last accepted token - /// Same as llama_sampling_prev(smpl, 0) - /// returns LLAMA_TOKEN_NULL if there are no accepted tokens - LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); + //LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); + + //LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); + + //// Copies the internal state of the sampler (rng, prev, params, grammar, etc.) + //LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); + + //// - clear prev token + //// - reset grammar state + //LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); + + //// Sampling parameter mutation + //// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable + //LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); + //LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); + + //// Set the logits from which to sample. + //// This call initializes the internal token candidates array. + //// The internal candidates are implicitly used by the sampling API below when no candidates are provided. + //LLAMA_API void llama_sampling_set_logits( + // struct llama_sampling * smpl, + // const float * logits); + + ///// @details Returns the current candidate tokens. + //LLAMA_API llama_token_data_array * llama_sampling_get_candidates( + // struct llama_sampling * smpl); + + //// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. + //// Each function can accept an array of token candidates. If the candidates are not provided, the internal + //// candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). + + ///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + //LLAMA_API void llama_sampling_softmax( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + //LLAMA_API void llama_sampling_top_k( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + //LLAMA_API void llama_sampling_top_p( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + //LLAMA_API void llama_sampling_min_p( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + //LLAMA_API void llama_sampling_tail_free( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + //LLAMA_API void llama_sampling_typical( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Apply temperature and entropy + //LLAMA_API void llama_sampling_temp( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Apply constraints from grammar + //LLAMA_API void llama_sampling_grammar( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + ///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + //LLAMA_API void llama_sampling_penalties( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + //LLAMA_API llama_token llama_sampling_sample_mirostat( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Selects the token with the highest probability. + ///// Does not compute the token probabilities. Use llama_sampling_softmax() instead. + //LLAMA_API llama_token llama_sampling_sample_greedy( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Randomly selects a token from the candidates based on their probability distribution. + //LLAMA_API llama_token llama_sampling_sample_dist( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). + //LLAMA_API llama_token llama_sampling_sample( + // struct llama_sampling * smpl, + // llama_token_data_array * candidates); + + ///// @details Accepts the sampled token into the sampling context. + ///// - adds it to "prev" tokens + ///// - updates the grammar state (if apply_grammar is true) + //LLAMA_API void llama_sampling_accept( + // struct llama_sampling * smpl, + // llama_token token, + // bool apply_grammar); + + ///// @details Get the number of accepted tokens so far (max of n_prev) + //LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); + + ///// @details Get the ith accepted token + ///// @param ith [0, n_prev), ith == 0 is the last accepted token. + ///// returns LLAMA_TOKEN_NULL if ith is out of bounds + //LLAMA_API llama_token llama_sampling_prev( + // const struct llama_sampling * smpl, + // int32_t ith); + + ///// @details Get the last accepted token + ///// Same as llama_sampling_prev(smpl, 0) + ///// returns LLAMA_TOKEN_NULL if there are no accepted tokens + //LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); // // Sampling v2 API @@ -1204,11 +1155,11 @@ extern "C" { struct llama_constraint_i { // TODO: add name API - void (*accept)(struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) (struct llama_constraint * cnstr, llama_token_data_array * candidates); // required - void (*reset) (struct llama_constraint * cnstr); // can be NULL - void (*copy) (struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src); // can be NULL if ctx is NULL - void (*free) (struct llama_constraint * cnstr); // can be NULL + void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL + void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * candidates); // required + void (*reset) ( struct llama_constraint * cnstr); // can be NULL + struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL + void (*free) ( struct llama_constraint * cnstr); // can be NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1228,21 +1179,27 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar (struct llama_model * model, const char * grammar_str, const char * grammar_root); + + LLAMA_API struct llama_constraint * llama_constraint_init_grammar( + const struct llama_model * model, + const char * grammar_str, + const char * grammar_root); LLAMA_API struct llama_constraint * llama_constraint_init_penalties( - struct llama_model * model, - int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) - float penalty_repeat, // 1.0 = disabled - float penalty_freq, // 0.0 = disabled - float penalty_present, // 0.0 = disabled - bool penalize_nl, // consider newlines as a repeatable token - bool ignore_eos); // ignore the end-of-sequence token + const struct llama_model * model, + int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat, // 1.0 = disabled + float penalty_freq, // 0.0 = disabled + float penalty_present, // 0.0 = disabled + bool penalize_nl, // consider newlines as a repeatable token + bool ignore_eos); // ignore the end-of-sequence token LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias); + + LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); // do not call if used with llama_sampler_add_constraint LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); @@ -1273,7 +1230,7 @@ extern "C" { LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates); LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates); - LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs); LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates); /// @details Get the number of accepted tokens so far (max of n_prev) @@ -1310,8 +1267,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl); - LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f21b5fd55b3ae..7a1f8a8059022 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -24,107 +24,7 @@ static void llama_log_softmax(float * array, size_t size) { } } -llama_sampling::llama_sampling(const struct llama_vocab & vocab) : vocab(vocab) { -} - -llama_sampling::~llama_sampling() { - if (grammar) { - llama_grammar_free_impl(grammar); - } -} - -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params) { - auto * result = new llama_sampling(vocab); - - result->params = params; - - result->prev = ring_buffer(params.n_prev); - - for (int i = 0; i < params.n_samplers; ++i) { - result->samplers.push_back(params.samplers[i]); - } - - llama_sampling_set_rng_seed_impl(*result, params.seed); - - return result; -} - -void llama_sampling_free_impl(struct llama_sampling * sampling) { - if (sampling == nullptr) { - return; - } - - delete sampling; -} - -struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { - auto * result = new llama_sampling(smpl.vocab); - - result->params = smpl.params; - - result->grammar_str = smpl.grammar_str; - result->grammar_root = smpl.grammar_root; - - result->logit_bias = smpl.logit_bias; - - if (smpl.grammar) { - result->grammar = llama_grammar_cp_impl(*smpl.grammar); - } - - result->rng = smpl.rng; - result->prev = smpl.prev; - - return result; -} - -void llama_sampling_reset_impl(struct llama_sampling & smpl) { - if (smpl.grammar) { - llama_grammar_free_impl(smpl.grammar); - smpl.grammar = nullptr; - } - - if (!smpl.grammar_str.empty()) { - smpl.grammar = llama_grammar_init_impl(&smpl.vocab, smpl.grammar_str.data(), smpl.grammar_root.data()); - } - - smpl.prev.clear(); -} - -void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - - smpl.rng.seed(seed); -} - -void llama_sampling_set_grammar_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { - if (smpl.grammar) { - llama_grammar_free_impl(smpl.grammar); - smpl.grammar = nullptr; - } - - if (grammar_str != nullptr && grammar_str[0] != '\0') { - smpl.grammar_str = grammar_str; - smpl.grammar_root = grammar_root; - - smpl.grammar = llama_grammar_init_impl(&smpl.vocab, grammar_str, grammar_root); - } else { - smpl.grammar_str.clear(); - smpl.grammar_root.clear(); - } -} - -void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - smpl.logit_bias.clear(); - smpl.logit_bias.reserve(n_logit_bias); - - for (int32_t i = 0; i < n_logit_bias; ++i) { - smpl.logit_bias.push_back(logit_bias[i]); - } -} - -void llama_sampling_softmax_impl(llama_token_data_array * candidates) { +void llama_constraint_softmax_impl(llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); // Sort the logits in descending order @@ -149,7 +49,7 @@ void llama_sampling_softmax_impl(llama_token_data_array * candidates) { } } -void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; @@ -226,12 +126,12 @@ void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, s candidates->size = k; } -void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -252,7 +152,7 @@ void llama_sampling_top_p_impl(llama_token_data_array * candidates, float p, siz candidates->size = last_idx; } -void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } @@ -307,12 +207,12 @@ void llama_sampling_min_p_impl(llama_token_data_array * candidates, float p, siz } } -void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_constraint_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -361,7 +261,7 @@ void llama_sampling_tail_free_impl(llama_token_data_array * candidates, float z, candidates->size = last_idx; } -void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -369,7 +269,7 @@ void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, s } // Compute the softmax of logits and calculate entropy - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -419,7 +319,7 @@ void llama_sampling_typical_impl(llama_token_data_array * candidates, float p, s candidates->sorted = false; } -void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { +void llama_constraint_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -428,7 +328,7 @@ void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_ // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -482,17 +382,17 @@ void llama_sampling_entropy_impl(llama_token_data_array * candidates, float min_ #endif } -void llama_sampling_temp_impl(llama_token_data_array * candidates, float temp) { +void llama_constraint_temp_impl(llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } } -void llama_sampling_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { +void llama_constraint_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { llama_grammar_apply_impl(grammar, candidates); } -void llama_sampling_penalties_impl( +void llama_constraint_penalties_impl( llama_token_data_array * candidates, const llama_token_cnt & token_count, float penalty_repeat, @@ -521,8 +421,8 @@ void llama_sampling_penalties_impl( candidates->sorted = false; } -llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_constraint_softmax_impl(candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -541,8 +441,8 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sampling_top_k_impl(candidates, int(k), 1); - llama_token X = llama_sampling_sample_dist_impl(candidates, rng); + llama_constraint_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -557,8 +457,8 @@ llama_token llama_sampling_sample_mirostat_impl(struct llama_token_data_array * return X; } -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_constraint_softmax_impl(candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -570,10 +470,10 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array } // Normalize the probabilities of the remaining words - llama_sampling_softmax_impl(candidates); + llama_constraint_softmax_impl(candidates); // Sample the next word X from the remaining words - llama_token X = llama_sampling_sample_dist_impl(candidates, rng); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -589,8 +489,16 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array return X; } -llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidates) { - // Find max element +llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { + if (probs) { + // if probs are needed, we apply softmax to get the probabilities + llama_constraint_softmax_impl(candidates); + + // the candidates are sorted, so we can just return the first one + return candidates->data[0].id; + } + + // return the token with the highest logit auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); @@ -600,8 +508,8 @@ llama_token llama_sampling_sample_greedy_impl(llama_token_data_array * candidate return result; } -llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { - llama_sampling_softmax_impl(candidates); +llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_constraint_softmax_impl(candidates); std::vector probs; probs.reserve(candidates->size); @@ -618,26 +526,6 @@ llama_token llama_sampling_sample_dist_impl(struct llama_token_data_array * cand return result; } -void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar) { - smpl.prev.push_back(token); - - if (apply_grammar && smpl.grammar) { - llama_grammar_accept_impl(*smpl.grammar, token); - } -} - -llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.prev.size()) { - return LLAMA_TOKEN_NULL; - } - - return smpl.prev.rat(ith); -} - -int llama_sampling_n_prev_impl(const struct llama_sampling & smpl) { - return smpl.prev.size(); -} - // // sampling v2 // @@ -655,14 +543,12 @@ static struct llama_constraint_i llama_constraint_top_k_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_sampling_top_k_impl(candidates, ctx->k, ctx->min_keep); + llama_constraint_top_k_impl(candidates, ctx->k, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_top_k; - const auto * ctx_src = (const llama_constraint_context_top_k *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_top_k *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; + return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -695,14 +581,12 @@ static struct llama_constraint_i llama_constraint_top_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; - llama_sampling_top_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_top_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_top_p; - const auto * ctx_src = (const llama_constraint_context_top_p *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_top_p *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx; + return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -735,14 +619,12 @@ static struct llama_constraint_i llama_constraint_min_p_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; - llama_sampling_min_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_min_p_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_min_p; - const auto * ctx_src = (const llama_constraint_context_min_p *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_min_p *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx; + return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -775,14 +657,12 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; - llama_sampling_tail_free_impl(candidates, ctx->z, ctx->min_keep); + llama_constraint_tail_free_impl(candidates, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_tail_free; - const auto * ctx_src = (const llama_constraint_context_tail_free *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_tail_free *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx; + return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -815,14 +695,12 @@ static struct llama_constraint_i llama_constraint_typical_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; - llama_sampling_typical_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_typical_impl(candidates, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_typical; - const auto * ctx_src = (const llama_constraint_context_typical *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_typical *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx; + return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -854,14 +732,12 @@ static struct llama_constraint_i llama_constraint_temp_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; - llama_sampling_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(candidates, ctx->temp); }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_temp; - const auto * ctx_src = (const llama_constraint_context_temp *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_temp *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx; + return llama_constraint_init_temp_impl(ctx->temp); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -898,17 +774,15 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; - llama_sampling_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); + llama_constraint_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); } else { - llama_sampling_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(candidates, ctx->temp); } }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_temp_ext; - const auto * ctx_src = (const llama_constraint_context_temp_ext *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_temp_ext *) cnstr->ctx; - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx; + return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -950,7 +824,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { - llama_sampling_grammar_impl(candidates, *ctx->grammar); + llama_constraint_grammar_impl(candidates, *ctx->grammar); } }, /* .reset = */ [](struct llama_constraint * cnstr) { @@ -964,18 +838,19 @@ static struct llama_constraint_i llama_constraint_grammar_i = { ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); } }, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_grammar; - const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_grammar *) cnstr->ctx; - - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; + auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr); + auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; if (ctx_src->grammar) { + ctx_dst->grammar_str = ctx_src->grammar_str; + ctx_dst->grammar_root = ctx_src->grammar_root; + ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar); - } else { - ctx_dst->grammar = nullptr; } + + return result; }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1059,7 +934,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { token_count[ctx->prev.rat(i)]++; } - llama_sampling_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + llama_constraint_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); if (!ctx->penalize_nl) { // restore the logit of the newline token if it was penalized @@ -1070,12 +945,21 @@ static struct llama_constraint_i llama_constraint_penalties_i = { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.clear(); }, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_penalties; - const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_penalties *) cnstr->ctx; - - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx; + auto * result = llama_constraint_init_penalties_impl( + *ctx_src->vocab, + ctx_src->penalty_last_n, + ctx_src->penalty_repeat, + ctx_src->penalty_freq, + ctx_src->penalty_present, + ctx_src->penalize_nl, + ctx_src->ignore_eos); + + auto * ctx_dst = (llama_constraint_context_penalties *) result->ctx; + ctx_dst->prev = ctx_src->prev; + + return result; }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1126,12 +1010,9 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { } }, /* .reset = */ nullptr, - /* .copy = */ [](struct llama_constraint * cnstr, const struct llama_constraint * cnstr_src) { - cnstr->ctx = new llama_constraint_context_logit_bias; - const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr_src->ctx; - auto * ctx_dst = ( llama_constraint_context_logit_bias *) cnstr->ctx; - - *ctx_dst = *ctx_src; + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx; + return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_constraint * cnstr) { if (cnstr->ctx) { @@ -1158,6 +1039,10 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( //////////////////////////////////////// +struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr) { + return cnstr.iface->copy ? cnstr.iface->copy(&cnstr) : nullptr; +} + void llama_constraint_free_impl(struct llama_constraint * cnstr) { if (cnstr->iface->free && cnstr) { cnstr->iface->free(cnstr); @@ -1214,12 +1099,14 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) // copy the constraints objects result->constraints.clear(); for (const auto & cnstr : smpl.constraints) { - result->constraints.push_back(new llama_constraint); - result->constraints.back()->iface = cnstr->iface; - - if (cnstr->ctx) { + if (cnstr->ctx == nullptr) { + result->constraints.push_back(new llama_constraint { + /* .iface = */ cnstr->iface, + /* .ctx = */ nullptr, + }); + } else { GGML_ASSERT(cnstr->iface->copy); - result->constraints.back()->iface->copy(result->constraints.back(), cnstr); + result->constraints.push_back(cnstr->iface->copy(cnstr)); } } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 7de37c89e5817..dd9236392abfd 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,74 +10,17 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; -// TODO: remove before merge -struct llama_sampling { - llama_sampling(const struct llama_vocab & vocab); - ~llama_sampling(); - - llama_sampling_params params; - - std::string grammar_str; - std::string grammar_root; - - std::vector logit_bias; // logit biases to apply - - // state - - std::mt19937 rng; - - const struct llama_vocab & vocab; - - std::vector samplers; - - ring_buffer prev; - - struct llama_grammar * grammar = nullptr; - - // mirostat sampler state - float mirostat_mu; - - mutable int64_t t_sample_us = 0; - mutable int64_t t_grammar_us = 0; - mutable int64_t t_accept_us = 0; - - mutable int32_t n_sample = 0; - mutable int32_t n_grammar = 0; - mutable int32_t n_accept = 0; - - std::vector cur; - - llama_token_data_array cur_p; -}; - -// -// internal API -// - -struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, struct llama_sampling_params params); - -void llama_sampling_free_impl(struct llama_sampling * sampling); - -struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); - -void llama_sampling_reset_impl(struct llama_sampling & smpl); - -// TODO: move the API below as member functions of llama_sampling -void llama_sampling_set_rng_seed_impl (struct llama_sampling & smpl, uint32_t seed); -void llama_sampling_set_grammar_impl (struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); -void llama_sampling_set_logit_bias_impl(struct llama_sampling & smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - -void llama_sampling_softmax_impl (struct llama_token_data_array * candidates); -void llama_sampling_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sampling_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sampling_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sampling_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sampling_temp_impl (struct llama_token_data_array * candidates, float temp); -void llama_sampling_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); - -void llama_sampling_penalties_impl( +void llama_constraint_softmax_impl (struct llama_token_data_array * candidates); +void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); +void llama_constraint_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); +void llama_constraint_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_constraint_temp_impl (struct llama_token_data_array * candidates, float temp); +void llama_constraint_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); + +void llama_constraint_penalties_impl( llama_token_data_array * candidates, const llama_token_cnt & token_count, float penalty_repeat, @@ -90,22 +33,18 @@ void llama_sampling_penalties_impl( /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); +llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); -llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates); -llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); +llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); -void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar); - -llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith); -int llama_sampling_n_prev_impl(const struct llama_sampling & smpl); // @@ -141,6 +80,8 @@ struct llama_constraint * llama_constraint_init_penalties_impl( int32_t n_logit_bias, const llama_logit_bias * logit_bias); +struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr); + void llama_constraint_free_impl(struct llama_constraint * cnstr); void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index 4060fa1de420d..a40fc4c30d097 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17946,36 +17946,6 @@ struct llama_sampler_params llama_sampler_default_params() { return result; } -struct llama_sampling_params llama_sampling_default_params() { - struct llama_sampling_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_prev =*/ 64, - /*.n_probs =*/ 0, - /*.min_keep =*/ 0, - /*.top_k =*/ 40, - /*.top_p =*/ 0.95f, - /*.min_p =*/ 0.05f, - /*.tfs_z =*/ 1.00f, - /*.typ_p =*/ 1.00f, - /*.temp =*/ 0.80f, - /*.dynatemp_range =*/ 0.00f, - /*.dynatemp_exponent =*/ 1.00f, - /*.penalty_last_n =*/ 64, - /*.penalty_repeat =*/ 1.00f, - /*.penalty_freq =*/ 0.00f, - /*.penalty_present =*/ 0.00f, - /*.mirostat =*/ 0, - /*.mirostat_tau =*/ 5.00f, - /*.mirostat_eta =*/ 0.10f, - /*.n_samplers =*/ 3, - /*.samplers =*/ { LLAMA_CONSTRAINT_TYPE_TEMPERATURE, LLAMA_CONSTRAINT_TYPE_TOP_K, LLAMA_CONSTRAINT_TYPE_TOP_P, }, - /*.penalize_nl =*/ false, - /*.ignore_eos =*/ false, - }; - - return result; -} - struct llama_model_quantize_params llama_model_quantize_default_params() { struct llama_model_quantize_params result = { /*.nthread =*/ 0, @@ -20638,341 +20608,341 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { - return llama_sampling_init_impl(model->vocab, params); -} - -void llama_sampling_free(struct llama_sampling * smpl) { - if (smpl == nullptr) { - return; - } - - llama_sampling_free_impl(smpl); -} - -struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { - return llama_sampling_cp_impl(*smpl); -} - -void llama_sampling_reset(struct llama_sampling * smpl) { - llama_sampling_reset_impl(*smpl); -} - -void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { - llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); -} - -void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); -} - -void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { - const int n_vocab = smpl->vocab.n_vocab; - - smpl->cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - for (const auto & lb : smpl->logit_bias) { - smpl->cur[lb.token].logit += lb.bias; - } - - if (smpl->params.ignore_eos) { - smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; - } - - smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; - - // apply penalties - { - const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; - - llama_sampling_penalties(smpl, &smpl->cur_p); - - if (!smpl->params.penalize_nl) { - for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { - if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { - smpl->cur_p.data[idx].logit = nl_logit; - break; - } - } - } - } -} - -llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { - return &smpl->cur_p; -} - -void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { +// return llama_sampling_init_impl(model->vocab, params); +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//void llama_sampling_free(struct llama_sampling * smpl) { +// if (smpl == nullptr) { +// return; +// } - llama_sampling_softmax_impl(candidates); -} +// llama_sampling_free_impl(smpl); +//} -void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { +// return llama_sampling_cp_impl(*smpl); +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//void llama_sampling_reset(struct llama_sampling * smpl) { +// llama_sampling_reset_impl(*smpl); +//} - llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); -} +//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { +// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); +//} -void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { +// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { +// const int n_vocab = smpl->vocab.n_vocab; - llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); -} +// smpl->cur.resize(n_vocab); -void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// for (llama_token token_id = 0; token_id < n_vocab; token_id++) { +// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; +// } - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// for (const auto & lb : smpl->logit_bias) { +// smpl->cur[lb.token].logit += lb.bias; +// } - llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); -} +// if (smpl->params.ignore_eos) { +// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; +// } -void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// // apply penalties +// { +// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; - llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); -} +// llama_sampling_penalties(smpl, &smpl->cur_p); -void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// if (!smpl->params.penalize_nl) { +// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { +// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { +// smpl->cur_p.data[idx].logit = nl_logit; +// break; +// } +// } +// } +// } +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { +// return &smpl->cur_p; +//} - llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); -} +//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); -void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } - - if (smpl->params.dynatemp_range > 0) { - const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); - const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); +// llama_sampling_softmax_impl(candidates); +//} - llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); - } else { - llama_sampling_temp_impl(candidates, smpl->params.temp); - } -} +//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); -void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_grammar_us); +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); +//} - if (smpl->grammar) { - llama_sampling_grammar_impl(candidates, *smpl->grammar); +//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - smpl->n_grammar++; - } -} +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } -void llama_sampling_penalties( - struct llama_sampling * smpl, - llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - const float penalty_repeat = smpl->params.penalty_repeat; - const float penalty_freq = smpl->params.penalty_freq; - const float penalty_present = smpl->params.penalty_present; +// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); +//} - if ((penalty_last_n == 0) || - (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { - return; - } +//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - // Create a frequency map to count occurrences of each token in last_tokens - // TODO: move to sampling state and avoid reallocation - llama_token_cnt token_count; - for (size_t i = 0; i < penalty_last_n; ++i) { - token_count[smpl->prev.rat(i)]++; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); -} +// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); +//} -llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - const auto type = smpl->params.mirostat; +// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); +//} - llama_token res; +//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - if (type == 1) { - res = llama_sampling_sample_mirostat_impl(candidates, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - 100, - smpl->vocab.n_vocab, - smpl->mirostat_mu); - } else if (type == 2) { - res = llama_sampling_sample_mirostat_v2_impl(candidates, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - smpl->mirostat_mu); - } else { - GGML_ABORT("invalid mirostat type: %d", type); - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - smpl->n_sample++; +// if (smpl->params.dynatemp_range > 0) { +// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); +// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); - return res; -} +// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); +// } else { +// llama_sampling_temp_impl(candidates, smpl->params.temp); +// } +//} -llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_grammar_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - auto res = llama_sampling_sample_greedy_impl(candidates); +// if (smpl->grammar) { +// llama_sampling_grammar_impl(candidates, *smpl->grammar); - smpl->n_sample++; +// smpl->n_grammar++; +// } +//} - return res; -} +//void llama_sampling_penalties( +// struct llama_sampling * smpl, +// llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); -llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +// const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); - auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); +// const float penalty_repeat = smpl->params.penalty_repeat; +// const float penalty_freq = smpl->params.penalty_freq; +// const float penalty_present = smpl->params.penalty_present; - smpl->n_sample++; +// if ((penalty_last_n == 0) || +// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { +// return; +// } - return res; -} +// // Create a frequency map to count occurrences of each token in last_tokens +// // TODO: move to sampling state and avoid reallocation +// llama_token_cnt token_count; +// for (size_t i = 0; i < penalty_last_n; ++i) { +// token_count[smpl->prev.rat(i)]++; +// } -llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { - time_meas tm(smpl->t_sample_us); +// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); +//} - if (candidates == nullptr) { - candidates = &smpl->cur_p; - } +//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); + +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } + +// const auto type = smpl->params.mirostat; + +// llama_token res; + +// if (type == 1) { +// res = llama_sampling_sample_mirostat_impl(candidates, +// smpl->rng, +// smpl->params.mirostat_tau, +// smpl->params.mirostat_eta, +// 100, +// smpl->vocab.n_vocab, +// smpl->mirostat_mu); +// } else if (type == 2) { +// res = llama_sampling_sample_mirostat_v2_impl(candidates, +// smpl->rng, +// smpl->params.mirostat_tau, +// smpl->params.mirostat_eta, +// smpl->mirostat_mu); +// } else { +// GGML_ABORT("invalid mirostat type: %d", type); +// } + +// smpl->n_sample++; + +// return res; +//} - const auto & params = smpl->params; +//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - const float temp = params.temp; - const int mirostat = params.mirostat; +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - auto & cur_p = candidates; +// auto res = llama_sampling_sample_greedy_impl(candidates); - llama_token res = 0; +// smpl->n_sample++; - if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { - // greedy sampling, with probs - llama_sampling_softmax_impl(cur_p); - res = cur_p->data[0].id; - } else if (temp == 0.0f) { - // greedy sampling, no probs - res = llama_sampling_sample_greedy(smpl, cur_p); - } else { - if (mirostat != 0) { - llama_sampling_temp(smpl, cur_p); - res = llama_sampling_sample_mirostat(smpl, cur_p); - } else { - for (const auto & sampler : smpl->samplers) { - switch (sampler) { - case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; - case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; - default : break; - } - } +// return res; +//} - res = llama_sampling_sample_dist(smpl, cur_p); +//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} +// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } - } +// smpl->n_sample++; - smpl->n_sample++; +// return res; +//} - return res; -} +//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { +// time_meas tm(smpl->t_sample_us); + +// if (candidates == nullptr) { +// candidates = &smpl->cur_p; +// } + +// const auto & params = smpl->params; + +// const float temp = params.temp; +// const int mirostat = params.mirostat; + +// auto & cur_p = candidates; + +// llama_token res = 0; + +// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { +// // greedy sampling, with probs +// llama_sampling_softmax_impl(cur_p); +// res = cur_p->data[0].id; +// } else if (temp == 0.0f) { +// // greedy sampling, no probs +// res = llama_sampling_sample_greedy(smpl, cur_p); +// } else { +// if (mirostat != 0) { +// llama_sampling_temp(smpl, cur_p); +// res = llama_sampling_sample_mirostat(smpl, cur_p); +// } else { +// for (const auto & sampler : smpl->samplers) { +// switch (sampler) { +// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; +// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; +// default : break; +// } +// } + +// res = llama_sampling_sample_dist(smpl, cur_p); + +// //{ +// // const int n_top = 10; +// // LOG("top %d candidates:\n", n_top); + +// // for (int i = 0; i < n_top; i++) { +// // const llama_token id = cur_p.data[i].id; +// // (void)id; // To avoid a warning that id is unused when logging is disabled. +// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); +// // } +// //} + +// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); +// } +// } + +// smpl->n_sample++; + +// return res; +//} -void llama_sampling_accept( - struct llama_sampling * smpl, - llama_token token, - bool apply_grammar) { - time_meas tm(smpl->t_accept_us); +//void llama_sampling_accept( +// struct llama_sampling * smpl, +// llama_token token, +// bool apply_grammar) { +// time_meas tm(smpl->t_accept_us); - llama_sampling_accept_impl(*smpl, token, apply_grammar); +// llama_sampling_accept_impl(*smpl, token, apply_grammar); - smpl->n_accept++; -} +// smpl->n_accept++; +//} -int llama_sampling_n_prev(const struct llama_sampling * smpl) { - return llama_sampling_n_prev_impl(*smpl); -} +//int llama_sampling_n_prev(const struct llama_sampling * smpl) { +// return llama_sampling_n_prev_impl(*smpl); +//} -llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { - return llama_sampling_prev_impl(*smpl, ith); -} +//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { +// return llama_sampling_prev_impl(*smpl, ith); +//} -llama_token llama_sampling_last(const struct llama_sampling * smpl) { - return llama_sampling_prev_impl(*smpl, 0); -} +//llama_token llama_sampling_last(const struct llama_sampling * smpl) { +// return llama_sampling_prev_impl(*smpl, 0); +//} // // sampling v2 @@ -21006,28 +20976,32 @@ struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta return llama_constraint_init_temp_ext_impl(temp, delta, exponent); } -struct llama_constraint * llama_constraint_init_grammar(struct llama_model * model, const char * grammar_str, const char * grammar_root) { +struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } struct llama_constraint * llama_constraint_init_penalties( - struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { + const struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); } LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } +struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr) { + return llama_constraint_cp_impl(*cnstr); +} + void llama_constraint_free(struct llama_constraint * cnstr) { if (cnstr == nullptr) { return; @@ -21110,7 +21084,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok llama_token res; if (type == 1) { - res = llama_sampling_sample_mirostat_impl(candidates, + res = llama_sampler_sample_mirostat_impl(candidates, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -21118,7 +21092,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok smpl->vocab->n_vocab, smpl->mirostat_mu); } else if (type == 2) { - res = llama_sampling_sample_mirostat_v2_impl(candidates, + res = llama_sampler_sample_mirostat_v2_impl(candidates, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -21132,14 +21106,14 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok return res; } -llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates) { +llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs) { time_meas tm(smpl->t_sample_us); if (candidates == nullptr) { candidates = &smpl->cur_p; } - auto res = llama_sampling_sample_greedy_impl(candidates); + auto res = llama_sampler_sample_greedy_impl(candidates, probs); smpl->n_sample++; @@ -21153,7 +21127,7 @@ llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_d candidates = &smpl->cur_p; } - auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); + auto res = llama_sampler_sample_dist_impl(candidates, smpl->rng); smpl->n_sample++; @@ -21204,20 +21178,16 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) { +void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { const llama_timings timings = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), - /*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0), - /*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0), /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), - /*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0), - /*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0), /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), /*.n_eval =*/ std::max(1, ctx->n_eval), }; @@ -21226,10 +21196,6 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smp LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); - LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar); - //LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - // __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -21237,15 +21203,13 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smp LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; if (smpl) { smpl->t_sample_us = smpl->n_sample = 0; - smpl->t_grammar_us = smpl->n_grammar = 0; - smpl->t_accept_us = smpl->n_accept = 0; } } From 437376e7083403fe4ff779062979f6f90c0dca65 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 11:54:49 +0300 Subject: [PATCH 12/47] cont : add n_prev to llama_sampler_params --- common/sampling.cpp | 6 ++++-- include/llama.h | 2 ++ src/llama-impl.h | 1 - src/llama-sampling.cpp | 39 +++++++++++++++++++++++++++++---------- src/llama-sampling.h | 4 ++-- src/llama.cpp | 1 + 6 files changed, 38 insertions(+), 15 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a5e76dfd41e32..4dfbe9021ec72 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -38,6 +38,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_params lparams = llama_sampler_default_params(); lparams.seed = params.seed; + lparams.n_prev = params.n_prev; lparams.mirostat = params.mirostat; lparams.mirostat_tau = params.mirostat_tau; lparams.mirostat_eta = params.mirostat_eta; @@ -177,8 +178,10 @@ llama_token gpt_sampler_sample( llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + auto * cur_p = llama_sampler_get_candidates(smpl); + // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); + const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); // create an array with a single token data element for the sampled id llama_token_data single_token_data = { id, 1.0f, 0.0f }; @@ -194,7 +197,6 @@ llama_token gpt_sampler_sample( // if the token is not valid, sample again, after applying the grammar constraints llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampler_get_candidates(smpl); llama_constraint_apply(grmr, cur_p); diff --git a/include/llama.h b/include/llama.h index 920952d68fa94..49011b7fe5ec6 100644 --- a/include/llama.h +++ b/include/llama.h @@ -375,6 +375,8 @@ extern "C" { typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler + int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) + int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau; // target entropy float mirostat_eta; // learning rate diff --git a/src/llama-impl.h b/src/llama-impl.h index b67f511c08157..6d388655d01a8 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -56,7 +56,6 @@ const std::vector> & llama_internal // the ring buffer works similarly to std::deque, but with a fixed capacity template struct ring_buffer { - ring_buffer() {} ring_buffer(size_t cap) : capacity(cap), data(cap) {} T & front() { diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 7a1f8a8059022..abf9d5a8ee789 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -983,7 +983,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam /*.penalty_present =*/ penalty_present, /*.penalize_nl =*/ penalize_nl, /*.ignore_eos =*/ ignore_eos, - /*.prev =*/ {}, + /*.prev =*/ ring_buffer(penalty_last_n), }, }; @@ -1069,12 +1069,20 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { // samplers struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { - auto * result = new llama_sampler; - - result->params = params; - result->vocab = &vocab; - - result->rng.seed(params.seed); + auto * result = new llama_sampler { + /* .params = */ params, + /* .vocab = */ &vocab, + + /* .rng = */ std::mt19937(params.seed), + + /* .mirostat_mu = */ 0.0f, + /* .prev = */ { (size_t) params.n_prev }, + /* .constraints = */ {}, + /* .cur = */ {}, + /* .cur_p = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }; return result; } @@ -1092,9 +1100,20 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) { } struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) { - auto * result = new llama_sampler; - - *result = smpl; + auto * result = new llama_sampler { + /* .params = */ smpl.params, + /* .vocab = */ smpl.vocab, + + /* .rng = */ smpl.rng, + + /* .mirostat_mu = */ smpl.mirostat_mu, + /* .prev = */ smpl.prev, + /* .constraints = */ {}, + /* .cur = */ {}, + /* .cur_p = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }; // copy the constraints objects result->constraints.clear(); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index dd9236392abfd..501f11de8698b 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -111,9 +111,9 @@ struct llama_sampler { // timing - mutable int64_t t_sample_us = 0; + mutable int64_t t_sample_us; - mutable int32_t n_sample = 0; + mutable int32_t n_sample; }; struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); diff --git a/src/llama.cpp b/src/llama.cpp index a40fc4c30d097..903be45750df1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17938,6 +17938,7 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 256, /*.mirostat =*/ 0, /*.mirostat_tau =*/ 5.00f, /*.mirostat_eta =*/ 0.10f, From a0b91214b40da723c1069f1ed3900edd8154318b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 13:54:32 +0300 Subject: [PATCH 13/47] cont : use new API in examples ggml-ci --- common/sampling.cpp | 114 ++++-- common/sampling.h | 31 +- examples/batched/batched.cpp | 23 +- examples/gritlm/gritlm.cpp | 10 +- examples/infill/infill.cpp | 24 +- examples/llava/llava-cli.cpp | 10 +- examples/llava/minicpmv-cli.cpp | 16 +- examples/lookahead/lookahead.cpp | 17 +- examples/lookup/lookup.cpp | 11 +- examples/main/main.cpp | 4 +- examples/parallel/parallel.cpp | 12 +- examples/passkey/passkey.cpp | 9 +- examples/save-load-state/save-load-state.cpp | 26 +- examples/server/server.cpp | 48 +-- examples/simple/simple.cpp | 9 +- examples/speculative/speculative.cpp | 52 +-- include/llama.h | 138 +------- src/llama-sampling.cpp | 211 ++++++------ src/llama-sampling.h | 43 +-- src/llama.cpp | 344 +------------------ tests/test-sampling.cpp | 30 +- 21 files changed, 380 insertions(+), 802 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 4dfbe9021ec72..4e88432245ed6 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,6 +2,16 @@ #include "common.h" +struct gpt_sampler { + gpt_sampler_params params; + + struct llama_constraint * bias; + struct llama_constraint * pnlt; + struct llama_constraint * grmr; + + struct llama_sampler * smpl; +}; + std::string gpt_sampler_params::print_all() const { char result[1024]; @@ -33,8 +43,6 @@ std::string gpt_sampler_params::print_constraints() const { } struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { - gpt_sampler * result = new gpt_sampler(); - llama_sampler_params lparams = llama_sampler_default_params(); lparams.seed = params.seed; @@ -43,21 +51,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st lparams.mirostat_tau = params.mirostat_tau; lparams.mirostat_eta = params.mirostat_eta; - result->smpl = llama_sampler_init(model, lparams); - - llama_sampler_add_constraint(result->smpl, llama_constraint_init_logit_bias( - model, - params.logit_bias.size(), - params.logit_bias.data())); - - llama_sampler_add_constraint(result->smpl, llama_constraint_init_penalties( - model, - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos)); + auto * result = new gpt_sampler { + .params = params, + .bias = llama_constraint_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data()), + .pnlt = llama_constraint_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos), + .grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), + .smpl = llama_sampler_init(model, lparams) + }; for (const auto & cnstr : params.constraints) { switch (cnstr) { @@ -84,14 +94,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st } } - result->grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"); - return result; } void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { + llama_constraint_free(gsmpl->bias); + llama_constraint_free(gsmpl->pnlt); llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->smpl); delete gsmpl; @@ -121,18 +132,28 @@ void gpt_sampler_reset (struct gpt_sampler * gsmpl) { llama_sampler_reset(gsmpl->smpl); } +void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) { + llama_sampler_set_logits(gsmpl->smpl, logits); +} + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { + return llama_sampler_get_candidates(gsmpl->smpl); +} + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { return llama_sampler_last(gsmpl->smpl); } +void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { + llama_print_timings(ctx, gsmpl->smpl); +} + static llama_token gpt_sampler_sample( struct llama_sampler * smpl, struct llama_token_data_array * cur_p, float temp, int mirostat, int n_probs) { - GGML_ASSERT(cur_p != nullptr && "candidates array must be provided"); - llama_token res = 0; if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { @@ -142,6 +163,7 @@ static llama_token gpt_sampler_sample( // greedy sampling, no probs res = llama_sampler_sample_greedy(smpl, cur_p, false); } else { + // apply all sampling constraints and then sample llama_sampler_apply(smpl, cur_p); if (mirostat != 0) { @@ -167,42 +189,62 @@ static llama_token gpt_sampler_sample( return res; } -llama_token gpt_sampler_sample( - struct gpt_sampler * gsmpl, - struct llama_context * ctx, - int idx) { +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { const auto & params = gsmpl->params; + auto & bias = gsmpl->bias; + auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; auto & smpl = gsmpl->smpl; + auto * cur_p = llama_sampler_get_candidates(smpl); + llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); - auto * cur_p = llama_sampler_get_candidates(smpl); + llama_constraint_apply(bias, cur_p); + llama_constraint_apply(pnlt, cur_p); // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); - // create an array with a single token data element for the sampled id - llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + // check if it the sampled token fits the grammar + { + llama_token_data single_token_data = { id, 1.0f, 0.0f }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; - llama_constraint_apply(grmr, &single_token_data_array); + llama_constraint_apply(grmr, &single_token_data_array); - // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY - const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; - if (is_valid) { - return id; + // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY + const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; + if (is_valid) { + return id; + } } - // if the token is not valid, sample again, after applying the grammar constraints + // if the token is not valid, sample again, first apply the grammar constraints and then sample llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_constraint_apply(bias, cur_p); + llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); } +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { + GGML_ASSERT(candidates != nullptr); + + llama_constraint_apply(gsmpl->grmr, candidates); +} + +llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { + return llama_sampler_sample_dist(gsmpl->smpl, candidates); +} + +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs) { + return llama_sampler_sample_greedy(gsmpl->smpl, candidates, probs); +} + std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { auto & smpl = gsmpl->smpl; diff --git a/common/sampling.h b/common/sampling.h index 4efa4a17ce4ae..8cb3da762019e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,15 +60,12 @@ struct gpt_sampler_params { std::string print_constraints() const; }; -struct gpt_sampler { - gpt_sampler_params params; - - struct llama_constraint * grmr = nullptr; - - struct llama_sampler * smpl = nullptr; -}; - -// llama_sampler API overload +// gpt_sampler extends llama_sampler with additional functionality: +// +// - grammar support +// - custom sampler logic based on the paramerters +// +struct gpt_sampler; struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); @@ -79,8 +76,14 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); +void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); +void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); + // common sampling implementation: // // - set logits @@ -88,10 +91,12 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token gpt_sampler_sample( - struct gpt_sampler * gsmpl, - struct llama_context * ctx, - int idx); +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); + +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates); + +llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * candidates); +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs); // helpers diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 4dfa19ce88af3..3052b96aeaf65 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,14 +64,15 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); - auto sparams = llama_sampling_default_params(); + auto sparams = llama_sampler_default_params(); sparams.seed = params.sparams.seed; - sparams.top_k = 40; - sparams.top_p = 0.9f; - sparams.temp = 0.4f; - llama_sampling * smpl = llama_sampling_init(model, sparams); + llama_sampler * smpl = llama_sampler_init(model, sparams); + + llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p)); + llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -174,15 +175,11 @@ int main(int argc, char ** argv) { const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - llama_sampling_set_logits(smpl, logits); - - llama_sampling_top_k(smpl, nullptr); - llama_sampling_top_p(smpl, nullptr); - llama_sampling_temp (smpl, nullptr); + llama_sampler_set_logits(smpl, logits); - const llama_token new_token_id = llama_sampling_sample_dist(smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); - //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -246,7 +243,7 @@ int main(int argc, char ** argv) { llama_batch_free(batch); - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 7d2ae77133ec5..978642cc35dfe 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -92,7 +92,7 @@ static std::vector> encode(llama_context * ctx, const std::ve return result; } -static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) { +static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) { std::string result; const llama_model * model = llama_get_model(ctx); @@ -122,9 +122,9 @@ static std::string generate(llama_context * ctx, llama_sampling * smpl, const st const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); - llama_token token = llama_sampling_sample_greedy(smpl, nullptr); + llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false); if (token == eos_token) { break; } @@ -171,7 +171,7 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -212,7 +212,7 @@ int main(int argc, char * argv[]) { std::string response = generate(ctx, smpl, prompt, true); } - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 371232421b71c..9f9f81a7f44ff 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -33,7 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; -static llama_sampling ** g_smpl; +static gpt_sampler ** g_smpl; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, *g_smpl); + gpt_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -167,7 +167,7 @@ int main(int argc, char ** argv) { llama_model * model = nullptr; llama_context * ctx = nullptr; - llama_sampling * smpl = nullptr; + gpt_sampler * smpl = nullptr; g_model = &model; g_ctx = &ctx; @@ -345,7 +345,7 @@ int main(int argc, char ** argv) { std::vector embd; - smpl = llama_sampling_init(model, sparams); + smpl = gpt_sampler_init(model, sparams); while (n_remain != 0 || params.interactive) { // predict @@ -417,9 +417,9 @@ int main(int argc, char ** argv) { embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - const llama_token id = llama_sampling_sample(smpl, ctx, -1); + const llama_token id = gpt_sampler_sample(smpl, ctx, -1); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); // LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str()); @@ -440,7 +440,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(smpl, embd_inp[n_consumed], false); + gpt_sampler_accept(smpl, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -472,7 +472,7 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ + if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str()); @@ -538,7 +538,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(model, llama_sampling_last(smpl))) { + else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) { LOG("found EOS token\n"); if (params.interactive) { @@ -611,7 +611,7 @@ int main(int argc, char ** argv) { if (n_past > 0) { if (is_interacting) { - llama_sampling_reset(smpl); + gpt_sampler_reset(smpl); } is_interacting = false; } @@ -634,13 +634,13 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); - llama_sampling_free(smpl); + gpt_sampler_free(smpl); llama_backend_free(); #ifndef LOG_DISABLE_LOGS diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index d01c5909e6f68..63a75c4a34ca2 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n return true; } -static const char * sample(struct llama_sampling * smpl, +static const char * sample(struct gpt_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); - llama_sampling_accept(smpl, id, true); + const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); + gpt_sampler_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); if (!smpl) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); printf("\n"); } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index c041fe530e987..15f258b91169f 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e LOG_TEE("%s: image token past: %d\n", __func__, n_past); } -static const char * sample(struct llama_sampling * smpl, +static const char * sample(struct gpt_sampler * smpl, struct llama_context * ctx_llama, int * n_past) { - const llama_token id = llama_sampling_sample(smpl, ctx_llama, -1); - llama_sampling_accept(smpl, id, true); + const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1); + gpt_sampler_accept(smpl, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri return ctx_llava; } -static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ +static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){ std::string user_prompt = prompt; int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip); if (!is_first) { @@ -238,11 +238,11 @@ static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_ LOG_TEE("\n"); - struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams); + struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams); return smpl; } -static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){ +static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){ const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past); return tmp; @@ -296,7 +296,7 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); }else { while (true) { LOG_TEE(""); @@ -315,7 +315,7 @@ int main(int argc, char ** argv) { if (strstr(response.c_str(), "")) break; // minicpm-v fflush(stdout); } - llama_sampling_free(smpl); + gpt_sampler_free(smpl); } } printf("\n"); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 2bd31d00268a2..8b461555bb594 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -117,7 +117,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); // verification n-grams std::vector ngrams_cur(G); @@ -158,9 +158,9 @@ int main(int argc, char ** argv) { // sample first token { - id = llama_sampling_sample(smpl, ctx, 0); + id = gpt_sampler_sample(smpl, ctx, 0); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -283,9 +283,9 @@ int main(int argc, char ** argv) { } // sample the next token - id = llama_sampling_sample(smpl, ctx, i_batch); + id = gpt_sampler_sample(smpl, ctx, i_batch); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); // print { @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { if (v == 0) { // sample from the last level for (int i = 0; i < W; i++) { - tokens_j[N - 2][i] = llama_sampling_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); + tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i); } } else { for (int i = 0; i < W; i++) { @@ -467,10 +467,11 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); + + gpt_sampler_free(smpl); llama_kv_cache_view_free(&kvc_view); - llama_sampling_free(smpl); llama_batch_free(batch); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index da4d57a518754..da3583f3c00f2 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -104,7 +104,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling * smpl = llama_sampling_init(model, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams); std::vector draft; @@ -128,9 +128,9 @@ int main(int argc, char ** argv){ int i_dft = 0; while (true) { // sample from the target model - llama_token id = llama_sampling_sample(smpl, ctx, i_dft); + llama_token id = gpt_sampler_sample(smpl, ctx, i_dft); - llama_sampling_accept(smpl, id, true); + gpt_sampler_accept(smpl, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -239,9 +239,10 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx, smpl); + gpt_print_timings(ctx, smpl); + + gpt_sampler_free(smpl); - llama_sampling_free(smpl); llama_batch_free(batch_tgt); llama_free(ctx); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 88202b800762b..1b706efbc2fa6 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx, (*g_smpl)->smpl); + gpt_print_timings(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -928,7 +928,7 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx, smpl->smpl); + gpt_print_timings(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); gpt_sampler_free(smpl); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7ce982a92e497..7422042db268f 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -51,7 +51,7 @@ static std::vector k_prompts = { struct client { ~client() { if (smpl) { - llama_sampling_free(smpl); + gpt_sampler_free(smpl); } } @@ -72,7 +72,7 @@ struct client { std::string prompt; std::string response; - struct llama_sampling * smpl = nullptr; + struct gpt_sampler * smpl = nullptr; }; static void print_date_time() { @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.smpl = llama_sampling_init(model, params.sparams); + client.smpl = gpt_sampler_init(model, params.sparams); } std::vector tokens_system; @@ -253,7 +253,7 @@ int main(int argc, char ** argv) { client.prompt = client.input + "\nAssistant:"; client.response = ""; - llama_sampling_reset(client.smpl); + gpt_sampler_reset(client.smpl); // do not prepend BOS because we have a system prompt! std::vector tokens_prompt; @@ -341,9 +341,9 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d\n", // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch); - const llama_token id = llama_sampling_sample(client.smpl, ctx, client.i_batch - i); + const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i); - llama_sampling_accept(client.smpl, id, true); + gpt_sampler_accept(client.smpl, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 0992ccc3c1808..b287d8403db7b 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,7 +83,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // tokenize the prompt std::vector tokens_list; @@ -218,10 +218,10 @@ int main(int argc, char ** argv) { { const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -262,9 +262,10 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); + llama_sampler_free(smpl); + llama_batch_free(batch); - llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 02f7a93ebac19..01f66886e33f8 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -38,10 +38,10 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling_params sparams = llama_sampling_default_params(); + llama_sampler_params sparams = llama_sampler_default_params(); sparams.seed = params.sparams.seed; - llama_sampling * smpl = llama_sampling_init(model, sparams); + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -71,9 +71,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); - auto next_token = llama_sampling_sample_dist(smpl, nullptr); + auto next_token = llama_sampler_sample_dist(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl2 = llama_sampling_init(model, sparams); + llama_sampler * smpl2 = llama_sampler_init(model, sparams); printf("\nsecond run: %s", params.prompt.c_str()); @@ -128,9 +128,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx2); - llama_sampling_set_logits(smpl2, logits); + llama_sampler_set_logits(smpl2, logits); - auto next_token = llama_sampling_sample_dist(smpl2, nullptr); + auto next_token = llama_sampler_sample_dist(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -157,7 +157,7 @@ int main(int argc, char ** argv) { // make new context auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampling * smpl3 = llama_sampling_init(model, sparams); + llama_sampler * smpl3 = llama_sampler_init(model, sparams); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -217,9 +217,9 @@ int main(int argc, char ** argv) { for (auto i = 0; i < params.n_predict; i++) { const auto * logits = llama_get_logits(ctx3); - llama_sampling_set_logits(smpl3, logits); + llama_sampler_set_logits(smpl3, logits); - auto next_token = llama_sampling_sample_dist(smpl3, nullptr); + auto next_token = llama_sampler_sample_dist(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); @@ -236,9 +236,9 @@ int main(int argc, char ** argv) { printf("\n"); - llama_sampling_free(smpl); - llama_sampling_free(smpl2); - llama_sampling_free(smpl3); + llama_sampler_free(smpl); + llama_sampler_free(smpl2); + llama_sampler_free(smpl3); llama_free(ctx3); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 139e503b9eb29..03e512e0343e4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -170,10 +170,10 @@ struct server_slot { // sampling json json_schema; - struct gpt_sampling_params sparams; + struct gpt_sampler_params sparams; + struct gpt_sampler * smpl = nullptr; llama_token sampled; - llama_sampling * smpl = nullptr; int32_t ga_i = 0; // group-attention state int32_t ga_n = 1; // group-attention factor @@ -653,7 +653,7 @@ struct server_context { // Clear any sampling context for (server_slot & slot : slots) { if (slot.smpl != nullptr) { - llama_sampling_free(slot.smpl); + gpt_sampler_free(slot.smpl); } } @@ -1027,26 +1027,26 @@ struct server_context { } { - const auto & samplers = data.find("samplers"); - if (samplers != data.end() && samplers->is_array()) { - std::vector sampler_names; - for (const auto & sampler_name : *samplers) { - if (sampler_name.is_string()) { - sampler_names.emplace_back(sampler_name); + const auto & constraints = data.find("samplers"); + if (constraints != data.end() && constraints->is_array()) { + std::vector constraint_names; + for (const auto & name : *constraints) { + if (name.is_string()) { + constraint_names.emplace_back(name); } } - slot.sparams.samplers = llama_sampling_types_from_names(sampler_names, false); + slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false); } else { - slot.sparams.samplers = default_sparams.samplers; + slot.sparams.constraints = default_sparams.constraints; } } { if (slot.smpl != nullptr) { - llama_sampling_free(slot.smpl); + gpt_sampler_free(slot.smpl); } - slot.smpl = llama_sampling_init(model, slot.sparams); + slot.smpl = gpt_sampler_init(model, slot.sparams); if (slot.smpl == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -1253,10 +1253,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - std::vector samplers; - samplers.reserve(slot.sparams.samplers.size()); - for (const auto & sampler : slot.sparams.samplers) { - samplers.emplace_back(llama_sampling_type_to_str(sampler)); + std::vector constraints; + constraints.reserve(slot.sparams.constraints.size()); + for (const auto & constraint : slot.sparams.constraints) { + constraints.emplace_back(gpt_constraint_type_to_str(constraint)); } return json { @@ -1290,7 +1290,7 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", samplers}, + {"samplers", constraints}, }; } @@ -2084,7 +2084,7 @@ struct server_context { GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx); } - llama_sampling_reset(slot.smpl); + gpt_sampler_reset(slot.smpl); if (!slot.params.cache_prompt) { slot.n_past_se = 0; @@ -2097,7 +2097,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.smpl, slot.cache_tokens[i], false); + gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false); } } } @@ -2150,7 +2150,7 @@ struct server_context { slot.n_past_se = 0; slot.ga_i = 0; // TODO: is the system prompt ever in the sampling context? - llama_sampling_reset(slot.smpl); + gpt_sampler_reset(slot.smpl); } // remove the non-common part from the cache @@ -2332,9 +2332,9 @@ struct server_context { } completion_token_output result; - const llama_token id = llama_sampling_sample(slot.smpl, ctx, slot.i_batch - i); + const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); - llama_sampling_accept(slot.smpl, id, true); + gpt_sampler_accept(slot.smpl, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2345,7 +2345,7 @@ struct server_context { result.tok = id; - const auto * cur_p = llama_sampling_get_candidates(slot.smpl); + const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 // fix if necessary diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 674158b857c4c..ffaa609cb8b26 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,7 +55,7 @@ int main(int argc, char ** argv) { return 1; } - llama_sampling * smpl = llama_sampling_init(model, llama_sampling_default_params()); + llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); // tokenize the prompt @@ -114,10 +114,10 @@ int main(int argc, char ** argv) { { const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - llama_sampling_set_logits(smpl, logits); + llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nullptr); + const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -159,8 +159,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); llama_batch_free(batch); - - llama_sampling_free(smpl); + llama_sampler_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e950733665303..bb26f8eb522fa 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -21,7 +21,7 @@ struct seq_draft { std::vector tokens; std::vector> dists; - struct llama_sampling * smpl; + struct gpt_sampler * smpl = nullptr; }; int main(int argc, char ** argv) { @@ -180,14 +180,14 @@ int main(int argc, char ** argv) { bool has_eos = false; // target model sampling context (reuse the llama_context's sampling instance) - struct llama_sampling * smpl = llama_sampling_init(model_tgt, params.sparams); + struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); // draft sequence data std::vector drafts(n_seq_dft); for (int s = 0; s < n_seq_dft; ++s) { - // allocate llama_sampling for each draft sequence - drafts[s].smpl = llama_sampling_init(model_dft, params.sparams); + // allocate gpt_sampler for each draft sequence + drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -229,13 +229,14 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sparams.temp > 0) { // stochastic verification + const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft])); + gpt_sampler_set_logits(smpl, logits); - auto & dist_tgt = *llama_sampling_get_candidates(smpl); + auto & dist_tgt = *gpt_sampler_get_candidates(smpl); - llama_sampling_grammar(smpl, &dist_tgt); - llama_sampling_softmax(smpl, &dist_tgt); + gpt_sampler_apply_grammar(smpl, &dist_tgt); + gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax float p_tgt = 0.0f; float p_dft = 0.0f; @@ -280,7 +281,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(smpl, token_id, true); + gpt_sampler_accept(smpl, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -334,8 +335,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sampling_sample_dist(smpl, &dist_tgt); - llama_sampling_accept(smpl, token_id, true); + token_id = gpt_sampler_sample_dist(smpl, &dist_tgt); + gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -344,9 +345,9 @@ int main(int argc, char ** argv) { // sample from the target model LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); - token_id = llama_sampling_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); + token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(smpl, token_id, true); + gpt_sampler_accept(smpl, token_id, true); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str()); @@ -436,7 +437,10 @@ int main(int argc, char ** argv) { break; } - llama_sampling_cp(smpl, drafts[0].smpl); + if (drafts[0].smpl) { + gpt_sampler_free(drafts[0].smpl); + } + drafts[0].smpl = gpt_sampler_cp(smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -465,9 +469,9 @@ int main(int argc, char ** argv) { continue; } - llama_sampling_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); + gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); - const auto * cur_p = llama_sampling_get_candidates(drafts[s].smpl); + const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl); for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", @@ -505,7 +509,11 @@ int main(int argc, char ** argv) { drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].smpl, drafts[n_seq_cur].smpl); + if (drafts[n_seq_cur].smpl) { + gpt_sampler_free(drafts[n_seq_cur].smpl); + } + drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl); + sa.push_back(n_seq_cur); @@ -521,7 +529,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - llama_sampling_accept(drafts[s].smpl, id, true); + gpt_sampler_accept(drafts[s].smpl, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists @@ -597,14 +605,14 @@ int main(int argc, char ** argv) { LOG_TEE("\ndraft:\n"); // TODO: print sampling/grammar timings for all drafts - llama_print_timings(ctx_dft, nullptr); + gpt_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt, smpl); + gpt_print_timings(ctx_tgt, smpl); - llama_sampling_free(smpl); + gpt_sampler_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { - llama_sampling_free(drafts[s].smpl); + gpt_sampler_free(drafts[s].smpl); } llama_batch_free(batch_dft); diff --git a/include/llama.h b/include/llama.h index 49011b7fe5ec6..8a02800ce8d5e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -46,9 +46,6 @@ #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ #define LLAMA_STATE_SEQ_VERSION 2 -// TODO: remove before merge -#define LLAMA_MAX_SAMPLERS 16 - #ifdef __cplusplus extern "C" { #endif @@ -1001,133 +998,6 @@ extern "C" { // // Sampling API - // TODO: remove before merge - // - - // TODO: llama_model should become llama_vocab - //LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params); - - //LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - - //// Copies the internal state of the sampler (rng, prev, params, grammar, etc.) - //LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - - //// - clear prev token - //// - reset grammar state - //LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl); - - //// Sampling parameter mutation - //// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable - //LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); - //LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - - //// Set the logits from which to sample. - //// This call initializes the internal token candidates array. - //// The internal candidates are implicitly used by the sampling API below when no candidates are provided. - //LLAMA_API void llama_sampling_set_logits( - // struct llama_sampling * smpl, - // const float * logits); - - ///// @details Returns the current candidate tokens. - //LLAMA_API llama_token_data_array * llama_sampling_get_candidates( - // struct llama_sampling * smpl); - - //// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object. - //// Each function can accept an array of token candidates. If the candidates are not provided, the internal - //// candidates are used. The internal candidates are initialized by llama_sampling_set_logits(). - - ///// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - //LLAMA_API void llama_sampling_softmax( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - //LLAMA_API void llama_sampling_top_k( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - //LLAMA_API void llama_sampling_top_p( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - //LLAMA_API void llama_sampling_min_p( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - //LLAMA_API void llama_sampling_tail_free( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - //LLAMA_API void llama_sampling_typical( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Apply temperature and entropy - //LLAMA_API void llama_sampling_temp( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Apply constraints from grammar - //LLAMA_API void llama_sampling_grammar( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - ///// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - //LLAMA_API void llama_sampling_penalties( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Mirostat algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. - //LLAMA_API llama_token llama_sampling_sample_mirostat( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Selects the token with the highest probability. - ///// Does not compute the token probabilities. Use llama_sampling_softmax() instead. - //LLAMA_API llama_token llama_sampling_sample_greedy( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Randomly selects a token from the candidates based on their probability distribution. - //LLAMA_API llama_token llama_sampling_sample_dist( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers"). - //LLAMA_API llama_token llama_sampling_sample( - // struct llama_sampling * smpl, - // llama_token_data_array * candidates); - - ///// @details Accepts the sampled token into the sampling context. - ///// - adds it to "prev" tokens - ///// - updates the grammar state (if apply_grammar is true) - //LLAMA_API void llama_sampling_accept( - // struct llama_sampling * smpl, - // llama_token token, - // bool apply_grammar); - - ///// @details Get the number of accepted tokens so far (max of n_prev) - //LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl); - - ///// @details Get the ith accepted token - ///// @param ith [0, n_prev), ith == 0 is the last accepted token. - ///// returns LLAMA_TOKEN_NULL if ith is out of bounds - //LLAMA_API llama_token llama_sampling_prev( - // const struct llama_sampling * smpl, - // int32_t ith); - - ///// @details Get the last accepted token - ///// Same as llama_sampling_prev(smpl, 0) - ///// returns LLAMA_TOKEN_NULL if there are no accepted tokens - //LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl); - - // - // Sampling v2 API // // - Constraints // The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their @@ -1203,7 +1073,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); - // do not call if used with llama_sampler_add_constraint + // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); @@ -1221,11 +1091,7 @@ extern "C" { LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); - - // TODO: should this take ownership so the user does not need to call llama_constraint_free - // or should just make a reference to the constraint so that it can be reused in multiple llama_sampler? - // - // seems better to take the ownership, otherwise the copying of the sampler will be more complicated + // important: takes ownership of the constraint object and will free it in llama_sampler_free LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index abf9d5a8ee789..eca2adb2be325 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -421,113 +421,8 @@ void llama_constraint_penalties_impl( candidates->sorted = false; } -llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_constraint_softmax_impl(candidates); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_constraint_top_k_impl(candidates, int(k), 1); - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_constraint_softmax_impl(candidates); - - // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > mu; - })); - - if (candidates->size == 0) { - candidates->size = 1; - } - - // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(candidates); - - // Sample the next word X from the remaining words - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - - float observed_surprise = -log2f(candidates->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { - if (probs) { - // if probs are needed, we apply softmax to get the probabilities - llama_constraint_softmax_impl(candidates); - - // the candidates are sorted, so we can just return the first one - return candidates->data[0].id; - } - - // return the token with the highest logit - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - - return result; -} - -llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { - llama_constraint_softmax_impl(candidates); - - std::vector probs; - probs.reserve(candidates->size); - - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - llama_token result = candidates->data[idx].id; - - return result; -} - // -// sampling v2 +// sampling // // constraints @@ -1172,3 +1067,107 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } +llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_constraint_softmax_impl(candidates); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); + + // Sample the next word X using top-k sampling + llama_constraint_top_k_impl(candidates, int(k), 1); + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + mu = mu - eta * e; + + return X; +} + +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_constraint_softmax_impl(candidates); + + // Truncate the words with surprise values greater than mu + candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > mu; + })); + + if (candidates->size == 0) { + candidates->size = 1; + } + + // Normalize the probabilities of the remaining words + llama_constraint_softmax_impl(candidates); + + // Sample the next word X from the remaining words + llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + + // Compute error as the difference between observed surprise and target surprise value + size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + return candidate.id == X; + })); + + float observed_surprise = -log2f(candidates->data[X_idx].p); + float e = observed_surprise - tau; + + // Update mu using the learning rate and error + mu = mu - eta * e; + + return X; +} + +llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { + if (probs) { + // if probs are needed, we apply softmax to get the probabilities + llama_constraint_softmax_impl(candidates); + + // the candidates are sorted, so we can just return the first one + return candidates->data[0].id; + } + + // return the token with the highest logit + auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + return a.logit < b.logit; + }); + + llama_token result = max_iter->id; + + return result; +} + +llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { + llama_constraint_softmax_impl(candidates); + + std::vector probs; + probs.reserve(candidates->size); + + for (size_t i = 0; i < candidates->size; ++i) { + probs.push_back(candidates->data[i].p); + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + llama_token result = candidates->data[idx].id; + + return result; +} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 501f11de8698b..f60d5b95f86da 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,6 +10,7 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; +// TODO: tmp exposed, until tests start using llama_constraint void llama_constraint_softmax_impl (struct llama_token_data_array * candidates); void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); @@ -27,30 +28,6 @@ void llama_constraint_penalties_impl( float penalty_freq, float penalty_present); -/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); - -/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); - -llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); -llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); - - - -// -// sampling v2 -// - // constraints struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); @@ -128,3 +105,21 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); + +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); + +llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); +llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 903be45750df1..2b54a1ff337ac 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20609,346 +20609,6 @@ int32_t llama_chat_apply_template( // sampling // -//struct llama_sampling * llama_sampling_init(const struct llama_model * model, struct llama_sampling_params params) { -// return llama_sampling_init_impl(model->vocab, params); -//} - -//void llama_sampling_free(struct llama_sampling * smpl) { -// if (smpl == nullptr) { -// return; -// } - -// llama_sampling_free_impl(smpl); -//} - -//struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { -// return llama_sampling_cp_impl(*smpl); -//} - -//void llama_sampling_reset(struct llama_sampling * smpl) { -// llama_sampling_reset_impl(*smpl); -//} - -//void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { -// llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root); -//} - -//void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { -// llama_sampling_set_logit_bias_impl(*smpl, n_logit_bias, logit_bias); -//} - -//void llama_sampling_set_logits(struct llama_sampling * smpl, const float * logits) { -// const int n_vocab = smpl->vocab.n_vocab; - -// smpl->cur.resize(n_vocab); - -// for (llama_token token_id = 0; token_id < n_vocab; token_id++) { -// smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; -// } - -// for (const auto & lb : smpl->logit_bias) { -// smpl->cur[lb.token].logit += lb.bias; -// } - -// if (smpl->params.ignore_eos) { -// smpl->cur[llama_token_eos_impl(smpl->vocab)].logit = -INFINITY; -// } - -// smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; - -// // apply penalties -// { -// const float nl_logit = smpl->cur[llama_token_nl_impl(smpl->vocab)].logit; - -// llama_sampling_penalties(smpl, &smpl->cur_p); - -// if (!smpl->params.penalize_nl) { -// for (size_t idx = 0; idx < smpl->cur_p.size; idx++) { -// if (smpl->cur_p.data[idx].id == llama_token_nl_impl(smpl->vocab)) { -// smpl->cur_p.data[idx].logit = nl_logit; -// break; -// } -// } -// } -// } -//} - -//llama_token_data_array * llama_sampling_get_candidates(struct llama_sampling * smpl) { -// return &smpl->cur_p; -//} - -//void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_softmax_impl(candidates); -//} - -//void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep); -//} - -//void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep); -//} - -//void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep); -//} - -//void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep); -//} - -//void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// llama_sampling_typical_impl(candidates, smpl->params.typ_p, smpl->params.min_keep); -//} - -//void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// if (smpl->params.dynatemp_range > 0) { -// const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range); -// const float dynatemp_max = std::max(0.0f, smpl->params.temp + smpl->params.dynatemp_range); - -// llama_sampling_entropy_impl(candidates, dynatemp_min, dynatemp_max, smpl->params.dynatemp_exponent); -// } else { -// llama_sampling_temp_impl(candidates, smpl->params.temp); -// } -//} - -//void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_grammar_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// if (smpl->grammar) { -// llama_sampling_grammar_impl(candidates, *smpl->grammar); - -// smpl->n_grammar++; -// } -//} - -//void llama_sampling_penalties( -// struct llama_sampling * smpl, -// llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const size_t penalty_last_n = std::min(smpl->params.penalty_last_n, smpl->prev.size()); - -// const float penalty_repeat = smpl->params.penalty_repeat; -// const float penalty_freq = smpl->params.penalty_freq; -// const float penalty_present = smpl->params.penalty_present; - -// if ((penalty_last_n == 0) || -// (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { -// return; -// } - -// // Create a frequency map to count occurrences of each token in last_tokens -// // TODO: move to sampling state and avoid reallocation -// llama_token_cnt token_count; -// for (size_t i = 0; i < penalty_last_n; ++i) { -// token_count[smpl->prev.rat(i)]++; -// } - -// llama_sampling_penalties_impl(candidates, token_count, penalty_repeat, penalty_freq, penalty_present); -//} - -//llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const auto type = smpl->params.mirostat; - -// llama_token res; - -// if (type == 1) { -// res = llama_sampling_sample_mirostat_impl(candidates, -// smpl->rng, -// smpl->params.mirostat_tau, -// smpl->params.mirostat_eta, -// 100, -// smpl->vocab.n_vocab, -// smpl->mirostat_mu); -// } else if (type == 2) { -// res = llama_sampling_sample_mirostat_v2_impl(candidates, -// smpl->rng, -// smpl->params.mirostat_tau, -// smpl->params.mirostat_eta, -// smpl->mirostat_mu); -// } else { -// GGML_ABORT("invalid mirostat type: %d", type); -// } - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// auto res = llama_sampling_sample_greedy_impl(candidates); - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample_dist(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// auto res = llama_sampling_sample_dist_impl(candidates, smpl->rng); - -// smpl->n_sample++; - -// return res; -//} - -//llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { -// time_meas tm(smpl->t_sample_us); - -// if (candidates == nullptr) { -// candidates = &smpl->cur_p; -// } - -// const auto & params = smpl->params; - -// const float temp = params.temp; -// const int mirostat = params.mirostat; - -// auto & cur_p = candidates; - -// llama_token res = 0; - -// if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) { -// // greedy sampling, with probs -// llama_sampling_softmax_impl(cur_p); -// res = cur_p->data[0].id; -// } else if (temp == 0.0f) { -// // greedy sampling, no probs -// res = llama_sampling_sample_greedy(smpl, cur_p); -// } else { -// if (mirostat != 0) { -// llama_sampling_temp(smpl, cur_p); -// res = llama_sampling_sample_mirostat(smpl, cur_p); -// } else { -// for (const auto & sampler : smpl->samplers) { -// switch (sampler) { -// case LLAMA_CONSTRAINT_TYPE_TOP_K: llama_sampling_top_k_impl (cur_p, smpl->params.top_k, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TFS_Z: llama_sampling_tail_free_impl(cur_p, smpl->params.tfs_z, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TYPICAL_P: llama_sampling_typical_impl (cur_p, smpl->params.typ_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TOP_P: llama_sampling_top_p_impl (cur_p, smpl->params.top_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_MIN_P: llama_sampling_min_p_impl (cur_p, smpl->params.min_p, smpl->params.min_keep); break; -// case LLAMA_CONSTRAINT_TYPE_TEMPERATURE: llama_sampling_temp_impl (cur_p, temp); break; -// default : break; -// } -// } - -// res = llama_sampling_sample_dist(smpl, cur_p); - -// //{ -// // const int n_top = 10; -// // LOG("top %d candidates:\n", n_top); - -// // for (int i = 0; i < n_top; i++) { -// // const llama_token id = cur_p.data[i].id; -// // (void)id; // To avoid a warning that id is unused when logging is disabled. -// // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); -// // } -// //} - -// //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); -// } -// } - -// smpl->n_sample++; - -// return res; -//} - -//void llama_sampling_accept( -// struct llama_sampling * smpl, -// llama_token token, -// bool apply_grammar) { -// time_meas tm(smpl->t_accept_us); - -// llama_sampling_accept_impl(*smpl, token, apply_grammar); - -// smpl->n_accept++; -//} - -//int llama_sampling_n_prev(const struct llama_sampling * smpl) { -// return llama_sampling_n_prev_impl(*smpl); -//} - -//llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) { -// return llama_sampling_prev_impl(*smpl, ith); -//} - -//llama_token llama_sampling_last(const struct llama_sampling * smpl) { -// return llama_sampling_prev_impl(*smpl, 0); -//} - -// -// sampling v2 -// - struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { return llama_constraint_init_top_k_impl(k, min_keep); } @@ -21070,6 +20730,10 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) { time_meas tm(smpl->t_sample_us); + if (candidates == nullptr) { + candidates = &smpl->cur_p; + } + llama_sampler_apply_impl(*smpl, candidates); } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index f5e32a741b23c..16eeaa1c8e01b 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -30,9 +30,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sampling_tail_free_impl(&candidates_p, z, 1); + llama_constraint_tail_free_impl(&candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -96,9 +96,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector Date: Wed, 4 Sep 2024 14:07:26 +0300 Subject: [PATCH 14/47] examples : fix build ggml-ci --- common/sampling.cpp | 10 +++++----- examples/batched.swift/Sources/main.swift | 20 +++++++++---------- examples/batched/batched.cpp | 4 ++-- .../llama/src/main/cpp/llama-android.cpp | 6 +++--- .../llama.cpp.swift/LibLlama.swift | 8 ++++---- 5 files changed, 24 insertions(+), 24 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 4e88432245ed6..b4063fe3168f5 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -52,12 +52,12 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st lparams.mirostat_eta = params.mirostat_eta; auto * result = new gpt_sampler { - .params = params, - .bias = llama_constraint_init_logit_bias( + /* .params = */ params, + /* .bias = */ llama_constraint_init_logit_bias( model, params.logit_bias.size(), params.logit_bias.data()), - .pnlt = llama_constraint_init_penalties( + /* .pnlt = */ llama_constraint_init_penalties( model, params.penalty_last_n, params.penalty_repeat, @@ -65,8 +65,8 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st params.penalty_present, params.penalize_nl, params.ignore_eos), - .grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), - .smpl = llama_sampler_init(model, lparams) + /* .grmr = */ llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), + /* .smpl = */ llama_sampler_init(model, lparams) }; for (const auto & cnstr : params.constraints) { diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 81763217a91a8..4d73ccd24ebfb 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -50,20 +50,24 @@ defer { llama_free(context) } -var sparams = llama_sampling_params() +var sparams = llama_sampler_params() sparams.top_k = 40 sparams.top_p = 0.9 sparams.temp = 0.4 -let smpl = llama_sampling_init(model, sparams) +let smpl = llama_sampler_init(model, sparams) guard smpl != nil else { print("Failed to initialize sampling") exit(1) } defer { - llama_sampling_free(smpl) + llama_sampler_free(smpl) } +llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1)); +llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1)); +llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4)); + let n_ctx = llama_n_ctx(context) print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") @@ -138,15 +142,11 @@ while n_cur <= n_len { var logits = llama_get_logits_ith(context, i_batch[i]) - llama_sampling_set_logits(smpl, logits) - - llama_sampling_top_k(smpl, nil) - llama_sampling_top_p(smpl, nil) - llama_sampling_temp (smpl, nil) + llama_sampler_set_logits(smpl, logits) - let new_token_id = llama_sampling_sample_dist(smpl, nil) + let new_token_id = llama_sampler_sample_dist(smpl, nil) - // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil); + // const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nil, false); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 3052b96aeaf65..0f35f6cd58775 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -71,7 +71,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_init(model, sparams); llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); - llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p)); + llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp)); if (ctx == NULL) { @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); - //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr); + //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index c33f55f720223..666e89764834d 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -386,7 +386,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto sampling = reinterpret_cast(sampling_pointer); + const auto sampling = reinterpret_cast(sampling_pointer); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); @@ -396,10 +396,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1); - llama_sampling_set_logits(sampling, logits); + llama_sampler_set_logits(sampling, logits); // sample the most likely token - const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr); + const auto new_token_id = llama_sampler_sample_greedy(sampling, nullptr, false); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 515170f679f82..930336b270f70 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,11 +43,11 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - self.sampling = llama_sampling_init(context, llama_sampling_default_params()) + self.sampling = llama_sampler_init(context, llama_sampler_default_params()) } deinit { - llama_sampling_free(sampling) + llama_sampler_free(sampling) llama_batch_free(batch) llama_free(context) llama_free_model(model) @@ -149,9 +149,9 @@ actor LlamaContext { let n_vocab = llama_n_vocab(model) let logits = llama_get_logits_ith(context, batch.n_tokens - 1) - llama_sampling_set_logits(sampling, logits); + llama_sampler_set_logits(sampling, logits); - new_token_id = llama_sampling_sample_greedy(sampling, nil) + new_token_id = llama_sampler_sample_greedy(sampling, nil, false) if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") From fdb52aa65704d821f916b34531e8d4cd2fa398eb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 14:17:19 +0300 Subject: [PATCH 15/47] common : fix gpt_sampler_cp ggml-ci --- common/sampling.cpp | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b4063fe3168f5..123c6b2a710af 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -110,12 +110,13 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { - gpt_sampler * result = new gpt_sampler(); - - result->grmr = llama_constraint_cp(gsmpl->grmr); - result->smpl = llama_sampler_cp(gsmpl->smpl); - - return result; + return new gpt_sampler { + /* .params = */ gsmpl->params, + /* .bias = */ llama_constraint_cp(gsmpl->bias), + /* .pnlt = */ llama_constraint_cp(gsmpl->pnlt), + /* .grmr = */ llama_constraint_cp(gsmpl->grmr), + /* .smpl = */ llama_sampler_cp(gsmpl->smpl) + }; } void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { @@ -145,7 +146,7 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { } void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { - llama_print_timings(ctx, gsmpl->smpl); + llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); } static llama_token gpt_sampler_sample( From ca5d21c17a6875336cae4a1ae91a2c02d9fc35bf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 14:26:23 +0300 Subject: [PATCH 16/47] grammar : fix reset call ggml-ci --- examples/batched.swift/Sources/main.swift | 3 --- src/llama-sampling.cpp | 13 +++++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 4d73ccd24ebfb..380040e572ecc 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -51,9 +51,6 @@ defer { } var sparams = llama_sampler_params() -sparams.top_k = 40 -sparams.top_p = 0.9 -sparams.temp = 0.4 let smpl = llama_sampler_init(model, sparams) guard smpl != nil else { diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index eca2adb2be325..a134fda95378b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -724,17 +724,18 @@ static struct llama_constraint_i llama_constraint_grammar_i = { }, /* .reset = */ [](struct llama_constraint * cnstr) { auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; - if (ctx->grammar) { - llama_grammar_free_impl(ctx->grammar); - ctx->grammar = nullptr; + if (!ctx->grammar) { + return; } - if (!ctx->grammar_str.empty()) { - ctx->grammar = llama_grammar_init_impl(nullptr, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); - } + auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str()); + + llama_grammar_free_impl(ctx->grammar); + ctx->grammar = grammar_new; }, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; + auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr); auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; From c024fe45b0728a31d2245ac6cf365fe4b0a67293 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 15:01:31 +0300 Subject: [PATCH 17/47] constraint : clean-up and simplify --- common/sampling.cpp | 14 +- common/sampling.h | 6 +- include/llama.h | 21 +- src/llama-grammar.cpp | 16 +- src/llama-grammar.h | 2 +- src/llama-sampling.cpp | 412 ++++++++++++++++++++-------------------- src/llama-sampling.h | 27 +-- src/llama.cpp | 42 ++-- tests/test-sampling.cpp | 174 +++++++++-------- 9 files changed, 357 insertions(+), 357 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 123c6b2a710af..34371bc241167 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -232,18 +232,18 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); } -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { - GGML_ASSERT(candidates != nullptr); +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { + GGML_ASSERT(cur_p != nullptr); - llama_constraint_apply(gsmpl->grmr, candidates); + llama_constraint_apply(gsmpl->grmr, cur_p); } -llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * candidates) { - return llama_sampler_sample_dist(gsmpl->smpl, candidates); +llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { + return llama_sampler_sample_dist(gsmpl->smpl, cur_p); } -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs) { - return llama_sampler_sample_greedy(gsmpl->smpl, candidates, probs); +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs) { + return llama_sampler_sample_greedy(gsmpl->smpl, cur_p, probs); } std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { diff --git a/common/sampling.h b/common/sampling.h index 8cb3da762019e..a04645a676a44 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -93,10 +93,10 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * candidates); +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); -llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * candidates); -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * candidates, bool probs); +llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); +llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs); // helpers diff --git a/include/llama.h b/include/llama.h index 8a02800ce8d5e..0f08c44c0825c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1027,11 +1027,11 @@ extern "C" { struct llama_constraint_i { // TODO: add name API - void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * candidates); // required - void (*reset) ( struct llama_constraint * cnstr); // can be NULL - struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL - void (*free) ( struct llama_constraint * cnstr); // can be NULL + void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL + void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required + void (*reset) ( struct llama_constraint * cnstr); // can be NULL + struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL + void (*free) ( struct llama_constraint * cnstr); // can be NULL if ctx is NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); @@ -1044,6 +1044,7 @@ extern "C" { llama_constraint_context_t ctx; }; + LLAMA_API struct llama_constraint * llama_constraint_init_softmax (); LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); @@ -1077,7 +1078,7 @@ extern "C" { LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); - LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates); + LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * cur_p); LLAMA_API void llama_constraint_reset (struct llama_constraint * cnstr); // samplers @@ -1095,11 +1096,11 @@ extern "C" { LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * candidates); - LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs); - LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates); + LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); + LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p); /// @details Get the number of accepted tokens so far (max of n_prev) LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 092a738aafe6b..a9813ebbfb228 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1069,7 +1069,7 @@ struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & gramma return result; } -void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) { +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * cur_p) { GGML_ASSERT(grammar.vocab != nullptr); bool allow_eog = false; @@ -1081,21 +1081,21 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ } std::vector, llama_partial_utf8>> candidates_decoded; - candidates_decoded.reserve(candidates->size); + candidates_decoded.reserve(cur_p->size); llama_grammar_candidates candidates_grammar; - candidates_grammar.reserve(candidates->size); + candidates_grammar.reserve(cur_p->size); - for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; + for (size_t i = 0; i < cur_p->size; ++i) { + const llama_token id = cur_p->data[i].id; const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (!allow_eog) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { - candidates->data[i].logit = -INFINITY; + cur_p->data[i].logit = -INFINITY; } else { candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); @@ -1104,7 +1104,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_ const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { - candidates->data[reject.index].logit = -INFINITY; + cur_p->data[reject.index].logit = -INFINITY; } } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 9b13354f67c74..6b9a2af8dd725 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -136,7 +136,7 @@ struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & gramma // TODO: move the API below as member functions of llama_grammar void llama_grammar_apply_impl( const struct llama_grammar & grammar, - llama_token_data_array * candidates); + llama_token_data_array * cur_p); void llama_grammar_accept_impl( struct llama_grammar & grammar, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index a134fda95378b..99e0edfd9d19c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -24,51 +24,51 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_constraint_softmax_impl(llama_token_data_array * candidates) { - GGML_ASSERT(candidates->size > 0); +static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { + GGML_ASSERT(cur_p->size > 0); // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }); - candidates->sorted = true; + cur_p->sorted = true; } - float max_l = candidates->data[0].logit; + float max_l = cur_p->data[0].logit; float cum_sum = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float p = expf(candidates->data[i].logit - max_l); - candidates->data[i].p = p; + for (size_t i = 0; i < cur_p->size; ++i) { + float p = expf(cur_p->data[i].logit - max_l); + cur_p->data[i].p = p; cum_sum += p; } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum; + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum; } } -void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, size_t min_keep) { +static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast - // if (k >= (int32_t)candidates->size) { + // if (k >= (int32_t)cur_p->size) { // return; // } if (k <= 0) { - k = candidates->size; + k = cur_p->size; } k = std::max(k, (int) min_keep); - k = std::min(k, (int) candidates->size); + k = std::min(k, (int) cur_p->size); // Sort scores in descending order - if (!candidates->sorted) { + if (!cur_p->sorted) { auto comp = [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }; if (k <= 128) { - std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp); + std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp); } else { constexpr int nbuckets = 128; constexpr float bucket_low = -10.0f; @@ -76,11 +76,11 @@ void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low); constexpr float bucket_inter = -bucket_low * bucket_scale; - std::vector bucket_idx(candidates->size); + std::vector bucket_idx(cur_p->size); std::vector histo(nbuckets, 0); - for (int i = 0; i < (int)candidates->size; ++i) { - const float val = candidates->data[i].logit; + for (int i = 0; i < (int)cur_p->size; ++i) { + const float val = cur_p->data[i].logit; int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low); ib = std::max(0, std::min(nbuckets-1, ib)); bucket_idx[i] = ib; @@ -102,10 +102,10 @@ void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, bucket_ptrs.push_back(ptr); ptr += histo[j]; } - for (int i = 0; i < (int)candidates->size; ++i) { + for (int i = 0; i < (int)cur_p->size; ++i) { int j = bucket_idx[i]; if (j >= ib) { - *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i]; + *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i]; } } @@ -118,27 +118,27 @@ void llama_constraint_top_k_impl(llama_token_data_array * candidates, int32_t k, } std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp); - std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data)); + std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data)); } - candidates->sorted = true; + cur_p->sorted = true; } - candidates->size = k; + cur_p->size = k; } -void llama_constraint_top_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +static void llama_constraint_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); // Compute the cumulative probabilities float cum_sum = 0.0f; - size_t last_idx = candidates->size; + size_t last_idx = cur_p->size; - for (size_t i = 0; i < candidates->size; ++i) { - cum_sum += candidates->data[i].p; + for (size_t i = 0; i < cur_p->size; ++i) { + cum_sum += cur_p->data[i].p; // Check if the running sum is at least p or if we have kept at least min_keep tokens // we set the last index to i+1 to indicate that the current iterate should be included in the set @@ -149,77 +149,77 @@ void llama_constraint_top_p_impl(llama_token_data_array * candidates, float p, s } // Resize the output vector to keep only the top-p tokens - candidates->size = last_idx; + cur_p->size = last_idx; } -void llama_constraint_min_p_impl(llama_token_data_array * candidates, float p, size_t min_keep) { - if (p <= 0.0f || !candidates->size) { +static void llama_constraint_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { + if (p <= 0.0f || !cur_p->size) { return; } bool min_p_applied = false; - // if the candidates aren't sorted, try the unsorted implementation first - if (!candidates->sorted) { + // if the cur_p aren't sorted, try the unsorted implementation first + if (!cur_p->sorted) { std::vector filtered_tokens; float max_logit = -FLT_MAX; - for (size_t i = 0; i < candidates->size; ++i) { - max_logit = std::max(max_logit, candidates->data[i].logit); + for (size_t i = 0; i < cur_p->size; ++i) { + max_logit = std::max(max_logit, cur_p->data[i].logit); } const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max - for (size_t i = 0; i < candidates->size; ++i) { - if (candidates->data[i].logit >= min_logit) { - filtered_tokens.push_back(candidates->data[i]); + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit >= min_logit) { + filtered_tokens.push_back(cur_p->data[i]); } } // if we have enough values the operation was a success if (filtered_tokens.size() >= min_keep) { - memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); - candidates->size = filtered_tokens.size(); + memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data)); + cur_p->size = filtered_tokens.size(); min_p_applied = true; } } - // if the candidates are sorted or the unsorted implementation failed, use this implementation + // if the cur_p are sorted or the unsorted implementation failed, use this implementation if (!min_p_applied) { // Sort the logits in descending order - if (!candidates->sorted) { - std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + if (!cur_p->sorted) { + std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit > b.logit; }); - candidates->sorted = true; + cur_p->sorted = true; } - const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max + const float min_logit = cur_p->data[0].logit + logf(p); // min logit for p_i >= p * p_max size_t i = 1; // first token always matches - for (; i < candidates->size; ++i) { - if (candidates->data[i].logit < min_logit && i >= min_keep) { + for (; i < cur_p->size; ++i) { + if (cur_p->data[i].logit < min_logit && i >= min_keep) { break; // prob too small } } // Resize the output vector to keep only the matching tokens - candidates->size = i; + cur_p->size = i; } } -void llama_constraint_tail_free_impl(llama_token_data_array * candidates, float z, size_t min_keep) { - if (z >= 1.0f || candidates->size <= 2) { +static void llama_constraint_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { + if (z >= 1.0f || cur_p->size <= 2) { return; } - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); // Compute the first and second derivatives - std::vector first_derivatives(candidates->size - 1); - std::vector second_derivatives(candidates->size - 2); + std::vector first_derivatives(cur_p->size - 1); + std::vector second_derivatives(cur_p->size - 2); for (size_t i = 0; i < first_derivatives.size(); ++i) { - first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p; + first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p; } for (size_t i = 0; i < second_derivatives.size(); ++i) { second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1]; @@ -246,7 +246,7 @@ void llama_constraint_tail_free_impl(llama_token_data_array * candidates, float } float cum_sum = 0.0f; - size_t last_idx = candidates->size; + size_t last_idx = cur_p->size; for (size_t i = 0; i < second_derivatives.size(); ++i) { cum_sum += second_derivatives[i]; @@ -258,10 +258,10 @@ void llama_constraint_tail_free_impl(llama_token_data_array * candidates, float } // Resize the output vector to keep only the tokens above the tail location - candidates->size = last_idx; + cur_p->size = last_idx; } -void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, size_t min_keep) { +static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -269,22 +269,22 @@ void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, } // Compute the softmax of logits and calculate entropy - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - entropy += -candidates->data[i].p * logf(candidates->data[i].p); + for (size_t i = 0; i < cur_p->size; ++i) { + entropy += -cur_p->data[i].p * logf(cur_p->data[i].p); } // Compute the absolute difference between negative log probability and entropy for each candidate std::vector shifted_scores; - for (size_t i = 0; i < candidates->size; ++i) { - float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy); + for (size_t i = 0; i < cur_p->size; ++i) { + float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy); shifted_scores.push_back(shifted_score); } // Sort tokens based on the shifted_scores and their corresponding indices - std::vector indices(candidates->size); + std::vector indices(cur_p->size); std::iota(indices.begin(), indices.end(), 0); std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) { @@ -297,7 +297,7 @@ void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, for (size_t i = 0; i < indices.size(); ++i) { size_t idx = indices[i]; - cum_sum += candidates->data[idx].p; + cum_sum += cur_p->data[idx].p; // Check if the running sum is greater than typical or if we have kept at least min_keep tokens if (cum_sum > p && i >= min_keep - 1) { @@ -307,39 +307,39 @@ void llama_constraint_typical_impl(llama_token_data_array * candidates, float p, } // Resize the output vector to keep only the locally typical tokens - std::vector new_candidates; + std::vector cur_p_new; for (size_t i = 0; i < last_idx; ++i) { size_t idx = indices[i]; - new_candidates.push_back(candidates->data[idx]); + cur_p_new.push_back(cur_p->data[idx]); } - // Replace the data in candidates with the new_candidates data - std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); - candidates->size = new_candidates.size(); - candidates->sorted = false; + // Replace the data in cur_p with the cur_p_new data + std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data); + cur_p->size = cur_p_new.size(); + cur_p->sorted = false; } -void llama_constraint_entropy_impl(llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { +static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates - if(candidates->size <= 1) { + if (cur_p->size <= 1) { return; } // Calculate maximum possible entropy - float max_entropy = -logf(1.0f / candidates->size); + float max_entropy = -logf(1.0f / cur_p->size); - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); // Calculate entropy of the softmax probabilities float entropy = 0.0f; - for (size_t i = 0; i < candidates->size; ++i) { - float prob = candidates->data[i].p; + for (size_t i = 0; i < cur_p->size; ++i) { + float prob = cur_p->data[i].p; if (prob > 0.0f) { // Ensure no log(0) entropy -= prob * logf(prob); } } - // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above) + // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above) float normalized_entropy = entropy / max_entropy; // Map the normalized entropy to the desired temperature range using the power function @@ -355,52 +355,52 @@ void llama_constraint_entropy_impl(llama_token_data_array * candidates, float mi #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= dyn_temp; + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= dyn_temp; } // Re-compute softmax probabilities after scaling logits with dynamic temperature - const double max_l_double = candidates->data[0].logit; + const double max_l_double = cur_p->data[0].logit; double cum_sum_double = 0.0; - for (size_t i = 0; i < candidates->size; ++i) { - double p = exp(candidates->data[i].logit - max_l_double); - candidates->data[i].p = p; // Store the scaled probability + for (size_t i = 0; i < cur_p->size; ++i) { + double p = exp(cur_p->data[i].logit - max_l_double); + cur_p->data[i].p = p; // Store the scaled probability cum_sum_double += p; } - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities } #ifdef DEBUG // Print the updated top 25 probabilities after temperature scaling LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n"); - for (size_t i = 0; i < 25 && i < candidates->size; ++i) { - LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); + for (size_t i = 0; i < 25 && i < cur_p->size; ++i) { + LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f); } #endif } -void llama_constraint_temp_impl(llama_token_data_array * candidates, float temp) { - for (size_t i = 0; i < candidates->size; ++i) { - candidates->data[i].logit /= temp; +static void llama_constraint_temp_impl(llama_token_data_array * cur_p, float temp) { + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; } } -void llama_constraint_grammar_impl(llama_token_data_array * candidates, const struct llama_grammar & grammar) { - llama_grammar_apply_impl(grammar, candidates); +static void llama_constraint_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { + llama_grammar_apply_impl(grammar, cur_p); } void llama_constraint_penalties_impl( - llama_token_data_array * candidates, + llama_token_data_array * cur_p, const llama_token_cnt & token_count, float penalty_repeat, float penalty_freq, float penalty_present) { - // Apply frequency and presence penalties to the candidates - for (size_t i = 0; i < candidates->size; ++i) { - const auto token_iter = token_count.find(candidates->data[i].id); + // Apply frequency and presence penalties to the cur_p + for (size_t i = 0; i < cur_p->size; ++i) { + const auto token_iter = token_count.find(cur_p->data[i].id); if (token_iter == token_count.end()) { continue; } @@ -409,23 +409,42 @@ void llama_constraint_penalties_impl( // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty_repeat; + if (cur_p->data[i].logit <= 0) { + cur_p->data[i].logit *= penalty_repeat; } else { - candidates->data[i].logit /= penalty_repeat; + cur_p->data[i].logit /= penalty_repeat; } - candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; + cur_p->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; } - candidates->sorted = false; + cur_p->sorted = false; } // -// sampling +// constraints // -// constraints +// softmax + +static struct llama_constraint_i llama_constraint_softmax_i = { + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) { + llama_constraint_softmax_impl(cur_p); + }, + /* .reset = */ nullptr, + /* .copy = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_constraint * llama_constraint_init_softmax_impl() { + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_softmax_i, + /* .ctx = */ nullptr, + }; + + return result; +} // top-k @@ -436,9 +455,9 @@ struct llama_constraint_context_top_k { static struct llama_constraint_i llama_constraint_top_k_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_constraint_top_k_impl(candidates, ctx->k, ctx->min_keep); + llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -446,10 +465,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = { return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_top_k *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_top_k *) cnstr->ctx; } }; @@ -474,9 +490,9 @@ struct llama_constraint_context_top_p { static struct llama_constraint_i llama_constraint_top_p_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; - llama_constraint_top_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -484,10 +500,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = { return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_top_p *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_top_p *) cnstr->ctx; } }; @@ -512,9 +525,9 @@ struct llama_constraint_context_min_p { static struct llama_constraint_i llama_constraint_min_p_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; - llama_constraint_min_p_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -522,10 +535,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = { return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_min_p *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_min_p *) cnstr->ctx; } }; @@ -550,9 +560,9 @@ struct llama_constraint_context_tail_free { static struct llama_constraint_i llama_constraint_tail_free_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; - llama_constraint_tail_free_impl(candidates, ctx->z, ctx->min_keep); + llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -560,10 +570,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_tail_free *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_tail_free *) cnstr->ctx; } }; @@ -588,9 +595,9 @@ struct llama_constraint_context_typical { static struct llama_constraint_i llama_constraint_typical_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; - llama_constraint_typical_impl(candidates, ctx->p, ctx->min_keep); + llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -598,10 +605,7 @@ static struct llama_constraint_i llama_constraint_typical_i = { return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_typical *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_typical *) cnstr->ctx; } }; @@ -625,9 +629,9 @@ struct llama_constraint_context_temp { static struct llama_constraint_i llama_constraint_temp_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; - llama_constraint_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { @@ -635,10 +639,7 @@ static struct llama_constraint_i llama_constraint_temp_i = { return llama_constraint_init_temp_impl(ctx->temp); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_temp *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_temp *) cnstr->ctx; } }; @@ -663,15 +664,15 @@ struct llama_constraint_context_temp_ext { static struct llama_constraint_i llama_constraint_temp_ext_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; - llama_constraint_entropy_impl(candidates, temp_min, temp_max, ctx->exponent); + llama_constraint_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); } else { - llama_constraint_temp_impl(candidates, ctx->temp); + llama_constraint_temp_impl(cur_p, ctx->temp); } }, /* .reset = */ nullptr, @@ -680,10 +681,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_temp_ext *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_temp_ext *) cnstr->ctx; } }; @@ -716,10 +714,10 @@ static struct llama_constraint_i llama_constraint_grammar_i = { llama_grammar_accept_impl(*ctx->grammar, token); } }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { - llama_constraint_grammar_impl(candidates, *ctx->grammar); + llama_constraint_grammar_impl(cur_p, *ctx->grammar); } }, /* .reset = */ [](struct llama_constraint * cnstr) { @@ -749,15 +747,13 @@ static struct llama_constraint_i llama_constraint_grammar_i = { return result; }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; - llama_grammar_free_impl(ctx->grammar); - } + auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; - delete (llama_constraint_context_grammar *) cnstr->ctx; + if (ctx->grammar) { + llama_grammar_free_impl(ctx->grammar); } - delete cnstr; + + delete ctx; } }; @@ -807,13 +803,13 @@ static struct llama_constraint_i llama_constraint_penalties_i = { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.push_back(token); }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; - GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); if (ctx->ignore_eos) { - candidates->data[ctx->vocab->special_eos_id].logit = -INFINITY; + cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY; } if ((ctx->penalty_last_n == 0) || @@ -821,7 +817,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { return; } - const float nl_logit = !ctx->penalize_nl ? candidates->data[ctx->vocab->linefeed_id].logit : -INFINITY; + const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY; // Create a frequency map to count occurrences of each token in last_tokens // TODO: optimize this by maintaining the token count in the constraint context @@ -830,11 +826,11 @@ static struct llama_constraint_i llama_constraint_penalties_i = { token_count[ctx->prev.rat(i)]++; } - llama_constraint_penalties_impl(candidates, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + llama_constraint_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); if (!ctx->penalize_nl) { // restore the logit of the newline token if it was penalized - candidates->data[ctx->vocab->linefeed_id].logit = nl_logit; + cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit; } }, /* .reset = */ [](struct llama_constraint * cnstr) { @@ -858,10 +854,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { return result; }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_penalties *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_penalties *) cnstr->ctx; } }; @@ -896,13 +889,13 @@ struct llama_constraint_context_logit_bias { static struct llama_constraint_i llama_constraint_logit_bias_i = { /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * candidates) { + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; - GGML_ASSERT(candidates->size == ctx->vocab->n_vocab && candidates->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); for (const auto & lb : ctx->logit_bias) { - candidates->data[lb.token].logit += lb.bias; + cur_p->data[lb.token].logit += lb.bias; } }, /* .reset = */ nullptr, @@ -911,10 +904,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_constraint * cnstr) { - if (cnstr->ctx) { - delete (llama_constraint_context_logit_bias *) cnstr->ctx; - } - delete cnstr; + delete (llama_constraint_context_logit_bias *) cnstr->ctx; } }; @@ -940,9 +930,15 @@ struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint } void llama_constraint_free_impl(struct llama_constraint * cnstr) { - if (cnstr->iface->free && cnstr) { + if (cnstr == nullptr) { + return; + } + + if (cnstr->iface->free) { cnstr->iface->free(cnstr); } + + delete cnstr; } void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token) { @@ -951,9 +947,9 @@ void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token t } } -void llama_constraint_apply_impl(struct llama_constraint & cnstr, struct llama_token_data_array * candidates) { +void llama_constraint_apply_impl(struct llama_constraint & cnstr, struct llama_token_data_array * cur_p) { GGML_ASSERT(cnstr.iface->apply); - cnstr.iface->apply(&cnstr, candidates); + cnstr.iface->apply(&cnstr, cur_p); } void llama_constraint_reset_impl(struct llama_constraint & cnstr) { @@ -962,7 +958,9 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { } } +// // samplers +// struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { auto * result = new llama_sampler { @@ -1050,9 +1048,9 @@ void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { } } -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * candidates) { +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { for (auto * cnstr : smpl.constraints) { - llama_constraint_apply_impl(*cnstr, candidates); + llama_constraint_apply_impl(*cnstr, cur_p); } } @@ -1068,16 +1066,16 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } -llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_constraint_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { + llama_constraint_softmax_impl(cur_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; float sum_ti_bi = 0.0; float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) { + for (size_t i = 0; i < size_t(m - 1) && i < cur_p->size - 1; ++i) { float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); sum_ti_bi += t_i * b_i; sum_ti_sq += t_i * t_i; } @@ -1088,14 +1086,14 @@ llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * c float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_constraint_top_k_impl(candidates, int(k), 1); - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + llama_constraint_top_k_impl(cur_p, int(k), 1); + llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); - float observed_surprise = -log2f(candidates->data[X_idx].p); + float observed_surprise = -log2f(cur_p->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error @@ -1104,30 +1102,30 @@ llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * c return X; } -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_constraint_softmax_impl(candidates); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu) { + llama_constraint_softmax_impl(cur_p); // Truncate the words with surprise values greater than mu - candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { return -log2f(candidate.p) > mu; })); - if (candidates->size == 0) { - candidates->size = 1; + if (cur_p->size == 0) { + cur_p->size = 1; } // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); // Sample the next word X from the remaining words - llama_token X = llama_sampler_sample_dist_impl(candidates, rng); + llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { + size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { return candidate.id == X; })); - float observed_surprise = -log2f(candidates->data[X_idx].p); + float observed_surprise = -log2f(cur_p->data[X_idx].p); float e = observed_surprise - tau; // Update mu using the learning rate and error @@ -1136,17 +1134,17 @@ llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array return X; } -llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates, bool probs) { +llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) { if (probs) { // if probs are needed, we apply softmax to get the probabilities - llama_constraint_softmax_impl(candidates); + llama_constraint_softmax_impl(cur_p); - // the candidates are sorted, so we can just return the first one - return candidates->data[0].id; + // the cur_p are sorted, so we can just return the first one + return cur_p->data[0].id; } // return the token with the highest logit - auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { + auto * max_iter = std::max_element(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); @@ -1155,20 +1153,20 @@ llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * candidates return result; } -llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * candidates, std::mt19937 & rng) { - llama_constraint_softmax_impl(candidates); +llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng) { + llama_constraint_softmax_impl(cur_p); std::vector probs; - probs.reserve(candidates->size); + probs.reserve(cur_p->size); - for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + for (size_t i = 0; i < cur_p->size; ++i) { + probs.push_back(cur_p->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); const int idx = dist(rng); - llama_token result = candidates->data[idx].id; + llama_token result = cur_p->data[idx].id; return result; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f60d5b95f86da..e4f9108861592 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -10,19 +10,9 @@ struct llama_grammar; using llama_token_cnt = std::unordered_map; -// TODO: tmp exposed, until tests start using llama_constraint -void llama_constraint_softmax_impl (struct llama_token_data_array * candidates); -void llama_constraint_top_k_impl (struct llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_constraint_top_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_constraint_min_p_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_constraint_tail_free_impl(struct llama_token_data_array * candidates, float z, size_t min_keep); -void llama_constraint_typical_impl (struct llama_token_data_array * candidates, float p, size_t min_keep); -void llama_constraint_entropy_impl (struct llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_constraint_temp_impl (struct llama_token_data_array * candidates, float temp); -void llama_constraint_grammar_impl (struct llama_token_data_array * candidates, const struct llama_grammar & grammar); - +// TODO: tmp exposed until test-sampling is fixed void llama_constraint_penalties_impl( - llama_token_data_array * candidates, + llama_token_data_array * cur_p, const llama_token_cnt & token_count, float penalty_repeat, float penalty_freq, @@ -30,6 +20,7 @@ void llama_constraint_penalties_impl( // constraints +struct llama_constraint * llama_constraint_init_softmax_impl (); struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); @@ -62,7 +53,7 @@ struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint void llama_constraint_free_impl(struct llama_constraint * cnstr); void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token); -void llama_constraint_apply_impl (struct llama_constraint & cnstr, struct llama_token_data_array * candidates); +void llama_constraint_apply_impl (struct llama_constraint & cnstr, struct llama_token_data_array * cur_p); void llama_constraint_reset_impl (struct llama_constraint & cnstr); // samplers @@ -101,7 +92,7 @@ void llama_sampler_reset_impl( struct llama_sampler & smp void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * candidates); +void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p); llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); @@ -112,14 +103,14 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); +llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu); +llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu); -llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * candidates, bool probs); -llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs); +llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 2b54a1ff337ac..28f406ce2d295 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20609,6 +20609,10 @@ int32_t llama_chat_apply_template( // sampling // +struct llama_constraint * llama_constraint_init_softmax() { + return llama_constraint_init_softmax_impl(); +} + struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { return llama_constraint_init_top_k_impl(k, min_keep); } @@ -20675,8 +20679,8 @@ void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) llama_constraint_accept_impl(*cnstr, token); } -void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * candidates) { - llama_constraint_apply_impl(*cnstr, candidates); +void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * cur_p) { + llama_constraint_apply_impl(*cnstr, cur_p); } void llama_constraint_reset(struct llama_constraint * cnstr) { @@ -20727,21 +20731,21 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { llama_sampler_accept_impl(*smpl, token); } -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * candidates) { +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; } - llama_sampler_apply_impl(*smpl, candidates); + llama_sampler_apply_impl(*smpl, cur_p); } -llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * candidates) { +llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p) { time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; } const auto type = smpl->params.mirostat; @@ -20749,7 +20753,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok llama_token res; if (type == 1) { - res = llama_sampler_sample_mirostat_impl(candidates, + res = llama_sampler_sample_mirostat_impl(cur_p, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -20757,7 +20761,7 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok smpl->vocab->n_vocab, smpl->mirostat_mu); } else if (type == 2) { - res = llama_sampler_sample_mirostat_v2_impl(candidates, + res = llama_sampler_sample_mirostat_v2_impl(cur_p, smpl->rng, smpl->params.mirostat_tau, smpl->params.mirostat_eta, @@ -20771,28 +20775,28 @@ llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_tok return res; } -llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * candidates, bool probs) { +llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) { time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; } - auto res = llama_sampler_sample_greedy_impl(candidates, probs); + auto res = llama_sampler_sample_greedy_impl(cur_p, probs); smpl->n_sample++; return res; } -llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * candidates) { +llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * cur_p) { time_meas tm(smpl->t_sample_us); - if (candidates == nullptr) { - candidates = &smpl->cur_p; + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; } - auto res = llama_sampler_sample_dist_impl(candidates, smpl->rng); + auto res = llama_sampler_sample_dist_impl(cur_p, smpl->rng); smpl->n_sample++; diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 16eeaa1c8e01b..0c9b46429caa7 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -11,119 +11,125 @@ #include #include -static void dump(const llama_token_data_array * candidates) { - for (size_t i = 0; i < candidates->size; i++) { - printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); +static void dump(const llama_token_data_array * cur_p) { + for (size_t i = 0; i < cur_p->size; i++) { + printf("%d: %f (%f)\n", cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); } } -#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) +#define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0) + +#define TEST(__cnstr, __cur_p) do { \ + auto * cnstr = (__cnstr); \ + llama_constraint_apply(cnstr, (__cur_p)); \ + llama_constraint_free(cnstr); \ +} while(0) static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_constraint_softmax_impl(&candidates_p); - DUMP(&candidates_p); - llama_constraint_top_k_impl(&candidates_p, k, 1); - DUMP(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + TEST(llama_constraint_init_softmax(), &cur_p); + DUMP(&cur_p); + TEST(llama_constraint_init_top_k(k, 1), &cur_p); + DUMP(&cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-5); } } static void test_top_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_constraint_softmax_impl(&candidates_p); - DUMP(&candidates_p); - llama_constraint_top_p_impl(&candidates_p, p, 1); - DUMP(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + TEST(llama_constraint_init_softmax(), &cur_p); + DUMP(&cur_p); + TEST(llama_constraint_init_top_p(p, 1), &cur_p); + DUMP(&cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); } } static void test_tfs(const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - DUMP(&candidates_p); - llama_constraint_tail_free_impl(&candidates_p, z, 1); - DUMP(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + DUMP(&cur_p); + TEST(llama_constraint_init_tail_free(z, 1), &cur_p); + DUMP(&cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); } } static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - DUMP(&candidates_p); - llama_constraint_min_p_impl(&candidates_p, p, 1); - DUMP(&candidates_p); - llama_constraint_softmax_impl(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + DUMP(&cur_p); + TEST(llama_constraint_init_min_p(p, 1), &cur_p); + DUMP(&cur_p); + TEST(llama_constraint_init_softmax(), &cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); } } static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - DUMP(&candidates_p); - llama_constraint_typical_impl(&candidates_p, p, 1); - DUMP(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + DUMP(&cur_p); + TEST(llama_constraint_init_typical(p, 1), &cur_p); + DUMP(&cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -135,11 +141,11 @@ static void test_penalties( const size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } llama_token_cnt token_count; @@ -147,55 +153,55 @@ static void test_penalties( token_count[last_tokens[i]]++; } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_constraint_softmax_impl(&candidates_p); - DUMP(&candidates_p); - llama_constraint_penalties_impl(&candidates_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); - llama_constraint_softmax_impl(&candidates_p); - DUMP(&candidates_p); + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + TEST(llama_constraint_init_softmax(), &cur_p); + DUMP(&cur_p); + llama_constraint_penalties_impl(&cur_p, token_count, repeat_penalty, alpha_frequency, alpha_presence); // TODO: avoid + TEST(llama_constraint_init_softmax(), &cur_p); + DUMP(&cur_p); - GGML_ASSERT(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(cur_p.size == expected_probs.size()); + for (size_t i = 0; i < cur_p.size; i++) { + GGML_ASSERT(fabs(cur_p.data[i].p - expected_probs[i]) < 1e-3); } } static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { - std::vector candidates; - candidates.reserve(n_vocab); + std::vector cur; + cur.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { const float logit = logf(token_id); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); + cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; llama_token min_token_id = 0; const llama_token max_token_id = n_vocab-1; for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_constraint_top_k_impl(&candidates_p, top_k, 1); break; + case 'k': TEST(llama_constraint_init_top_k(top_k, 1), &cur_p); break; case 'f': GGML_ABORT("tail_free test not implemented"); case 'y': GGML_ABORT("typical test not implemented"); - case 'p': llama_constraint_top_p_impl(&candidates_p, top_p, 1); break; - case 'm': llama_constraint_min_p_impl(&candidates_p, min_p, 1); break; + case 'p': TEST(llama_constraint_init_top_p(top_p, 1), &cur_p); break; + case 'm': TEST(llama_constraint_init_min_p(min_p, 1), &cur_p); break; case 't': GGML_ABORT("temperature test not implemented"); default : GGML_ABORT("Unknown sampler"); } - llama_constraint_softmax_impl(&candidates_p); // make sure tokens are sorted for tests + TEST(llama_constraint_init_softmax(), &cur_p); // make sure tokens are sorted for tests - const int size = candidates_p.size; + const int size = cur_p.size; if (s == 'k') { const int expected_size = std::min(size, top_k); min_token_id = std::max(min_token_id, (llama_token)(n_vocab - top_k)); GGML_ASSERT(size == expected_size); - GGML_ASSERT(candidates_p.data[0].id == max_token_id); - GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); + GGML_ASSERT(cur_p.data[0].id == max_token_id); + GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id); } else if (s == 'p') { const int softmax_divisor = n_vocab * (n_vocab-1) / 2 - min_token_id * (min_token_id-1) / 2; const int softmax_numerator_target = ceilf(top_p * softmax_divisor); @@ -217,8 +223,8 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler } GGML_ASSERT(size == expected_size); - GGML_ASSERT(candidates_p.data[0].id == max_token_id); - GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); + GGML_ASSERT(cur_p.data[0].id == max_token_id); + GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id); } else if (s == 'm') { int expected_size = ceilf((1.0f-min_p) * n_vocab); expected_size = std::max(expected_size, 1); @@ -230,8 +236,8 @@ static void test_sampler_queue(const size_t n_vocab, const std::string & sampler min_token_id = std::min(min_token_id, (llama_token)(n_vocab - 1)); GGML_ASSERT(size == expected_size); - GGML_ASSERT(candidates_p.data[0].id == max_token_id); - GGML_ASSERT(candidates_p.data[expected_size-1].id == min_token_id); + GGML_ASSERT(cur_p.data[0].id == max_token_id); + GGML_ASSERT(cur_p.data[expected_size-1].id == min_token_id); } else { GGML_ABORT("fatal error"); } From 1a0de0b781f727162f85a45c47a03275ae0f7f31 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 15:37:51 +0300 Subject: [PATCH 18/47] constraint : add name API ggml-ci --- common/sampling.cpp | 8 ++++--- common/sampling.h | 2 +- include/llama.h | 11 ++++----- src/llama-sampling.cpp | 51 +++++++++++++++++++++++++----------------- src/llama.cpp | 24 ++++++++++---------- 5 files changed, 53 insertions(+), 43 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 34371bc241167..f5edd87a661d9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -198,9 +198,11 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context auto & grmr = gsmpl->grmr; auto & smpl = gsmpl->smpl; - auto * cur_p = llama_sampler_get_candidates(smpl); + const auto * logits = llama_get_logits_ith(ctx, idx); + + llama_sampler_set_logits(smpl, logits); - llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + auto * cur_p = llama_sampler_get_candidates(smpl); llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); @@ -223,7 +225,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context } // if the token is not valid, sample again, first apply the grammar constraints and then sample - llama_sampler_set_logits(smpl, llama_get_logits_ith(ctx, idx)); + llama_sampler_set_logits(smpl, logits); llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); diff --git a/common/sampling.h b/common/sampling.h index a04645a676a44..bab26493701a1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -63,7 +63,7 @@ struct gpt_sampler_params { // gpt_sampler extends llama_sampler with additional functionality: // // - grammar support -// - custom sampler logic based on the paramerters +// - custom sampler logic based on the parameters // struct gpt_sampler; diff --git a/include/llama.h b/include/llama.h index 0f08c44c0825c..bf67f24c75854 100644 --- a/include/llama.h +++ b/include/llama.h @@ -386,11 +386,11 @@ extern "C" { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sampling_ms; + double t_sampler_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sampling; + int32_t n_sampler; int32_t n_p_eval; int32_t n_eval; }; @@ -1025,8 +1025,7 @@ extern "C" { // user code can implement the interface below in order to create custom llama_constraint struct llama_constraint_i { - // TODO: add name API - + const char * (*name) (const struct llama_constraint * cnstr); // can be NULL void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required void (*reset) ( struct llama_constraint * cnstr); // can be NULL @@ -1035,8 +1034,6 @@ extern "C" { // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph //void (*apply_ggml) (struct llama_constraint * cnstr, ...); - - // TODO: add API to get timing stats }; struct llama_constraint { @@ -1044,7 +1041,7 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (); + LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 99e0edfd9d19c..733957fdd5ee6 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -428,6 +428,7 @@ void llama_constraint_penalties_impl( // softmax static struct llama_constraint_i llama_constraint_softmax_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) { llama_constraint_softmax_impl(cur_p); @@ -454,9 +455,10 @@ struct llama_constraint_context_top_k { }; static struct llama_constraint_i llama_constraint_top_k_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); }, /* .reset = */ nullptr, @@ -466,7 +468,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_k *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { @@ -489,9 +491,10 @@ struct llama_constraint_context_top_p { }; static struct llama_constraint_i llama_constraint_top_p_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -501,7 +504,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_p *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { @@ -524,9 +527,10 @@ struct llama_constraint_context_min_p { }; static struct llama_constraint_i llama_constraint_min_p_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -536,7 +540,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_min_p *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { @@ -559,9 +563,10 @@ struct llama_constraint_context_tail_free { }; static struct llama_constraint_i llama_constraint_tail_free_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, @@ -571,7 +576,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_tail_free *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { @@ -594,9 +599,10 @@ struct llama_constraint_context_typical { }; static struct llama_constraint_i llama_constraint_typical_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, @@ -606,7 +612,7 @@ static struct llama_constraint_i llama_constraint_typical_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_typical *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { @@ -628,9 +634,10 @@ struct llama_constraint_context_temp { }; static struct llama_constraint_i llama_constraint_temp_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; llama_constraint_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, @@ -640,7 +647,7 @@ static struct llama_constraint_i llama_constraint_temp_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_temp *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_temp_impl(float temp) { @@ -663,9 +670,10 @@ struct llama_constraint_context_temp_ext { }; static struct llama_constraint_i llama_constraint_temp_ext_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; @@ -682,7 +690,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_temp_ext *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { @@ -708,14 +716,15 @@ struct llama_constraint_context_grammar { }; static struct llama_constraint_i llama_constraint_grammar_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; }, /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_accept_impl(*ctx->grammar, token); } }, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_constraint_grammar_impl(cur_p, *ctx->grammar); } @@ -747,14 +756,14 @@ static struct llama_constraint_i llama_constraint_grammar_i = { return result; }, /* .free = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); } delete ctx; - } + }, }; struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { @@ -799,6 +808,7 @@ struct llama_constraint_context_penalties { }; static struct llama_constraint_i llama_constraint_penalties_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; }, /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.push_back(token); @@ -855,7 +865,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_penalties *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { @@ -888,6 +898,7 @@ struct llama_constraint_context_logit_bias { }; static struct llama_constraint_i llama_constraint_logit_bias_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; @@ -905,7 +916,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_logit_bias *) cnstr->ctx; - } + }, }; struct llama_constraint * llama_constraint_init_logit_bias_impl( diff --git a/src/llama.cpp b/src/llama.cpp index 28f406ce2d295..712d9cfb58636 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20609,7 +20609,7 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_constraint * llama_constraint_init_softmax() { +struct llama_constraint * llama_constraint_init_softmax(void) { return llama_constraint_init_softmax_impl(); } @@ -20849,22 +20849,22 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { const llama_timings timings = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - - /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampler_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + + /*.n_sampler =*/ std::max(0, smpl ? smpl->n_sample : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + __func__, timings.t_sampler_ms, timings.n_sampler, timings.t_sampler_ms / timings.n_sampler, 1e3 / timings.t_sampler_ms * timings.n_sampler); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", From 0e1378c8448e7616717e71e084bf0999da7f9b6c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 16:57:43 +0300 Subject: [PATCH 19/47] sampling : convert mirostat samplers to constraints ggml-ci --- common/sampling.cpp | 88 ++++++------- include/llama.h | 47 ++++--- src/llama-sampling.cpp | 280 ++++++++++++++++++++++++++++------------- src/llama-sampling.h | 59 +++++---- src/llama.cpp | 46 ++----- 5 files changed, 305 insertions(+), 215 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f5edd87a661d9..b528d49291c66 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -47,9 +47,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st lparams.seed = params.seed; lparams.n_prev = params.n_prev; - lparams.mirostat = params.mirostat; - lparams.mirostat_tau = params.mirostat_tau; - lparams.mirostat_eta = params.mirostat_eta; auto * result = new gpt_sampler { /* .params = */ params, @@ -69,29 +66,39 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .smpl = */ llama_sampler_init(model, lparams) }; - for (const auto & cnstr : params.constraints) { - switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - default: - GGML_ASSERT(false && "unknown constraint type"); + if (params.mirostat == 0) { + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } } + } else if (params.mirostat == 1) { + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + } else if (params.mirostat == 2) { + llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + } else { + GGML_ASSERT(false && "unknown mirostat version"); } return result; @@ -153,7 +160,6 @@ static llama_token gpt_sampler_sample( struct llama_sampler * smpl, struct llama_token_data_array * cur_p, float temp, - int mirostat, int n_probs) { llama_token res = 0; @@ -167,24 +173,20 @@ static llama_token gpt_sampler_sample( // apply all sampling constraints and then sample llama_sampler_apply(smpl, cur_p); - if (mirostat != 0) { - res = llama_sampler_sample_mirostat(smpl, cur_p); - } else { - res = llama_sampler_sample_dist(smpl, cur_p); + res = llama_sampler_sample_dist(smpl, cur_p); - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); + // } + //} - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } + //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); } return res; @@ -208,7 +210,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.mirostat, params.n_probs); + const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs); // check if it the sampled token fits the grammar { @@ -231,7 +233,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); - return gpt_sampler_sample(smpl, cur_p, params.temp, params.mirostat, params.n_probs); + return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs); } void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { diff --git a/include/llama.h b/include/llama.h index bf67f24c75854..8c2a2aff94abf 100644 --- a/include/llama.h +++ b/include/llama.h @@ -369,16 +369,18 @@ extern "C" { float bias; } llama_logit_bias; + enum llama_sampler_type { + LLAMA_SAMPLER_TYPE_GREEDY = 0, + LLAMA_SAMPLER_TYPE_DIST = 1, + }; + typedef struct llama_sampler_params { uint32_t seed; // the seed used to initialize the rng of the sampler int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) - int32_t mirostat; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - float mirostat_tau; // target entropy - float mirostat_eta; // learning rate - - // TODO: add type of sampler: greedy, dist, mirostat, etc. + // TODO: will be used by the llama_decode_with_sampler() API in the future + enum llama_sampler_type type; } llama_sampler_params; // performance timing information @@ -1005,17 +1007,18 @@ extern "C" { // // - Samplers // The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the - // sampler can apply a sequence of constraints to the candidate tokens. + // sampler can apply a sequence of constraints in order to modify the probabilities of the candidates. // // The llama_sampler object contains the entire sampling information: // // - RNG state (seed and generator) // - Custom set of constraints (see llama_sampler_add_constraint) - // - Sampling method (greedy, dist, mirostat) + // - Sampling method (greedy, dist) // - Previous tokens // // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). // + // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model // constraints @@ -1041,14 +1044,23 @@ extern "C" { llama_constraint_context_t ctx; }; - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + + LLAMA_API struct llama_constraint * llama_constraint_init_mirostat( + const struct llama_model * model, + float tau, + float eta); + + LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2( + float tau, + float eta); LLAMA_API struct llama_constraint * llama_constraint_init_grammar( const struct llama_model * model, @@ -1095,9 +1107,8 @@ extern "C" { LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_greedy (struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); - LLAMA_API llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); /// @details Get the number of accepted tokens so far (max of n_prev) LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 733957fdd5ee6..385f7bec15fd5 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -450,8 +450,8 @@ struct llama_constraint * llama_constraint_init_softmax_impl() { // top-k struct llama_constraint_context_top_k { - int32_t k; - size_t min_keep; + const int32_t k; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_k_i = { @@ -486,8 +486,8 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min // top-p struct llama_constraint_context_top_p { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_p_i = { @@ -522,8 +522,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k // min-p struct llama_constraint_context_min_p { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_min_p_i = { @@ -558,8 +558,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k // tail-free struct llama_constraint_context_tail_free { - float z; - size_t min_keep; + const float z; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_tail_free_i = { @@ -594,8 +594,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m // typical struct llama_constraint_context_typical { - float p; - size_t min_keep; + const float p; + const size_t min_keep; }; static struct llama_constraint_i llama_constraint_typical_i = { @@ -630,7 +630,7 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min // temp struct llama_constraint_context_temp { - float temp; + const float temp; }; static struct llama_constraint_i llama_constraint_temp_i = { @@ -664,9 +664,9 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { // temp-ext struct llama_constraint_context_temp_ext { - float temp; - float delta; - float exponent; + const float temp; + const float delta; + const float exponent; }; static struct llama_constraint_i llama_constraint_temp_ext_i = { @@ -706,6 +706,176 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float return result; } +// mirostat + +struct llama_constraint_context_mirostat { + const struct llama_vocab * vocab; + + const float tau; + const float eta; + + const int32_t m; + + float mu; + + std::vector cur; +}; + +static struct llama_constraint_i llama_constraint_mirostat_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat"; }, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + + int32_t idx = -1; + for (size_t i = 0; i < ctx->cur.size(); ++i) { + if (ctx->cur[i].id == token) { + idx = i; + break; + } + } + + float observed_surprise = -log2f(ctx->cur[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + + llama_constraint_softmax_impl(cur_p); + + // Estimate s_hat using the most probable m tokens + float s_hat = 0.0; + float sum_ti_bi = 0.0; + float sum_ti_sq = 0.0; + for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) { + float t_i = logf(float(i + 2) / float(i + 1)); + float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); + sum_ti_bi += t_i * b_i; + sum_ti_sq += t_i * t_i; + } + s_hat = sum_ti_bi / sum_ti_sq; + + // Compute k from the estimated s_hat and target surprise value + float epsilon_hat = s_hat - 1; + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); + + llama_constraint_top_k_impl(cur_p, int(k), 1); + + // remember the order to be able to compute the distance later when accepting the token + ctx->cur.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->cur[i] = cur_p->data[i]; + } + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + ctx->mu = 0.0f; + }, + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; + return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_mirostat *) cnstr->ctx; + }, +}; + +struct llama_constraint * llama_constraint_init_mirostat_impl( + const struct llama_vocab & vocab, + float tau, + float eta, + int32_t m) { + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_mirostat_i, + /* .ctx = */ new llama_constraint_context_mirostat { + /*.vocab =*/ &vocab, + /*.tau =*/ tau, + /*.eta =*/ eta, + /*.m =*/ m, + /*.mu =*/ 0.0f, + /*.cur =*/ {}, + }, + }; + + return result; +} + +// mirostat v2 + +struct llama_constraint_context_mirostat_v2 { + const float tau; + const float eta; + + float mu; + + std::vector cur; +}; + +static struct llama_constraint_i llama_constraint_mirostat_v2_i = { + /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat-v2"; }, + /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + + int32_t idx = -1; + for (size_t i = 0; i < ctx->cur.size(); ++i) { + if (ctx->cur[i].id == token) { + idx = i; + break; + } + } + + float observed_surprise = -log2f(ctx->cur[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; + }, + /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + + llama_constraint_softmax_impl(cur_p); + + // Truncate the words with surprise values greater than mu + cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { + return -log2f(candidate.p) > ctx->mu; + })); + + if (cur_p->size == 0) { + cur_p->size = 1; + } + + // Normalize the probabilities of the remaining words + llama_constraint_softmax_impl(cur_p); + }, + /* .reset = */ [](struct llama_constraint * cnstr) { + auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + ctx->mu = 0.0f; + }, + /* .copy = */ [](const struct llama_constraint * cnstr) { + const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; + return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta); + }, + /* .free = */ [](struct llama_constraint * cnstr) { + delete (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + }, +}; + +struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { + struct llama_constraint * result = new llama_constraint { + /* .iface = */ &llama_constraint_mirostat_v2_i, + /* .ctx = */ new llama_constraint_context_mirostat_v2 { + /*.tau =*/ tau, + /*.eta =*/ eta, + /*.mu =*/ 0.0f, + /*.cur =*/ {}, + }, + }; + + return result; +} + // grammar struct llama_constraint_context_grammar { @@ -796,13 +966,13 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ struct llama_constraint_context_penalties { const struct llama_vocab * vocab; - int32_t penalty_last_n; - float penalty_repeat; - float penalty_freq; - float penalty_present; + const int32_t penalty_last_n; + const float penalty_repeat; + const float penalty_freq; + const float penalty_present; - bool penalize_nl; - bool ignore_eos; + const bool penalize_nl; + const bool ignore_eos; ring_buffer prev; }; @@ -980,7 +1150,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, /* .rng = */ std::mt19937(params.seed), - /* .mirostat_mu = */ 0.0f, /* .prev = */ { (size_t) params.n_prev }, /* .constraints = */ {}, /* .cur = */ {}, @@ -1011,7 +1180,6 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) /* .rng = */ smpl.rng, - /* .mirostat_mu = */ smpl.mirostat_mu, /* .prev = */ smpl.prev, /* .constraints = */ {}, /* .cur = */ {}, @@ -1077,74 +1245,6 @@ int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } -llama_token llama_sampler_sample_mirostat_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu) { - llama_constraint_softmax_impl(cur_p); - - // Estimate s_hat using the most probable m tokens - float s_hat = 0.0; - float sum_ti_bi = 0.0; - float sum_ti_sq = 0.0; - for (size_t i = 0; i < size_t(m - 1) && i < cur_p->size - 1; ++i) { - float t_i = logf(float(i + 2) / float(i + 1)); - float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p); - sum_ti_bi += t_i * b_i; - sum_ti_sq += t_i * t_i; - } - s_hat = sum_ti_bi / sum_ti_sq; - - // Compute k from the estimated s_hat and target surprise value - float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); - - // Sample the next word X using top-k sampling - llama_constraint_top_k_impl(cur_p, int(k), 1); - llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - float observed_surprise = -log2f(cur_p->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu) { - llama_constraint_softmax_impl(cur_p); - - // Truncate the words with surprise values greater than mu - cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return -log2f(candidate.p) > mu; - })); - - if (cur_p->size == 0) { - cur_p->size = 1; - } - - // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(cur_p); - - // Sample the next word X from the remaining words - llama_token X = llama_sampler_sample_dist_impl(cur_p, rng); - - // Compute error as the difference between observed surprise and target surprise value - size_t X_idx = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { - return candidate.id == X; - })); - - float observed_surprise = -log2f(cur_p->data[X_idx].p); - float e = observed_surprise - tau; - - // Update mu using the learning rate and error - mu = mu - eta * e; - - return X; -} - llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) { if (probs) { // if probs are needed, we apply softmax to get the probabilities diff --git a/src/llama-sampling.h b/src/llama-sampling.h index e4f9108861592..aad9f311a8f2d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -20,16 +20,38 @@ void llama_constraint_penalties_impl( // constraints -struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); -struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep); -struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_temp_impl (float t); -struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); - -struct llama_constraint * llama_constraint_init_grammar_impl ( +struct llama_constraint * llama_constraint_init_softmax_impl (); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); +struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); +struct llama_constraint * llama_constraint_init_temp_impl (float t); +struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); + +/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + +struct llama_constraint * llama_constraint_init_mirostat_impl( + const struct llama_vocab & vocab, + float tau, + float eta, + int32_t m); + +/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. +/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. +/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. +/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. +/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. +struct llama_constraint * llama_constraint_init_mirostat_v2_impl( + float tau, + float eta); + +struct llama_constraint * llama_constraint_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); @@ -67,8 +89,6 @@ struct llama_sampler { std::mt19937 rng; - float mirostat_mu; - ring_buffer prev; std::vector constraints; @@ -97,20 +117,5 @@ void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_d llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); -/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, int32_t m, int32_t n_vocab, float & mu); - -/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. -llama_token llama_sampler_sample_mirostat_v2_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, float tau, float eta, float & mu); - llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs); llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 712d9cfb58636..8f6503152f7f8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17939,9 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_prev =*/ 256, - /*.mirostat =*/ 0, - /*.mirostat_tau =*/ 5.00f, - /*.mirostat_eta =*/ 0.10f, + /*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY, }; return result; @@ -20641,6 +20639,14 @@ struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta return llama_constraint_init_temp_ext_impl(temp, delta, exponent); } +struct llama_constraint * llama_constraint_init_mirostat(const struct llama_model * model, float tau, float eta) { + return llama_constraint_init_mirostat_impl(model->vocab, tau, eta, 100); +} + +struct llama_constraint * llama_constraint_init_mirostat_v2(float tau, float eta) { + return llama_constraint_init_mirostat_v2_impl(tau, eta); +} + struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); } @@ -20741,40 +20747,6 @@ void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * c llama_sampler_apply_impl(*smpl, cur_p); } -llama_token llama_sampler_sample_mirostat(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - const auto type = smpl->params.mirostat; - - llama_token res; - - if (type == 1) { - res = llama_sampler_sample_mirostat_impl(cur_p, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - 100, - smpl->vocab->n_vocab, - smpl->mirostat_mu); - } else if (type == 2) { - res = llama_sampler_sample_mirostat_v2_impl(cur_p, - smpl->rng, - smpl->params.mirostat_tau, - smpl->params.mirostat_eta, - smpl->mirostat_mu); - } else { - GGML_ABORT("invalid mirostat type: %d", type); - } - - smpl->n_sample++; - - return res; -} - llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) { time_meas tm(smpl->t_sample_us); From 784a64404093afb9dae64a8f44803141d9789087 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 17:13:15 +0300 Subject: [PATCH 20/47] sampler : API to iterate constraints ggml-ci --- common/sampling.cpp | 39 ++++++++++------------- common/sampling.h | 8 ++--- examples/batched.swift/Sources/main.swift | 6 ++-- examples/batched/batched.cpp | 6 ++-- examples/infill/infill.cpp | 2 +- examples/main/main.cpp | 17 +++++----- include/llama.h | 9 ++++-- src/llama-sampling.cpp | 14 +++++++- src/llama-sampling.h | 4 ++- src/llama.cpp | 12 +++++-- 10 files changed, 69 insertions(+), 48 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index b528d49291c66..718001844ee63 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -12,7 +12,7 @@ struct gpt_sampler { struct llama_sampler * smpl; }; -std::string gpt_sampler_params::print_all() const { +std::string gpt_sampler_params::print() const { char result[1024]; snprintf(result, sizeof(result), @@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const { return std::string(result); } -std::string gpt_sampler_params::print_constraints() const { - std::string result = "CFG -> Penalties "; - if (mirostat == 0) { - for (const auto & cnstr : constraints) { - const auto name = gpt_constraint_type_to_str(cnstr); - if (!name.empty()) { - result += "-> " + name + " "; - } - } - } else { - result += "-> mirostat "; +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { + std::string result = "\tlogits"; + + for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) { + const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i); + result += " -> " + std::string(cnstr->iface->name(cnstr)) + " "; } return result; @@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st for (const auto & cnstr : params.constraints) { switch (cnstr) { case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); break; case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; default: GGML_ASSERT(false && "unknown constraint type"); } } } else if (params.mirostat == 1) { - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { - llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } diff --git a/common/sampling.h b/common/sampling.h index bab26493701a1..8ec7459994dcc 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -54,10 +54,7 @@ struct gpt_sampler_params { std::vector logit_bias; // logit biases to apply // print the parameters into a string - std::string print_all() const; - - // print the constraints into a string - std::string print_constraints() const; + std::string print() const; }; // gpt_sampler extends llama_sampler with additional functionality: @@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da // helpers +// print the constraints into a string +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); + // get a string representation of the last accepted tokens std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 380040e572ecc..6b9f3e0d55bc2 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -61,9 +61,9 @@ defer { llama_sampler_free(smpl) } -llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1)); -llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1)); -llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); let n_ctx = llama_n_ctx(context) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 0f35f6cd58775..cbab4b66b5dd0 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_init(model, sparams); - llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); - llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 9f9f81a7f44ff..3895b586ecc69 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -301,7 +301,7 @@ int main(int argc, char ** argv) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); } } - LOG_TEE("sampling: \n%s\n", sparams.print_all().c_str()); + LOG_TEE("sampling: \n%s\n", sparams.print().c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 1b706efbc2fa6..85dea9782e152 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -457,8 +457,15 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str()); - LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str()); + + smpl = gpt_sampler_init(model, sparams); + if (!smpl) { + fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); + exit(1); + } + + LOG_TEE("sampling params: \n%s\n", sparams.print().c_str()); + LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str()); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); // group-attention state @@ -525,12 +532,6 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - smpl = gpt_sampler_init(model, sparams); - if (!smpl) { - fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); - exit(1); - } - if (llama_model_has_encoder(model)) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); diff --git a/include/llama.h b/include/llama.h index 8c2a2aff94abf..813d854ef754c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1012,7 +1012,7 @@ extern "C" { // The llama_sampler object contains the entire sampling information: // // - RNG state (seed and generator) - // - Custom set of constraints (see llama_sampler_add_constraint) + // - Custom set of constraints (see llama_sampler_constraint_add) // - Sampling method (greedy, dist) // - Previous tokens // @@ -1083,7 +1083,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); - // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint) + // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); @@ -1102,7 +1102,10 @@ extern "C" { LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); // important: takes ownership of the constraint object and will free it in llama_sampler_free - LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr); + LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr); + LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl); + LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); + LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 385f7bec15fd5..81cc357db517b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { // TODO: should we reset the timings? } -void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { +void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { smpl.constraints.push_back(cnstr); } +int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) { + return smpl.constraints.size(); +} + +struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) { + if (ith < 0 || ith >= (int) smpl.constraints.size()) { + return nullptr; + } + + return smpl.constraints[ith]; +} + void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { smpl.prev.push_back(token); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index aad9f311a8f2d..bf5f596f7c07d 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); void llama_sampler_reset_impl( struct llama_sampler & smpl); -void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr); +void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); +int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); +struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p); diff --git a/src/llama.cpp b/src/llama.cpp index 8f6503152f7f8..6426073eb2f14 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20729,8 +20729,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp return &smpl->cur_p; } -void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) { - llama_sampler_add_constraint_impl(*smpl, cnstr); +void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) { + llama_sampler_constraint_add_impl(*smpl, cnstr); +} + +int llama_sampler_n_constraints (const struct llama_sampler * smpl) { + return llama_sampler_n_constraints_impl(*smpl); +} + +struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) { + return llama_sampler_constraint_get_impl(*smpl, i); } void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { From e7a11cac0ef293d6218691cb9a7333beaa2c2756 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 20:21:02 +0300 Subject: [PATCH 21/47] sampling : simplify new llama_sampler calls --- src/llama-sampling.cpp | 64 +++++++++++------------------------------- 1 file changed, 16 insertions(+), 48 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 81cc357db517b..91a2cb5f5e08f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -439,12 +439,10 @@ static struct llama_constraint_i llama_constraint_softmax_i = { }; struct llama_constraint * llama_constraint_init_softmax_impl() { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_softmax_i, /* .ctx = */ nullptr, }; - - return result; } // top-k @@ -472,15 +470,13 @@ static struct llama_constraint_i llama_constraint_top_k_i = { }; struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_top_k_i, /* .ctx = */ new llama_constraint_context_top_k { /*.k =*/ k, /*.min_keep =*/ min_keep, }, }; - - return result; } // top-p @@ -508,15 +504,13 @@ static struct llama_constraint_i llama_constraint_top_p_i = { }; struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_top_p_i, /* .ctx = */ new llama_constraint_context_top_p { /*.p =*/ p, /*.min_keep =*/ min_keep, }, }; - - return result; } // min-p @@ -544,15 +538,13 @@ static struct llama_constraint_i llama_constraint_min_p_i = { }; struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_min_p_i, /* .ctx = */ new llama_constraint_context_min_p { /*.p =*/ p, /*.min_keep =*/ min_keep, }, }; - - return result; } // tail-free @@ -580,15 +572,13 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { }; struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_tail_free_i, /* .ctx = */ new llama_constraint_context_tail_free { /*.z =*/ z, /*.min_keep =*/ min_keep, }, }; - - return result; } // typical @@ -616,15 +606,13 @@ static struct llama_constraint_i llama_constraint_typical_i = { }; struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_typical_i, /* .ctx = */ new llama_constraint_context_typical { /*.p =*/ p, /*.min_keep =*/ min_keep, }, }; - - return result; } // temp @@ -651,14 +639,12 @@ static struct llama_constraint_i llama_constraint_temp_i = { }; struct llama_constraint * llama_constraint_init_temp_impl(float temp) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_temp_i, /* .ctx = */ new llama_constraint_context_temp { /*.temp =*/ temp, }, }; - - return result; } // temp-ext @@ -694,7 +680,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { }; struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_temp_ext_i, /* .ctx = */ new llama_constraint_context_temp_ext { /*.temp =*/ temp, @@ -702,8 +688,6 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float /*.exponent =*/ exponent, }, }; - - return result; } // mirostat @@ -782,12 +766,8 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { }, }; -struct llama_constraint * llama_constraint_init_mirostat_impl( - const struct llama_vocab & vocab, - float tau, - float eta, - int32_t m) { - struct llama_constraint * result = new llama_constraint { +struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { + return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_i, /* .ctx = */ new llama_constraint_context_mirostat { /*.vocab =*/ &vocab, @@ -798,8 +778,6 @@ struct llama_constraint * llama_constraint_init_mirostat_impl( /*.cur =*/ {}, }, }; - - return result; } // mirostat v2 @@ -863,7 +841,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { }; struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_v2_i, /* .ctx = */ new llama_constraint_context_mirostat_v2 { /*.tau =*/ tau, @@ -872,8 +850,6 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa /*.cur =*/ {}, }, }; - - return result; } // grammar @@ -953,12 +929,10 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ }; } - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_grammar_i, /* .ctx = */ ctx, }; - - return result; } // penalties @@ -1039,10 +1013,10 @@ static struct llama_constraint_i llama_constraint_penalties_i = { }; struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { - GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_penalties_i, /* .ctx = */ new llama_constraint_context_penalties { /*.vocab =*/ &vocab, @@ -1055,8 +1029,6 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam /*.prev =*/ ring_buffer(penalty_last_n), }, }; - - return result; } // logit-bias @@ -1093,15 +1065,13 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - struct llama_constraint * result = new llama_constraint { + return new llama_constraint { /* .iface = */ &llama_constraint_logit_bias_i, /* .ctx = */ new llama_constraint_context_logit_bias { /*.vocab =*/ &vocab, /*.logit_bias=*/ std::vector(logit_bias, logit_bias + n_logit_bias), }, }; - - return result; } //////////////////////////////////////// @@ -1144,7 +1114,7 @@ void llama_constraint_reset_impl(struct llama_constraint & cnstr) { // struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { - auto * result = new llama_sampler { + return new llama_sampler { /* .params = */ params, /* .vocab = */ &vocab, @@ -1157,8 +1127,6 @@ struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, /* .t_sample_us = */ 0, /* .n_sample = */ 0, }; - - return result; } void llama_sampler_free_impl(struct llama_sampler * smpl) { From 8e80a1cf6ba2bf27a1ccf64847e066f70a495bc4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 21:23:35 +0300 Subject: [PATCH 22/47] sampling : simplify sample API ggml-ci --- common/sampling.cpp | 124 +++++++------------ common/sampling.h | 11 +- examples/batched/batched.cpp | 6 +- examples/gritlm/gritlm.cpp | 8 +- examples/passkey/passkey.cpp | 8 +- examples/save-load-state/save-load-state.cpp | 6 +- examples/simple/simple.cpp | 8 +- examples/speculative/speculative.cpp | 8 +- include/llama.h | 20 ++- src/llama-sampling.cpp | 83 ++++++------- src/llama-sampling.h | 16 ++- src/llama.cpp | 48 +++---- 12 files changed, 147 insertions(+), 199 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 718001844ee63..df2d1958ca1e0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -40,8 +40,9 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_params lparams = llama_sampler_default_params(); - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; + lparams.seed = params.seed; + lparams.n_prev = params.n_prev; + lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST; auto * result = new gpt_sampler { /* .params = */ params, @@ -61,39 +62,41 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .smpl = */ llama_sampler_init(model, lparams) }; - if (params.mirostat == 0) { - for (const auto & cnstr : params.constraints) { - switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); - break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); - break; - default: - GGML_ASSERT(false && "unknown constraint type"); + if (params.temp > 0.0f) { + if (params.mirostat == 0) { + for (const auto & cnstr : params.constraints) { + switch (cnstr) { + case GPT_CONSTRAINT_TYPE_TOP_K: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TOP_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_MIN_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TFS_Z: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TYPICAL_P: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + break; + case GPT_CONSTRAINT_TYPE_TEMPERATURE: + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + break; + default: + GGML_ASSERT(false && "unknown constraint type"); + } } + } else if (params.mirostat == 1) { + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + } else if (params.mirostat == 2) { + llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + } else { + GGML_ASSERT(false && "unknown mirostat version"); } - } else if (params.mirostat == 1) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); - } else if (params.mirostat == 2) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); - } else { - GGML_ASSERT(false && "unknown mirostat version"); } return result; @@ -151,45 +154,11 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); } -static llama_token gpt_sampler_sample( - struct llama_sampler * smpl, - struct llama_token_data_array * cur_p, - float temp, - int n_probs) { - llama_token res = 0; - - if (temp < 0.0f || (temp == 0.0f && n_probs > 0)) { - // greedy sampling, with probs - res = llama_sampler_sample_greedy(smpl, cur_p, true); - } else if (temp == 0.0f) { - // greedy sampling, no probs - res = llama_sampler_sample_greedy(smpl, cur_p, false); - } else { - // apply all sampling constraints and then sample - llama_sampler_apply(smpl, cur_p); - - res = llama_sampler_sample_dist(smpl, cur_p); - - //{ - // const int n_top = 10; - // LOG("top %d candidates:\n", n_top); - - // for (int i = 0; i < n_top; i++) { - // const llama_token id = cur_p.data[i].id; - // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); - // } - //} - - //LOG("sampled token: %5d: '%s'\n", res, llama_token_to_piece(smpl, res).c_str()); - } - - return res; +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) { + return llama_sampler_sample(gsmpl->smpl, cur_p); } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { - const auto & params = gsmpl->params; - auto & bias = gsmpl->bias; auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; @@ -204,8 +173,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); - // first, sample the token without any grammar constraints - const llama_token id = gpt_sampler_sample(smpl, nullptr, params.temp, params.n_probs); + llama_sampler_apply(smpl, cur_p); + + const llama_token id = llama_sampler_sample(smpl, cur_p); // check if it the sampled token fits the grammar { @@ -228,7 +198,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(pnlt, cur_p); llama_constraint_apply(grmr, cur_p); - return gpt_sampler_sample(smpl, cur_p, params.temp, params.n_probs); + llama_sampler_apply(smpl, cur_p); + + return llama_sampler_sample(smpl, cur_p); } void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { @@ -237,14 +209,6 @@ void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_arra llama_constraint_apply(gsmpl->grmr, cur_p); } -llama_token gpt_sampler_sample_dist(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { - return llama_sampler_sample_dist(gsmpl->smpl, cur_p); -} - -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs) { - return llama_sampler_sample_greedy(gsmpl->smpl, cur_p, probs); -} - std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { auto & smpl = gsmpl->smpl; diff --git a/common/sampling.h b/common/sampling.h index 8ec7459994dcc..9bdeadf784398 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -73,15 +73,19 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); +void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); + void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); + llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); -// common sampling implementation: +// extended sampling implementation: // // - set logits // - apply the configured sampling constraints @@ -90,11 +94,6 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); - -llama_token gpt_sampler_sample_dist (struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); -llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p, bool probs); - // helpers // print the constraints into a string diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index cbab4b66b5dd0..9f0a40873b63a 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -66,7 +66,7 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_default_params(); - sparams.seed = params.sparams.seed; + sparams.seed = params.sparams.seed; llama_sampler * smpl = llama_sampler_init(model, sparams); @@ -177,9 +177,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); - const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr); - - //const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 978642cc35dfe..07475ecd30ed1 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -124,7 +124,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_sampler_set_logits(smpl, logits); - llama_token token = llama_sampler_sample_greedy(smpl, nullptr, false); + llama_token token = llama_sampler_sample(smpl, nullptr); if (token == eos_token) { break; } @@ -171,7 +171,11 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index b287d8403db7b..b9800a9170ab7 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,7 +83,11 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize the prompt std::vector tokens_list; @@ -221,7 +225,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 01f66886e33f8..6f8c84137f1be 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -73,7 +73,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); - auto next_token = llama_sampler_sample_dist(smpl, nullptr); + auto next_token = llama_sampler_sample(smpl, nullptr); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -130,7 +130,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl2, logits); - auto next_token = llama_sampler_sample_dist(smpl2, nullptr); + auto next_token = llama_sampler_sample(smpl2, nullptr); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -219,7 +219,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl3, logits); - auto next_token = llama_sampler_sample_dist(smpl3, nullptr); + auto next_token = llama_sampler_sample(smpl3, nullptr); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index ffaa609cb8b26..7193f1ee4a03a 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,7 +55,11 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler * smpl = llama_sampler_init(model, llama_sampler_default_params()); + auto sparams = llama_sampler_default_params(); + + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + + llama_sampler * smpl = llama_sampler_init(model, sparams); // tokenize the prompt @@ -117,7 +121,7 @@ int main(int argc, char ** argv) { llama_sampler_set_logits(smpl, logits); // sample the most likely token - const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false); + const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index bb26f8eb522fa..0b49b2b06f953 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -182,6 +182,8 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); + struct llama_constraint * softmax = llama_constraint_init_softmax(); + // draft sequence data std::vector drafts(n_seq_dft); @@ -236,7 +238,7 @@ int main(int argc, char ** argv) { auto & dist_tgt = *gpt_sampler_get_candidates(smpl); gpt_sampler_apply_grammar(smpl, &dist_tgt); - gpt_sampler_sample_greedy(smpl, &dist_tgt, true); // applies softmax + llama_constraint_apply(softmax, &dist_tgt); float p_tgt = 0.0f; float p_dft = 0.0f; @@ -335,11 +337,10 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = gpt_sampler_sample_dist(smpl, &dist_tgt); + token_id = gpt_sampler_sample(smpl, &dist_tgt); gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } - } else { // greedy verification @@ -615,6 +616,7 @@ int main(int argc, char ** argv) { gpt_sampler_free(drafts[s].smpl); } + llama_constraint_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index 813d854ef754c..763d32abb6237 100644 --- a/include/llama.h +++ b/include/llama.h @@ -370,8 +370,8 @@ extern "C" { } llama_logit_bias; enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_GREEDY = 0, - LLAMA_SAMPLER_TYPE_DIST = 1, + LLAMA_SAMPLER_TYPE_GREEDY = 0, + LLAMA_SAMPLER_TYPE_DIST = 1, }; typedef struct llama_sampler_params { @@ -1092,10 +1092,12 @@ extern "C" { // samplers - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); @@ -1107,11 +1109,7 @@ extern "C" { LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); - LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p); - - LLAMA_API llama_token llama_sampler_sample_dist (struct llama_sampler * smpl, llama_token_data_array * cur_p); - LLAMA_API llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p); /// @details Get the number of accepted tokens so far (max of n_prev) LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 91a2cb5f5e08f..dfd618c331051 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1183,6 +1183,20 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { // TODO: should we reset the timings? } +void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { + smpl.prev.push_back(token); + + for (auto * cnstr : smpl.constraints) { + llama_constraint_accept_impl(*cnstr, token); + } +} + +void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { + for (auto * cnstr : smpl.constraints) { + llama_constraint_apply_impl(*cnstr, cur_p); + } +} + void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { smpl.constraints.push_back(cnstr); } @@ -1199,17 +1213,31 @@ struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_s return smpl.constraints[ith]; } -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - smpl.prev.push_back(token); +llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) { + switch (type) { + case LLAMA_SAMPLER_TYPE_GREEDY: + { + llama_constraint_softmax_impl(cur_p); - for (auto * cnstr : smpl.constraints) { - llama_constraint_accept_impl(*cnstr, token); - } -} + return cur_p->data[0].id; + } + case LLAMA_SAMPLER_TYPE_DIST: + { + llama_constraint_softmax_impl(cur_p); -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - for (auto * cnstr : smpl.constraints) { - llama_constraint_apply_impl(*cnstr, cur_p); + std::vector probs(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs[i] = cur_p->data[i].p; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + + return cur_p->data[idx].id; + } + default: + GGML_ABORT("invalid sampler type"); } } @@ -1224,40 +1252,3 @@ llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { return smpl.prev.size(); } - -llama_token llama_sampler_sample_greedy_impl(llama_token_data_array * cur_p, bool probs) { - if (probs) { - // if probs are needed, we apply softmax to get the probabilities - llama_constraint_softmax_impl(cur_p); - - // the cur_p are sorted, so we can just return the first one - return cur_p->data[0].id; - } - - // return the token with the highest logit - auto * max_iter = std::max_element(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) { - return a.logit < b.logit; - }); - - llama_token result = max_iter->id; - - return result; -} - -llama_token llama_sampler_sample_dist_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng) { - llama_constraint_softmax_impl(cur_p); - - std::vector probs; - probs.reserve(cur_p->size); - - for (size_t i = 0; i < cur_p->size; ++i) { - probs.push_back(cur_p->data[i].p); - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - llama_token result = cur_p->data[idx].id; - - return result; -} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index bf5f596f7c07d..acb1e04feb4dc 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -104,20 +104,18 @@ struct llama_sampler { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); -void llama_sampler_free_impl ( struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); -void llama_sampler_reset_impl( struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); +void llama_sampler_free_impl ( struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); +void llama_sampler_reset_impl ( struct llama_sampler & smpl); +void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p); +llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type); llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); - -llama_token llama_sampler_sample_greedy_impl(struct llama_token_data_array * cur_p, bool probs); -llama_token llama_sampler_sample_dist_impl (struct llama_token_data_array * cur_p, std::mt19937 & rng); diff --git a/src/llama.cpp b/src/llama.cpp index 6426073eb2f14..2f5df2433f499 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17939,7 +17939,7 @@ struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_prev =*/ 256, - /*.type =*/ LLAMA_SAMPLER_TYPE_GREEDY, + /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, }; return result; @@ -20713,6 +20713,20 @@ void llama_sampler_reset(struct llama_sampler * smpl) { llama_sampler_reset_impl(*smpl); } +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); +} + +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + time_meas tm(smpl->t_sample_us); + + if (cur_p == nullptr) { + cur_p = &smpl->cur_p; + } + + llama_sampler_apply_impl(*smpl, cur_p); +} + void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { const int n_vocab = smpl->vocab->n_vocab; @@ -20741,42 +20755,14 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample return llama_sampler_constraint_get_impl(*smpl, i); } -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); -} - -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - llama_sampler_apply_impl(*smpl, cur_p); -} - -llama_token llama_sampler_sample_greedy(struct llama_sampler * smpl, llama_token_data_array * cur_p, bool probs) { - time_meas tm(smpl->t_sample_us); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - auto res = llama_sampler_sample_greedy_impl(cur_p, probs); - - smpl->n_sample++; - - return res; -} - -llama_token llama_sampler_sample_dist(struct llama_sampler * smpl, llama_token_data_array * cur_p) { +llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) { time_meas tm(smpl->t_sample_us); if (cur_p == nullptr) { cur_p = &smpl->cur_p; } - auto res = llama_sampler_sample_dist_impl(cur_p, smpl->rng); + auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type); smpl->n_sample++; From 9b950671f46dd77c44965237374281b341e09f1e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 21:48:57 +0300 Subject: [PATCH 23/47] sampling : fix grammar apply --- common/sampling.cpp | 2 +- examples/speculative/speculative.cpp | 3 --- src/llama-sampling.cpp | 6 +++++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index df2d1958ca1e0..18d3be8456b5c 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -132,7 +132,7 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool appl llama_sampler_accept(gsmpl->smpl, token); } -void gpt_sampler_reset (struct gpt_sampler * gsmpl) { +void gpt_sampler_reset(struct gpt_sampler * gsmpl) { llama_constraint_reset(gsmpl->grmr); llama_sampler_reset(gsmpl->smpl); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 0b49b2b06f953..5cd14c49d6ebf 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -37,9 +37,6 @@ int main(int argc, char ** argv) { return 1; } - // for probabilities to be computed even with temp = 0 - params.sparams.n_probs = 16; - // max number of parallel drafting sequences (i.e. tree branches) const int n_seq_dft = params.n_parallel; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index dfd618c331051..4e44ec417af21 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -855,6 +855,8 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa // grammar struct llama_constraint_context_grammar { + const struct llama_vocab * vocab; + std::string grammar_str; std::string grammar_root; @@ -889,7 +891,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; - auto * result = llama_constraint_init_grammar_impl(*ctx_src->grammar->vocab, nullptr, nullptr); + auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; if (ctx_src->grammar) { @@ -917,12 +919,14 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { + /*.vocab = */ &vocab, /*.grammar_str = */ grammar_str, /*.grammar_root = */ grammar_root, /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), }; } else { *ctx = { + /*.vocab = */ &vocab, /*.grammar_str = */ {}, /*.grammar_root = */ {}, /*.grammar = */ nullptr, From b2b36e9e95249bbeaf2d833377777c5e32c39576 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 4 Sep 2024 22:16:30 +0300 Subject: [PATCH 24/47] example : fix build + fix speculative ggml-ci --- common/sampling.cpp | 12 +++++- common/sampling.h | 5 ++- examples/batched.swift/Sources/main.swift | 4 +- .../llama/src/main/cpp/llama-android.cpp | 2 +- .../llama.cpp.swift/LibLlama.swift | 6 ++- examples/speculative/speculative.cpp | 38 ++++++++++++------- 6 files changed, 45 insertions(+), 22 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 18d3be8456b5c..edc6cd05b3ac9 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -120,7 +120,7 @@ struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { /* .bias = */ llama_constraint_cp(gsmpl->bias), /* .pnlt = */ llama_constraint_cp(gsmpl->pnlt), /* .grmr = */ llama_constraint_cp(gsmpl->grmr), - /* .smpl = */ llama_sampler_cp(gsmpl->smpl) + /* .smpl = */ llama_sampler_cp (gsmpl->smpl) }; } @@ -158,7 +158,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_da return llama_sampler_sample(gsmpl->smpl, cur_p); } -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx) { +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { auto & bias = gsmpl->bias; auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; @@ -173,10 +173,18 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_constraint_apply(bias, cur_p); llama_constraint_apply(pnlt, cur_p); + if (grammar_first) { + llama_constraint_apply(grmr, cur_p); + } + llama_sampler_apply(smpl, cur_p); const llama_token id = llama_sampler_sample(smpl, cur_p); + if (grammar_first) { + return id; + } + // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; diff --git a/common/sampling.h b/common/sampling.h index 9bdeadf784398..87673efa3bf2b 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -92,7 +92,10 @@ void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx); +// if grammar_first is true, the grammar is applied before the constraints (slower) +// useful in cases where all the resulting candidates must fit the grammar +// +llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 6b9f3e0d55bc2..6ff62ae067da3 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -141,9 +141,7 @@ while n_cur <= n_len { llama_sampler_set_logits(smpl, logits) - let new_token_id = llama_sampler_sample_dist(smpl, nil) - - // const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nil, false); + let new_token_id = llama_sampler_sample(smpl, nil) // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 666e89764834d..1a4908501c35c 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -399,7 +399,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( llama_sampler_set_logits(sampling, logits); // sample the most likely token - const auto new_token_id = llama_sampler_sample_greedy(sampling, nullptr, false); + const auto new_token_id = llama_sampler_sample(sampling, nullptr); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 930336b270f70..bd6513d3457ab 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,7 +43,9 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - self.sampling = llama_sampler_init(context, llama_sampler_default_params()) + var sparams = llama_sampler_default_params() + sparams.type = LLAMA_SAMPLER_TYPE_GREEDY + self.sampling = llama_sampler_init(context, sparams) } deinit { @@ -151,7 +153,7 @@ actor LlamaContext { llama_sampler_set_logits(sampling, logits); - new_token_id = llama_sampler_sample_greedy(sampling, nil, false) + new_token_id = llama_sampler_sample(sampling, nil) if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 5cd14c49d6ebf..d51c768493b7a 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -228,20 +228,13 @@ int main(int argc, char ** argv) { bool accept = false; if (params.sparams.temp > 0) { // stochastic verification - const float * logits = llama_get_logits_ith(ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]); - - gpt_sampler_set_logits(smpl, logits); + gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); auto & dist_tgt = *gpt_sampler_get_candidates(smpl); - gpt_sampler_apply_grammar(smpl, &dist_tgt); - llama_constraint_apply(softmax, &dist_tgt); - float p_tgt = 0.0f; float p_dft = 0.0f; - // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); - while (active_seqs.size() > 0) { // randomly select a sequence to verify from active sequences std::uniform_int_distribution u_int_dist(0, active_seqs.size() - 1); @@ -259,9 +252,13 @@ int main(int argc, char ** argv) { } continue; } + LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + + //GGML_ASSERT(dist_tgt.size <= dist_dft.size); + // acquire the token probabilities assigned by the draft and target models for (size_t i = 0; i < dist_tgt.size; i++) { if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) { @@ -291,7 +288,6 @@ int main(int argc, char ** argv) { // calculate residual probability GGML_ASSERT(dist_tgt.sorted); GGML_ASSERT(dist_dft.sorted); - float sum_probs = 0.0f; // sort dist by id std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) { @@ -301,10 +297,18 @@ int main(int argc, char ** argv) { return a.id < b.id; }); + float sum_probs = 0.0f; + for (size_t i = 0; i < dist_tgt.size; i++) { - dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + if (i < dist_dft.size) { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p); + } else { + dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p); + } + sum_probs += dist_tgt.data[i].p; } + for (size_t i = 0; i < dist_tgt.size; i++) { dist_tgt.data[i].p /= sum_probs; } @@ -334,7 +338,16 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = gpt_sampler_sample(smpl, &dist_tgt); + std::vector probs(dist_tgt.size); + for (size_t i = 0; i < dist_tgt.size; ++i) { + probs[i] = dist_tgt.data[i].p; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + const int idx = dist(rng); + + token_id = dist_tgt.data[idx].id; gpt_sampler_accept(smpl, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -467,7 +480,7 @@ int main(int argc, char ** argv) { continue; } - gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft); + gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl); @@ -512,7 +525,6 @@ int main(int argc, char ** argv) { } drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl); - sa.push_back(n_seq_cur); n_seq_cur++; From 69551ffd609ffdcfd087110b2f09ccb46287432a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 10:18:04 +0300 Subject: [PATCH 25/47] sampling : remove top-k min_keep, fix mirostat init and state --- common/sampling.cpp | 2 +- examples/batched.swift/Sources/main.swift | 2 +- examples/batched/batched.cpp | 2 +- include/llama.h | 2 +- src/llama-sampling.cpp | 105 +++++++++++----------- src/llama-sampling.h | 2 +- src/llama.cpp | 4 +- tests/test-sampling.cpp | 30 +++---- 8 files changed, 76 insertions(+), 73 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index edc6cd05b3ac9..2887207f1e228 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -67,7 +67,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st for (const auto & cnstr : params.constraints) { switch (cnstr) { case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep)); + llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k)); break; case GPT_CONSTRAINT_TYPE_TOP_P: llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 6ff62ae067da3..a02fa4da9183d 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -61,7 +61,7 @@ defer { llama_sampler_free(smpl) } -llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1)); +llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 9f0a40873b63a..5896526abdc77 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -70,7 +70,7 @@ int main(int argc, char ** argv) { llama_sampler * smpl = llama_sampler_init(model, sparams); - llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep)); + llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k)); llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); diff --git a/include/llama.h b/include/llama.h index 763d32abb6237..02f7a849175bc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1045,7 +1045,7 @@ extern "C" { }; LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k, int32_t min_keep); + LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k); LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 4e44ec417af21..c07c509bccbda 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -49,7 +49,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { } } -static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k, size_t min_keep) { +static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; @@ -59,7 +59,6 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k = cur_p->size; } - k = std::max(k, (int) min_keep); k = std::min(k, (int) cur_p->size); // Sort scores in descending order @@ -449,7 +448,6 @@ struct llama_constraint * llama_constraint_init_softmax_impl() { struct llama_constraint_context_top_k { const int32_t k; - const size_t min_keep; }; static struct llama_constraint_i llama_constraint_top_k_i = { @@ -457,24 +455,23 @@ static struct llama_constraint_i llama_constraint_top_k_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_constraint_top_k_impl(cur_p, ctx->k, ctx->min_keep); + llama_constraint_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; - return llama_constraint_init_top_k_impl(ctx->k, ctx->min_keep); + return llama_constraint_init_top_k_impl(ctx->k); }, /* .free = */ [](struct llama_constraint * cnstr) { delete (llama_constraint_context_top_k *) cnstr->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k, size_t min_keep) { +struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { return new llama_constraint { /* .iface = */ &llama_constraint_top_k_i, /* .ctx = */ new llama_constraint_context_top_k { - /*.k =*/ k, - /*.min_keep =*/ min_keep, + /* .k = */ k, }, }; } @@ -507,8 +504,8 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k return new llama_constraint { /* .iface = */ &llama_constraint_top_p_i, /* .ctx = */ new llama_constraint_context_top_p { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -541,8 +538,8 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k return new llama_constraint { /* .iface = */ &llama_constraint_min_p_i, /* .ctx = */ new llama_constraint_context_min_p { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -575,8 +572,8 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m return new llama_constraint { /* .iface = */ &llama_constraint_tail_free_i, /* .ctx = */ new llama_constraint_context_tail_free { - /*.z =*/ z, - /*.min_keep =*/ min_keep, + /* .z = */ z, + /*. min_keep = */ min_keep, }, }; } @@ -609,8 +606,8 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min return new llama_constraint { /* .iface = */ &llama_constraint_typical_i, /* .ctx = */ new llama_constraint_context_typical { - /*.p =*/ p, - /*.min_keep =*/ min_keep, + /* .p = */ p, + /* .min_keep = */ min_keep, }, }; } @@ -642,7 +639,7 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { return new llama_constraint { /* .iface = */ &llama_constraint_temp_i, /* .ctx = */ new llama_constraint_context_temp { - /*.temp =*/ temp, + /*.temp = */ temp, }, }; } @@ -683,9 +680,9 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float return new llama_constraint { /* .iface = */ &llama_constraint_temp_ext_i, /* .ctx = */ new llama_constraint_context_temp_ext { - /*.temp =*/ temp, - /*.delta =*/ delta, - /*.exponent =*/ exponent, + /* .temp = */ temp, + /* .delta = */ delta, + /* .exponent = */ exponent, }, }; } @@ -745,7 +742,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { float epsilon_hat = s_hat - 1; float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); - llama_constraint_top_k_impl(cur_p, int(k), 1); + llama_constraint_top_k_impl(cur_p, std::max(int(k), 1)); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -755,7 +752,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { }, /* .reset = */ [](struct llama_constraint * cnstr) { auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; - ctx->mu = 0.0f; + ctx->mu = 2.0f*ctx->tau; }, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; @@ -770,12 +767,12 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_i, /* .ctx = */ new llama_constraint_context_mirostat { - /*.vocab =*/ &vocab, - /*.tau =*/ tau, - /*.eta =*/ eta, - /*.m =*/ m, - /*.mu =*/ 0.0f, - /*.cur =*/ {}, + /* .vocab = */ &vocab, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .cur = */ {}, }, }; } @@ -826,10 +823,16 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { // Normalize the probabilities of the remaining words llama_constraint_softmax_impl(cur_p); + + // remember the order to be able to compute the distance later when accepting the token + ctx->cur.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->cur[i] = cur_p->data[i]; + } }, /* .reset = */ [](struct llama_constraint * cnstr) { auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; - ctx->mu = 0.0f; + ctx->mu = 2.0f*ctx->tau; }, /* .copy = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; @@ -844,10 +847,10 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa return new llama_constraint { /* .iface = */ &llama_constraint_mirostat_v2_i, /* .ctx = */ new llama_constraint_context_mirostat_v2 { - /*.tau =*/ tau, - /*.eta =*/ eta, - /*.mu =*/ 0.0f, - /*.cur =*/ {}, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .cur = */ {}, }, }; } @@ -919,17 +922,17 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { - /*.vocab = */ &vocab, - /*.grammar_str = */ grammar_str, - /*.grammar_root = */ grammar_root, - /*.grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), + /* .vocab = */ &vocab, + /* .grammar_str = */ grammar_str, + /* .grammar_root = */ grammar_root, + /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), }; } else { *ctx = { - /*.vocab = */ &vocab, - /*.grammar_str = */ {}, - /*.grammar_root = */ {}, - /*.grammar = */ nullptr, + /* .vocab = */ &vocab, + /* .grammar_str = */ {}, + /* .grammar_root = */ {}, + /* .grammar = */ nullptr, }; } @@ -1023,14 +1026,14 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam return new llama_constraint { /* .iface = */ &llama_constraint_penalties_i, /* .ctx = */ new llama_constraint_context_penalties { - /*.vocab =*/ &vocab, - /*.penalty_last_n =*/ penalty_last_n, - /*.penalty_repeat =*/ penalty_repeat, - /*.penalty_freq =*/ penalty_freq, - /*.penalty_present =*/ penalty_present, - /*.penalize_nl =*/ penalize_nl, - /*.ignore_eos =*/ ignore_eos, - /*.prev =*/ ring_buffer(penalty_last_n), + /* .vocab = */ &vocab, + /* .penalty_last_n = */ penalty_last_n, + /* .penalty_repeat = */ penalty_repeat, + /* .penalty_freq = */ penalty_freq, + /* .penalty_present = */ penalty_present, + /* .penalize_nl = */ penalize_nl, + /* .ignore_eos = */ ignore_eos, + /* .prev = */ ring_buffer(penalty_last_n), }, }; } @@ -1072,8 +1075,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( return new llama_constraint { /* .iface = */ &llama_constraint_logit_bias_i, /* .ctx = */ new llama_constraint_context_logit_bias { - /*.vocab =*/ &vocab, - /*.logit_bias=*/ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .vocab = */ &vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), }, }; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index acb1e04feb4dc..1295bc823ba63 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -21,7 +21,7 @@ void llama_constraint_penalties_impl( // constraints struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k, size_t min_keep); +struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k); struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); diff --git a/src/llama.cpp b/src/llama.cpp index 2f5df2433f499..6a30daf396dc4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20611,8 +20611,8 @@ struct llama_constraint * llama_constraint_init_softmax(void) { return llama_constraint_init_softmax_impl(); } -struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep) { - return llama_constraint_init_top_k_impl(k, min_keep); +struct llama_constraint * llama_constraint_init_top_k(int32_t k) { + return llama_constraint_init_top_k_impl(k); } struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 0c9b46429caa7..74bb4a3a3a40c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -19,7 +19,7 @@ static void dump(const llama_token_data_array * cur_p) { #define DUMP(__cur_p) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__cur_p)); printf("-\n"); } while(0) -#define TEST(__cnstr, __cur_p) do { \ +#define APPLY(__cnstr, __cur_p) do { \ auto * cnstr = (__cnstr); \ llama_constraint_apply(cnstr, (__cur_p)); \ llama_constraint_free(cnstr); \ @@ -36,9 +36,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector llama_token_data_array cur_p = { cur.data(), cur.size(), false }; DUMP(&cur_p); - TEST(llama_constraint_init_tail_free(z, 1), &cur_p); + APPLY(llama_constraint_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -102,9 +102,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector Date: Thu, 5 Sep 2024 10:25:33 +0300 Subject: [PATCH 26/47] sampling : change _cp/copy to clone --- common/sampling.cpp | 10 +++---- common/sampling.h | 2 +- examples/speculative/speculative.cpp | 4 +-- include/llama.h | 17 +++++++++--- src/llama-grammar.cpp | 2 +- src/llama-grammar.h | 2 +- src/llama-sampling.cpp | 40 ++++++++++++++-------------- src/llama-sampling.h | 16 ++--------- src/llama.cpp | 8 +++--- 9 files changed, 50 insertions(+), 51 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 2887207f1e228..914b579a0054e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -114,13 +114,13 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } } -struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl) { +struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, - /* .bias = */ llama_constraint_cp(gsmpl->bias), - /* .pnlt = */ llama_constraint_cp(gsmpl->pnlt), - /* .grmr = */ llama_constraint_cp(gsmpl->grmr), - /* .smpl = */ llama_sampler_cp (gsmpl->smpl) + /* .bias = */ llama_constraint_clone(gsmpl->bias), + /* .pnlt = */ llama_constraint_clone(gsmpl->pnlt), + /* .grmr = */ llama_constraint_clone(gsmpl->grmr), + /* .smpl = */ llama_sampler_clone (gsmpl->smpl) }; } diff --git a/common/sampling.h b/common/sampling.h index 87673efa3bf2b..c260ef0553d8e 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -68,7 +68,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl); -struct gpt_sampler * gpt_sampler_cp(gpt_sampler * gsmpl); +struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index d51c768493b7a..9f596ec914b54 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -451,7 +451,7 @@ int main(int argc, char ** argv) { if (drafts[0].smpl) { gpt_sampler_free(drafts[0].smpl); } - drafts[0].smpl = gpt_sampler_cp(smpl); + drafts[0].smpl = gpt_sampler_clone(smpl); int n_seq_cur = 1; int n_past_cur = n_past_dft; @@ -523,7 +523,7 @@ int main(int argc, char ** argv) { if (drafts[n_seq_cur].smpl) { gpt_sampler_free(drafts[n_seq_cur].smpl); } - drafts[n_seq_cur].smpl = gpt_sampler_cp(drafts[s].smpl); + drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl); sa.push_back(n_seq_cur); diff --git a/include/llama.h b/include/llama.h index 02f7a849175bc..0fc45bef362eb 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1032,7 +1032,7 @@ extern "C" { void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required void (*reset) ( struct llama_constraint * cnstr); // can be NULL - struct llama_constraint * (*copy) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL + struct llama_constraint * (*clone) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL void (*free) ( struct llama_constraint * cnstr); // can be NULL if ctx is NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph @@ -1053,11 +1053,22 @@ extern "C" { LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_constraint * llama_constraint_init_mirostat( const struct llama_model * model, float tau, float eta); + /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. + /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2( float tau, float eta); @@ -1081,7 +1092,7 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); - LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr); + LLAMA_API struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr); // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); @@ -1094,7 +1105,7 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index a9813ebbfb228..09f756fbec727 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -1050,7 +1050,7 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) { +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar) { llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, }; // redirect elements in stacks to point to new rules diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 6b9a2af8dd725..419a616d644dc 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -131,7 +131,7 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar); +struct llama_grammar * llama_grammar_clone_impl(const struct llama_grammar & grammar); // TODO: move the API below as member functions of llama_grammar void llama_grammar_apply_impl( diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index c07c509bccbda..bf71f98f10834 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -433,7 +433,7 @@ static struct llama_constraint_i llama_constraint_softmax_i = { llama_constraint_softmax_impl(cur_p); }, /* .reset = */ nullptr, - /* .copy = */ nullptr, + /* .clone = */ nullptr, /* .free = */ nullptr, }; @@ -458,7 +458,7 @@ static struct llama_constraint_i llama_constraint_top_k_i = { llama_constraint_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; return llama_constraint_init_top_k_impl(ctx->k); }, @@ -491,7 +491,7 @@ static struct llama_constraint_i llama_constraint_top_p_i = { llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx; return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); }, @@ -525,7 +525,7 @@ static struct llama_constraint_i llama_constraint_min_p_i = { llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx; return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); }, @@ -559,7 +559,7 @@ static struct llama_constraint_i llama_constraint_tail_free_i = { llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx; return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); }, @@ -593,7 +593,7 @@ static struct llama_constraint_i llama_constraint_typical_i = { llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx; return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); }, @@ -626,7 +626,7 @@ static struct llama_constraint_i llama_constraint_temp_i = { llama_constraint_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx; return llama_constraint_init_temp_impl(ctx->temp); }, @@ -667,7 +667,7 @@ static struct llama_constraint_i llama_constraint_temp_ext_i = { } }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx; return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, @@ -754,7 +754,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); }, @@ -834,7 +834,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta); }, @@ -891,7 +891,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; }, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); @@ -901,7 +901,7 @@ static struct llama_constraint_i llama_constraint_grammar_i = { ctx_dst->grammar_str = ctx_src->grammar_str; ctx_dst->grammar_root = ctx_src->grammar_root; - ctx_dst->grammar = llama_grammar_cp_impl(*ctx_src->grammar); + ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar); } return result; @@ -998,7 +998,7 @@ static struct llama_constraint_i llama_constraint_penalties_i = { auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; ctx->prev.clear(); }, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx; auto * result = llama_constraint_init_penalties_impl( *ctx_src->vocab, @@ -1059,7 +1059,7 @@ static struct llama_constraint_i llama_constraint_logit_bias_i = { } }, /* .reset = */ nullptr, - /* .copy = */ [](const struct llama_constraint * cnstr) { + /* .clone = */ [](const struct llama_constraint * cnstr) { const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx; return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, @@ -1083,8 +1083,8 @@ struct llama_constraint * llama_constraint_init_logit_bias_impl( //////////////////////////////////////// -struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr) { - return cnstr.iface->copy ? cnstr.iface->copy(&cnstr) : nullptr; +struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr) { + return cnstr.iface->clone ? cnstr.iface->clone(&cnstr) : nullptr; } void llama_constraint_free_impl(struct llama_constraint * cnstr) { @@ -1148,7 +1148,7 @@ void llama_sampler_free_impl(struct llama_sampler * smpl) { delete smpl; } -struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) { +struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { auto * result = new llama_sampler { /* .params = */ smpl.params, /* .vocab = */ smpl.vocab, @@ -1163,7 +1163,7 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) /* .n_sample = */ 0, }; - // copy the constraints objects + // clone the constraints objects result->constraints.clear(); for (const auto & cnstr : smpl.constraints) { if (cnstr->ctx == nullptr) { @@ -1172,8 +1172,8 @@ struct llama_sampler * llama_sampler_cp_impl(const struct llama_sampler & smpl) /* .ctx = */ nullptr, }); } else { - GGML_ASSERT(cnstr->iface->copy); - result->constraints.push_back(cnstr->iface->copy(cnstr)); + GGML_ASSERT(cnstr->iface->clone); + result->constraints.push_back(cnstr->iface->clone(cnstr)); } } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 1295bc823ba63..453650b28fb7a 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -29,24 +29,12 @@ struct llama_constraint * llama_constraint_init_typical_impl (float p, size struct llama_constraint * llama_constraint_init_temp_impl (float t); struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); -/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - struct llama_constraint * llama_constraint_init_mirostat_impl( const struct llama_vocab & vocab, float tau, float eta, int32_t m); -/// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. -/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. -/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. -/// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. -/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. struct llama_constraint * llama_constraint_init_mirostat_v2_impl( float tau, float eta); @@ -70,7 +58,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl( int32_t n_logit_bias, const llama_logit_bias * logit_bias); -struct llama_constraint * llama_constraint_cp_impl(const struct llama_constraint & cnstr); +struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr); void llama_constraint_free_impl(struct llama_constraint * cnstr); @@ -106,7 +94,7 @@ struct llama_sampler { struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); void llama_sampler_free_impl ( struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); void llama_sampler_reset_impl ( struct llama_sampler & smpl); void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); diff --git a/src/llama.cpp b/src/llama.cpp index 6a30daf396dc4..436c21f9dffe5 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20669,8 +20669,8 @@ LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } -struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr) { - return llama_constraint_cp_impl(*cnstr); +struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr) { + return llama_constraint_clone_impl(*cnstr); } void llama_constraint_free(struct llama_constraint * cnstr) { @@ -20705,8 +20705,8 @@ void llama_sampler_free(struct llama_sampler * smpl) { llama_sampler_free_impl(smpl); } -struct llama_sampler * llama_sampler_cp(const struct llama_sampler * smpl) { - return llama_sampler_cp_impl(*smpl); +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + return llama_sampler_clone_impl(*smpl); } void llama_sampler_reset(struct llama_sampler * smpl) { From 595711417ab8e3ee14144d5382cf252c52424d4c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 10:33:04 +0300 Subject: [PATCH 27/47] sampling : add name API + option to disable timings --- common/sampling.cpp | 2 +- include/llama.h | 9 ++++++--- src/llama-sampling.cpp | 8 ++++++++ src/llama-sampling.h | 7 ++++--- src/llama.cpp | 21 ++++++++++++++------- 5 files changed, 33 insertions(+), 14 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 914b579a0054e..0047ead3424b0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -31,7 +31,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) { const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i); - result += " -> " + std::string(cnstr->iface->name(cnstr)) + " "; + result += std::string(" -> ") + llama_constraint_name(cnstr) + " "; } return result; diff --git a/include/llama.h b/include/llama.h index 0fc45bef362eb..a43b629054091 100644 --- a/include/llama.h +++ b/include/llama.h @@ -381,6 +381,8 @@ extern "C" { // TODO: will be used by the llama_decode_with_sampler() API in the future enum llama_sampler_type type; + + bool no_timing; // whether to measure performance timings } llama_sampler_params; // performance timing information @@ -1097,9 +1099,10 @@ extern "C" { // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); - LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token); - LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * cur_p); - LLAMA_API void llama_constraint_reset (struct llama_constraint * cnstr); + LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr); + LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token); + LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); + LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr); // samplers diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index bf71f98f10834..cf28baab5978f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1190,6 +1190,14 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { // TODO: should we reset the timings? } +const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) { + if (!cnstr.iface) { + return "(null)"; + } + + return cnstr.iface->name(&cnstr); +} + void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { smpl.prev.push_back(token); diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 453650b28fb7a..18304b49a8ef1 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -62,9 +62,10 @@ struct llama_constraint * llama_constraint_clone_impl(const struct llama_constra void llama_constraint_free_impl(struct llama_constraint * cnstr); -void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token); -void llama_constraint_apply_impl (struct llama_constraint & cnstr, struct llama_token_data_array * cur_p); -void llama_constraint_reset_impl (struct llama_constraint & cnstr); +const char * llama_constraint_name_impl (const struct llama_constraint & cnstr); +void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token); +void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p); +void llama_constraint_reset_impl ( struct llama_constraint & cnstr); // samplers diff --git a/src/llama.cpp b/src/llama.cpp index 436c21f9dffe5..2636f2316104b 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -148,10 +148,12 @@ static void zeros(std::ofstream & file, size_t n) { } struct time_meas { - time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {} + time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} ~time_meas() { - t_acc += ggml_time_us() - t_start_us; + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } } const int64_t t_start_us; @@ -17937,9 +17939,10 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_params llama_sampler_default_params() { struct llama_sampler_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_prev =*/ 256, - /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_prev =*/ 256, + /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, + /*.no_timing =*/ false, // TODO: change to true and set explicitly in examples }; return result; @@ -20681,6 +20684,10 @@ void llama_constraint_free(struct llama_constraint * cnstr) { llama_constraint_free_impl(cnstr); } +const char * llama_constraint_name(const struct llama_constraint * cnstr) { + return llama_constraint_name_impl(*cnstr); +} + void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { llama_constraint_accept_impl(*cnstr, token); } @@ -20718,7 +20725,7 @@ void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { } void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); + time_meas tm(smpl->t_sample_us, smpl->params.no_timing); if (cur_p == nullptr) { cur_p = &smpl->cur_p; @@ -20756,7 +20763,7 @@ struct llama_constraint * llama_sampler_constraint_get(const struct llama_sample } llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us); + time_meas tm(smpl->t_sample_us, smpl->params.no_timing); if (cur_p == nullptr) { cur_p = &smpl->cur_p; From a2d8b27a4b29b44253c05e2a721c5a152a29fd50 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 10:38:31 +0300 Subject: [PATCH 28/47] llama : restore comments in llama.h ggml-ci --- include/llama.h | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/include/llama.h b/include/llama.h index a43b629054091..dd047e0aceb9c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1046,13 +1046,26 @@ extern "C" { llama_constraint_context_t ctx; }; + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); + + /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k); + + /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + + /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + + /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + + /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + + /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. From 0b6dfcebb2f32504d266352c33c102d2234f4ff2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 16:49:14 +0300 Subject: [PATCH 29/47] llama : remove llama_constraint ggml-ci --- common/common.cpp | 24 +- common/sampling.cpp | 402 ++++++---- common/sampling.h | 54 +- examples/batched.swift/Sources/main.swift | 16 +- examples/batched/batched.cpp | 18 +- examples/gritlm/gritlm.cpp | 12 +- .../llama/src/main/cpp/llama-android.cpp | 8 +- .../llama.cpp.swift/LibLlama.swift | 12 +- examples/passkey/passkey.cpp | 13 +- examples/save-load-state/save-load-state.cpp | 42 +- examples/server/server.cpp | 24 +- examples/simple/simple.cpp | 13 +- examples/speculative/speculative.cpp | 6 +- include/llama.h | 162 ++-- src/llama-impl.h | 14 + src/llama-sampling.cpp | 739 +++++++++--------- src/llama-sampling.h | 114 ++- src/llama.cpp | 230 ++---- tests/test-sampling.cpp | 48 +- 19 files changed, 958 insertions(+), 993 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f7095c7f3c1de..2a51649a5c49f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -841,15 +841,15 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.defrag_thold = std::stof(argv[i]); return true; } - if (arg == "--samplers" || arg == "--constraints") { + if (arg == "--samplers") { CHECK_ARG - const auto constraint_names = string_split(argv[i], ';'); - sparams.constraints = gpt_constraint_types_from_names(constraint_names, true); + const auto sampler_names = string_split(argv[i], ';'); + sparams.samplers = gpt_sampler_types_from_names(sampler_names, true); return true; } if (arg == "--sampling-seq") { CHECK_ARG - sparams.constraints = gpt_constraint_types_from_chars(argv[i]); + sparams.samplers = gpt_sampler_types_from_chars(argv[i]); return true; } if (arg == "--top-p") { @@ -1706,13 +1706,13 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { const auto & sparams = params.sparams; - std::string constraint_type_chars; - std::string constraint_type_names; - for (const auto & constraint : sparams.constraints) { - constraint_type_chars += gpt_constraint_type_to_chr(constraint); - constraint_type_names += gpt_constraint_type_to_str(constraint) + ";"; + std::string sampler_type_chars; + std::string sampler_type_names; + for (const auto & sampler : sparams.samplers) { + sampler_type_chars += gpt_sampler_type_to_chr(sampler); + sampler_type_names += gpt_sampler_type_to_str(sampler) + ";"; } - constraint_type_names.pop_back(); + sampler_type_names.pop_back(); struct option_info { LLAMA_COMMON_ATTRIBUTE_FORMAT(4, 5) @@ -1826,9 +1826,9 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "sampling" }); options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed }); options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n" - "(default: %s)", constraint_type_names.c_str() }); + "(default: %s)", sampler_type_names.c_str() }); options.push_back({ "*", " --sampling-seq SEQUENCE", - "simplified sequence for samplers that will be used (default: %s)", constraint_type_chars.c_str() }); + "simplified sequence for samplers that will be used (default: %s)", sampler_type_chars.c_str() }); options.push_back({ "*", " --ignore-eos", "ignore end of stream token and continue generating (implies --logit-bias EOS-inf)" }); options.push_back({ "*", " --penalize-nl", "penalize newline tokens (default: %s)", sparams.penalize_nl ? "true" : "false" }); options.push_back({ "*", " --temp T", "temperature (default: %.1f)", (double)sparams.temp }); diff --git a/common/sampling.cpp b/common/sampling.cpp index 0047ead3424b0..de7c9b1b97395 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -2,14 +2,127 @@ #include "common.h" +// the ring buffer works similarly to std::deque, but with a fixed capacity +// TODO: deduplicate with llama-impl.h +template +struct ring_buffer { + ring_buffer(size_t cap) : capacity(cap), data(cap) {} + + T & front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + const T & front() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[first]; + } + + T & back() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + const T & back() const { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + return data[pos]; + } + + void push_back(const T & value) { + if (sz == capacity) { + // advance the start when buffer is full + first = (first + 1) % capacity; + } else { + sz++; + } + data[pos] = value; + pos = (pos + 1) % capacity; + } + + T pop_front() { + if (sz == 0) { + throw std::runtime_error("ring buffer is empty"); + } + T value = data[first]; + first = (first + 1) % capacity; + sz--; + return value; + } + + const T & rat(size_t i) const { + if (i >= sz) { + throw std::runtime_error("ring buffer: index out of bounds"); + } + return data[(first + sz - i - 1) % capacity]; + } + + std::vector to_vector() const { + std::vector result; + result.reserve(sz); + for (size_t i = 0; i < sz; i++) { + result.push_back(data[(first + i) % capacity]); + } + return result; + } + + void clear() { + // here only reset the status of the buffer + sz = 0; + first = 0; + pos = 0; + } + + bool empty() const { + return sz == 0; + } + + size_t size() const { + return sz; + } + + size_t capacity = 0; + size_t sz = 0; + size_t first = 0; + size_t pos = 0; + std::vector data; +}; + struct gpt_sampler { gpt_sampler_params params; - struct llama_constraint * bias; - struct llama_constraint * pnlt; - struct llama_constraint * grmr; + struct llama_sampler * bias; + struct llama_sampler * pnlt; + struct llama_sampler * grmr; + + struct llama_sampler * chain; + + ring_buffer prev; + + std::vector cur; + + llama_token_data_array cur_p; + + void set_logits(struct llama_context * ctx, int idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + cur.resize(n_vocab); - struct llama_sampler * smpl; + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + } }; std::string gpt_sampler_params::print() const { @@ -29,28 +142,26 @@ std::string gpt_sampler_params::print() const { std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { std::string result = "\tlogits"; - for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) { - const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i); - result += std::string(" -> ") + llama_constraint_name(cnstr) + " "; + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + result += std::string(" -> ") + llama_sampler_name(smpl) + " "; } return result; } struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { - llama_sampler_params lparams = llama_sampler_default_params(); + llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); - lparams.seed = params.seed; - lparams.n_prev = params.n_prev; - lparams.type = params.temp <= 0.0f ? LLAMA_SAMPLER_TYPE_GREEDY : LLAMA_SAMPLER_TYPE_DIST; + lparams.no_timing = false; auto * result = new gpt_sampler { /* .params = */ params, - /* .bias = */ llama_constraint_init_logit_bias( + /* .bias = */ llama_sampler_init_logit_bias( model, params.logit_bias.size(), params.logit_bias.data()), - /* .pnlt = */ llama_constraint_init_penalties( + /* .pnlt = */ llama_sampler_init_penalties( model, params.penalty_last_n, params.penalty_repeat, @@ -58,45 +169,53 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st params.penalty_present, params.penalize_nl, params.ignore_eos), - /* .grmr = */ llama_constraint_init_grammar(model, params.grammar.c_str(), "root"), - /* .smpl = */ llama_sampler_init(model, lparams) + /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), + /* .chain = */ llama_sampler_chain_init(lparams), + /* .prev = */ ring_buffer(params.n_prev), + /* .cur = */ {}, + /* .cur_p = */ {}, }; if (params.temp > 0.0f) { if (params.mirostat == 0) { - for (const auto & cnstr : params.constraints) { + for (const auto & cnstr : params.samplers) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k)); + case GPT_SAMPLER_TYPE_TOP_K: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; - case GPT_CONSTRAINT_TYPE_TOP_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep)); + case GPT_SAMPLER_TYPE_TOP_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_MIN_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep)); + case GPT_SAMPLER_TYPE_MIN_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TFS_Z: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep)); + case GPT_SAMPLER_TYPE_TFS_Z: + llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep)); + case GPT_SAMPLER_TYPE_TYPICAL_P: + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); break; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); + case GPT_SAMPLER_TYPE_TEMPERATURE: + llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; default: - GGML_ASSERT(false && "unknown constraint type"); + GGML_ASSERT(false && "unknown sampler type"); } } } else if (params.mirostat == 1) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { - llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp)); - llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); + } else { + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); } return result; @@ -104,11 +223,11 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { - llama_constraint_free(gsmpl->bias); - llama_constraint_free(gsmpl->pnlt); - llama_constraint_free(gsmpl->grmr); + llama_sampler_free(gsmpl->bias); + llama_sampler_free(gsmpl->pnlt); + llama_sampler_free(gsmpl->grmr); - llama_sampler_free(gsmpl->smpl); + llama_sampler_free(gsmpl->chain); delete gsmpl; } @@ -117,69 +236,66 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, - /* .bias = */ llama_constraint_clone(gsmpl->bias), - /* .pnlt = */ llama_constraint_clone(gsmpl->pnlt), - /* .grmr = */ llama_constraint_clone(gsmpl->grmr), - /* .smpl = */ llama_sampler_clone (gsmpl->smpl) + /* .bias = */ llama_sampler_clone(gsmpl->bias), + /* .pnlt = */ llama_sampler_clone(gsmpl->pnlt), + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, }; } void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { if (apply_grammar) { - llama_constraint_accept(gsmpl->grmr, token); + llama_sampler_accept(gsmpl->grmr, token); } - llama_sampler_accept(gsmpl->smpl, token); + llama_sampler_accept(gsmpl->chain, token); + + gsmpl->prev.push_back(token); } void gpt_sampler_reset(struct gpt_sampler * gsmpl) { - llama_constraint_reset(gsmpl->grmr); + llama_sampler_reset(gsmpl->grmr); - llama_sampler_reset(gsmpl->smpl); -} - -void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits) { - llama_sampler_set_logits(gsmpl->smpl, logits); + llama_sampler_reset(gsmpl->chain); } llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { - return llama_sampler_get_candidates(gsmpl->smpl); + return &gsmpl->cur_p; } llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { - return llama_sampler_last(gsmpl->smpl); + return gsmpl->prev.rat(0); } -void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl) { - llama_print_timings(ctx, gsmpl ? gsmpl->smpl : nullptr); -} - -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p) { - return llama_sampler_sample(gsmpl->smpl, cur_p); +void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { + llama_print_timings(ctx, gsmpl ? gsmpl->chain : nullptr); } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & bias = gsmpl->bias; - auto & pnlt = gsmpl->pnlt; - auto & grmr = gsmpl->grmr; - auto & smpl = gsmpl->smpl; - - const auto * logits = llama_get_logits_ith(ctx, idx); + auto & bias = gsmpl->bias; + auto & pnlt = gsmpl->pnlt; + auto & grmr = gsmpl->grmr; + auto & chain = gsmpl->chain; - llama_sampler_set_logits(smpl, logits); + gsmpl->set_logits(ctx, idx); - auto * cur_p = llama_sampler_get_candidates(smpl); + auto & cur_p = gsmpl->cur_p; - llama_constraint_apply(bias, cur_p); - llama_constraint_apply(pnlt, cur_p); + llama_sampler_apply(bias, &cur_p); + llama_sampler_apply(pnlt, &cur_p); if (grammar_first) { - llama_constraint_apply(grmr, cur_p); + llama_sampler_apply(grmr, &cur_p); } - llama_sampler_apply(smpl, cur_p); + llama_sampler_apply(chain, &cur_p); + + const llama_token id = cur_p.data[cur_p.selected].id; - const llama_token id = llama_sampler_sample(smpl, cur_p); + GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); if (grammar_first) { return id; @@ -188,9 +304,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false }; - llama_constraint_apply(grmr, &single_token_data_array); + llama_sampler_apply(grmr, &single_token_data_array); // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -199,28 +315,22 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context } } - // if the token is not valid, sample again, first apply the grammar constraints and then sample - llama_sampler_set_logits(smpl, logits); + // if the token is not valid, sample again, first apply the grammar samplers and then sample + gsmpl->set_logits(ctx, idx); - llama_constraint_apply(bias, cur_p); - llama_constraint_apply(pnlt, cur_p); - llama_constraint_apply(grmr, cur_p); + llama_sampler_apply(bias, &cur_p); + llama_sampler_apply(pnlt, &cur_p); + llama_sampler_apply(grmr, &cur_p); - llama_sampler_apply(smpl, cur_p); + llama_sampler_apply(chain, &cur_p); - return llama_sampler_sample(smpl, cur_p); -} - -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p) { - GGML_ASSERT(cur_p != nullptr); + GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); - llama_constraint_apply(gsmpl->grmr, cur_p); + return cur_p.data[cur_p.selected].id; } std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { - auto & smpl = gsmpl->smpl; - - n = std::min(n, llama_sampler_n_prev(smpl)); + n = std::min(n, (int) gsmpl->prev.size()); if (n <= 0) { return ""; @@ -230,7 +340,7 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab for (int i = n - 1; i >= 0; i--) { - const llama_token id = llama_sampler_prev(smpl, i); + const llama_token id = gsmpl->prev.rat(i); GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen"); @@ -240,95 +350,95 @@ std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, return result; } -char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr) { +char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: return 'k'; - case GPT_CONSTRAINT_TYPE_TFS_Z: return 'f'; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: return 'y'; - case GPT_CONSTRAINT_TYPE_TOP_P: return 'p'; - case GPT_CONSTRAINT_TYPE_MIN_P: return 'm'; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: return 't'; + case GPT_SAMPLER_TYPE_TOP_K: return 'k'; + case GPT_SAMPLER_TYPE_TFS_Z: return 'f'; + case GPT_SAMPLER_TYPE_TYPICAL_P: return 'y'; + case GPT_SAMPLER_TYPE_TOP_P: return 'p'; + case GPT_SAMPLER_TYPE_MIN_P: return 'm'; + case GPT_SAMPLER_TYPE_TEMPERATURE: return 't'; default : return '?'; } } -std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr) { +std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) { switch (cnstr) { - case GPT_CONSTRAINT_TYPE_TOP_K: return "top_k"; - case GPT_CONSTRAINT_TYPE_TFS_Z: return "tfs_z"; - case GPT_CONSTRAINT_TYPE_TYPICAL_P: return "typ_p"; - case GPT_CONSTRAINT_TYPE_TOP_P: return "top_p"; - case GPT_CONSTRAINT_TYPE_MIN_P: return "min_p"; - case GPT_CONSTRAINT_TYPE_TEMPERATURE: return "temperature"; + case GPT_SAMPLER_TYPE_TOP_K: return "top_k"; + case GPT_SAMPLER_TYPE_TFS_Z: return "tfs_z"; + case GPT_SAMPLER_TYPE_TYPICAL_P: return "typ_p"; + case GPT_SAMPLER_TYPE_TOP_P: return "top_p"; + case GPT_SAMPLER_TYPE_MIN_P: return "min_p"; + case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature"; default : return ""; } } -std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names) { - std::unordered_map constraint_canonical_name_map { - { "top_k", GPT_CONSTRAINT_TYPE_TOP_K }, - { "top_p", GPT_CONSTRAINT_TYPE_TOP_P }, - { "typ_p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "min_p", GPT_CONSTRAINT_TYPE_MIN_P }, - { "tfs_z", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "temperature", GPT_CONSTRAINT_TYPE_TEMPERATURE }, +std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names) { + std::unordered_map sampler_canonical_name_map { + { "top_k", GPT_SAMPLER_TYPE_TOP_K }, + { "top_p", GPT_SAMPLER_TYPE_TOP_P }, + { "typ_p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "min_p", GPT_SAMPLER_TYPE_MIN_P }, + { "tfs_z", GPT_SAMPLER_TYPE_TFS_Z }, + { "temperature", GPT_SAMPLER_TYPE_TEMPERATURE }, }; - // since constraints names are written multiple ways + // since samplers names are written multiple ways // make it ready for both system names and input names - std::unordered_map constraint_alt_name_map { - { "top-k", GPT_CONSTRAINT_TYPE_TOP_K }, - { "top-p", GPT_CONSTRAINT_TYPE_TOP_P }, - { "nucleus", GPT_CONSTRAINT_TYPE_TOP_P }, - { "typical-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typical", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ-p", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "typ", GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { "min-p", GPT_CONSTRAINT_TYPE_MIN_P }, - { "tfs-z", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "tfs", GPT_CONSTRAINT_TYPE_TFS_Z }, - { "temp", GPT_CONSTRAINT_TYPE_TEMPERATURE }, + std::unordered_map sampler_alt_name_map { + { "top-k", GPT_SAMPLER_TYPE_TOP_K }, + { "top-p", GPT_SAMPLER_TYPE_TOP_P }, + { "nucleus", GPT_SAMPLER_TYPE_TOP_P }, + { "typical-p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typical", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typ-p", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "typ", GPT_SAMPLER_TYPE_TYPICAL_P }, + { "min-p", GPT_SAMPLER_TYPE_MIN_P }, + { "tfs-z", GPT_SAMPLER_TYPE_TFS_Z }, + { "tfs", GPT_SAMPLER_TYPE_TFS_Z }, + { "temp", GPT_SAMPLER_TYPE_TEMPERATURE }, }; - std::vector constraints; - constraints.reserve(names.size()); + std::vector samplers; + samplers.reserve(names.size()); for (const auto & name : names) { - auto constraint = constraint_canonical_name_map.find(name); - if (constraint != constraint_canonical_name_map.end()) { - constraints.push_back(constraint->second); + auto sampler = sampler_canonical_name_map.find(name); + if (sampler != sampler_canonical_name_map.end()) { + samplers.push_back(sampler->second); } else { if (allow_alt_names) { - constraint = constraint_alt_name_map.find(name); - if (constraint != constraint_alt_name_map.end()) { - constraints.push_back(constraint->second); + sampler = sampler_alt_name_map.find(name); + if (sampler != sampler_alt_name_map.end()) { + samplers.push_back(sampler->second); } } } } - return constraints; + return samplers; } -std::vector gpt_constraint_types_from_chars(const std::string & chars) { - std::unordered_map constraint_name_map { - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_K), GPT_CONSTRAINT_TYPE_TOP_K }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TFS_Z), GPT_CONSTRAINT_TYPE_TFS_Z }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TYPICAL_P), GPT_CONSTRAINT_TYPE_TYPICAL_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TOP_P), GPT_CONSTRAINT_TYPE_TOP_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_MIN_P), GPT_CONSTRAINT_TYPE_MIN_P }, - { gpt_constraint_type_to_chr(GPT_CONSTRAINT_TYPE_TEMPERATURE), GPT_CONSTRAINT_TYPE_TEMPERATURE } +std::vector gpt_sampler_types_from_chars(const std::string & chars) { + std::unordered_map sampler_name_map { + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_K), GPT_SAMPLER_TYPE_TOP_K }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TFS_Z), GPT_SAMPLER_TYPE_TFS_Z }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE } }; - std::vector constraints; - constraints.reserve(chars.size()); + std::vector samplers; + samplers.reserve(chars.size()); for (const auto & c : chars) { - const auto constraint = constraint_name_map.find(c); - if (constraint != constraint_name_map.end()) { - constraints.push_back(constraint->second); + const auto sampler = sampler_name_map.find(c); + if (sampler != sampler_name_map.end()) { + samplers.push_back(sampler->second); } } - return constraints; + return samplers; } diff --git a/common/sampling.h b/common/sampling.h index c260ef0553d8e..5083f456f1f96 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -5,14 +5,14 @@ #include #include -enum gpt_constraint_type { - GPT_CONSTRAINT_TYPE_NONE = 0, - GPT_CONSTRAINT_TYPE_TOP_K = 1, - GPT_CONSTRAINT_TYPE_TOP_P = 2, - GPT_CONSTRAINT_TYPE_MIN_P = 3, - GPT_CONSTRAINT_TYPE_TFS_Z = 4, - GPT_CONSTRAINT_TYPE_TYPICAL_P = 5, - GPT_CONSTRAINT_TYPE_TEMPERATURE = 6, +enum gpt_sampler_type { + GPT_SAMPLER_TYPE_NONE = 0, + GPT_SAMPLER_TYPE_TOP_K = 1, + GPT_SAMPLER_TYPE_TOP_P = 2, + GPT_SAMPLER_TYPE_MIN_P = 3, + GPT_SAMPLER_TYPE_TFS_Z = 4, + GPT_SAMPLER_TYPE_TYPICAL_P = 5, + GPT_SAMPLER_TYPE_TEMPERATURE = 6, }; // sampling parameters @@ -21,7 +21,7 @@ struct gpt_sampler_params { int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - int32_t min_keep = 0; // 0 = disabled, otherwise constraints should return at least min_keep tokens + int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float min_p = 0.05f; // 0.0 = disabled @@ -40,13 +40,13 @@ struct gpt_sampler_params { bool penalize_nl = false; // consider newlines as a repeatable token bool ignore_eos = false; - std::vector constraints = { - GPT_CONSTRAINT_TYPE_TOP_K, - GPT_CONSTRAINT_TYPE_TFS_Z, - GPT_CONSTRAINT_TYPE_TYPICAL_P, - GPT_CONSTRAINT_TYPE_TOP_P, - GPT_CONSTRAINT_TYPE_MIN_P, - GPT_CONSTRAINT_TYPE_TEMPERATURE + std::vector samplers = { + GPT_SAMPLER_TYPE_TOP_K, + GPT_SAMPLER_TYPE_TFS_Z, + GPT_SAMPLER_TYPE_TYPICAL_P, + GPT_SAMPLER_TYPE_TOP_P, + GPT_SAMPLER_TYPE_MIN_P, + GPT_SAMPLER_TYPE_TEMPERATURE }; std::string grammar; // optional BNF-like grammar to constrain sampling @@ -73,40 +73,36 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); -void gpt_sampler_apply_grammar(struct gpt_sampler * gsmpl, llama_token_data_array * cur_p); - -void gpt_sampler_set_logits(struct gpt_sampler * gsmpl, const float * logits); - llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); -llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); +//llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); -void gpt_print_timings(struct llama_context * ctx, struct gpt_sampler * gsmpl); +void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); // extended sampling implementation: // // - set logits -// - apply the configured sampling constraints +// - apply the configured sampler chain // - check if the token fits the grammar (if any) // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // -// if grammar_first is true, the grammar is applied before the constraints (slower) +// if grammar_first is true, the grammar is applied before the samplers (slower) // useful in cases where all the resulting candidates must fit the grammar // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers -// print the constraints into a string +// print the sampler chain into a string std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); // get a string representation of the last accepted tokens std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n); -char gpt_constraint_type_to_chr(enum gpt_constraint_type cnstr); -std::string gpt_constraint_type_to_str(enum gpt_constraint_type cnstr); +char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr); +std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr); -std::vector gpt_constraint_types_from_names(const std::vector & names, bool allow_alt_names); -std::vector gpt_constraint_types_from_chars(const std::string & chars); +std::vector gpt_sampler_types_from_names(const std::vector & names, bool allow_alt_names); +std::vector gpt_sampler_types_from_chars(const std::string & chars); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index a02fa4da9183d..24f8a7027bc10 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -50,9 +50,9 @@ defer { llama_free(context) } -var sparams = llama_sampler_params() +var sparams = llama_sampler_chain_default_params() -let smpl = llama_sampler_init(model, sparams) +let smpl = llama_sampler_chain_init(sparams) guard smpl != nil else { print("Failed to initialize sampling") exit(1) @@ -61,9 +61,9 @@ defer { llama_sampler_free(smpl) } -llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40)); -llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1)); -llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4)); +llama_sampler_sampler_add(smpl, llama_sampler_init_top_k(40)); +llama_sampler_sampler_add(smpl, llama_sampler_init_top_p(0.9, 1)); +llama_sampler_sampler_add(smpl, llama_sampler_init_temp (0.4)); let n_ctx = llama_n_ctx(context) @@ -137,11 +137,9 @@ while n_cur <= n_len { continue } - var logits = llama_get_logits_ith(context, i_batch[i]) + let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) - llama_sampler_set_logits(smpl, logits) - - let new_token_id = llama_sampler_sample(smpl, nil) + llama_sampler_accept(smpl, new_token_id) // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 5896526abdc77..b6e98fcc36335 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,15 +64,13 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_new_context_with_model(model, ctx_params); - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.seed = params.sparams.seed; + llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler * smpl = llama_sampler_init(model, sparams); - - llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k)); - llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep)); - llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); + llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); + llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -173,11 +171,9 @@ int main(int argc, char ** argv) { continue; } - const auto * logits = llama_get_logits_ith(ctx, i_batch[i]); - - llama_sampler_set_logits(smpl, logits); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 07475ecd30ed1..b402abbb80256 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -120,11 +120,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_decode(ctx, bat); - const auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_sampler_accept(smpl, token); - llama_sampler_set_logits(smpl, logits); - - llama_token token = llama_sampler_sample(smpl, nullptr); if (token == eos_token) { break; } @@ -171,11 +169,9 @@ int main(int argc, char * argv[]) { // create generation context llama_context * ctx = llama_new_context_with_model(model, cparams); - auto sparams = llama_sampler_default_params(); - - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + auto sparams = llama_sampler_chain_default_params(); - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 1a4908501c35c..9217937512d75 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -394,12 +394,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V"); - const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1); - - llama_sampler_set_logits(sampling, logits); - // sample the most likely token - const auto new_token_id = llama_sampler_sample(sampling, nullptr); + const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1); + + llama_sampler_accept(sampling, new_token_id); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index bd6513d3457ab..73cabc6c7444c 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,9 +43,8 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - var sparams = llama_sampler_default_params() - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY - self.sampling = llama_sampler_init(context, sparams) + var sparams = llama_sampler_chain_default_params() + self.sampling = llama_sampler_chain_init(sparams) } deinit { @@ -148,12 +147,9 @@ actor LlamaContext { func completion_loop() -> String { var new_token_id: llama_token = 0 - let n_vocab = llama_n_vocab(model) - let logits = llama_get_logits_ith(context, batch.n_tokens - 1) + new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - llama_sampler_set_logits(sampling, logits); - - new_token_id = llama_sampler_sample(sampling, nil) + llama_sampler_accept(sampling, new_token_id) if llama_token_is_eog(model, new_token_id) || n_cur == n_len { print("\n") diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index b9800a9170ab7..92c71c5a1a35d 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -83,11 +83,11 @@ int main(int argc, char ** argv) { return 1; } - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; + llama_sampler * smpl = llama_sampler_chain_init(sparams); - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); // tokenize the prompt std::vector tokens_list; @@ -220,12 +220,9 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - llama_sampler_set_logits(smpl, logits); - - // sample the most likely token - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 6f8c84137f1be..133a010e4757a 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -38,10 +38,12 @@ int main(int argc, char ** argv) { return 1; } - llama_sampler_params sparams = llama_sampler_default_params(); - sparams.seed = params.sparams.seed; + auto sparams = llama_sampler_chain_default_params(); - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed)); // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -69,13 +71,11 @@ int main(int argc, char ** argv) { printf("\nfirst run: %s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx); - - llama_sampler_set_logits(smpl, logits); - - auto next_token = llama_sampler_sample(smpl, nullptr); + auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = llama_token_to_piece(ctx, next_token); + llama_sampler_accept(smpl, next_token); + printf("%s", next_token_str.c_str()); result0 += next_token_str; @@ -96,7 +96,10 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampler * smpl2 = llama_sampler_init(model, sparams); + llama_sampler * smpl2 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl2, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed)); printf("\nsecond run: %s", params.prompt.c_str()); @@ -126,13 +129,11 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx2); - - llama_sampler_set_logits(smpl2, logits); - - auto next_token = llama_sampler_sample(smpl2, nullptr); + auto next_token = llama_sampler_sample(smpl2, ctx2, -1); auto next_token_str = llama_token_to_piece(ctx2, next_token); + llama_sampler_accept(smpl2, next_token); + printf("%s", next_token_str.c_str()); result1 += next_token_str; @@ -157,7 +158,10 @@ int main(int argc, char ** argv) { // make new context auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - llama_sampler * smpl3 = llama_sampler_init(model, sparams); + llama_sampler * smpl3 = llama_sampler_chain_init(sparams); + + llama_sampler_chain_add(smpl3, llama_sampler_init_softmax()); + llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed)); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -215,13 +219,11 @@ int main(int argc, char ** argv) { // third run with seq 1 instead of 0 for (auto i = 0; i < params.n_predict; i++) { - const auto * logits = llama_get_logits(ctx3); - - llama_sampler_set_logits(smpl3, logits); - - auto next_token = llama_sampler_sample(smpl3, nullptr); + auto next_token = llama_sampler_sample(smpl3, ctx3, -1); auto next_token_str = llama_token_to_piece(ctx3, next_token); + llama_sampler_accept(smpl3, next_token); + printf("%s", next_token_str.c_str()); result2 += next_token_str; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 03e512e0343e4..1095f43b206bb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1027,17 +1027,17 @@ struct server_context { } { - const auto & constraints = data.find("samplers"); - if (constraints != data.end() && constraints->is_array()) { - std::vector constraint_names; - for (const auto & name : *constraints) { + const auto & samplers = data.find("samplers"); + if (samplers != data.end() && samplers->is_array()) { + std::vector sampler_names; + for (const auto & name : *samplers) { if (name.is_string()) { - constraint_names.emplace_back(name); + sampler_names.emplace_back(name); } } - slot.sparams.constraints = gpt_constraint_types_from_names(constraint_names, false); + slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false); } else { - slot.sparams.constraints = default_sparams.constraints; + slot.sparams.samplers = default_sparams.samplers; } } @@ -1253,10 +1253,10 @@ struct server_context { } json get_formated_generation(const server_slot & slot) const { - std::vector constraints; - constraints.reserve(slot.sparams.constraints.size()); - for (const auto & constraint : slot.sparams.constraints) { - constraints.emplace_back(gpt_constraint_type_to_str(constraint)); + std::vector samplers; + samplers.reserve(slot.sparams.samplers.size()); + for (const auto & sampler : slot.sparams.samplers) { + samplers.emplace_back(gpt_sampler_type_to_str(sampler)); } return json { @@ -1290,7 +1290,7 @@ struct server_context { {"n_probs", slot.sparams.n_probs}, {"min_keep", slot.sparams.min_keep}, {"grammar", slot.sparams.grammar}, - {"samplers", constraints}, + {"samplers", samplers}, }; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 7193f1ee4a03a..e5dfeb2f4b4f8 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,11 +55,9 @@ int main(int argc, char ** argv) { return 1; } - auto sparams = llama_sampler_default_params(); + auto sparams = llama_sampler_chain_default_params(); - sparams.type = LLAMA_SAMPLER_TYPE_GREEDY; - - llama_sampler * smpl = llama_sampler_init(model, sparams); + llama_sampler * smpl = llama_sampler_chain_init(sparams); // tokenize the prompt @@ -116,12 +114,9 @@ int main(int argc, char ** argv) { while (n_cur <= n_predict) { // sample the next token { - const auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - - llama_sampler_set_logits(smpl, logits); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); - // sample the most likely token - const llama_token new_token_id = llama_sampler_sample(smpl, nullptr); + llama_sampler_accept(smpl, new_token_id); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 9f596ec914b54..037d5d34bb54d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -179,7 +179,7 @@ int main(int argc, char ** argv) { // target model sampling context (reuse the llama_context's sampling instance) struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams); - struct llama_constraint * softmax = llama_constraint_init_softmax(); + struct llama_sampler * softmax = llama_sampler_init_softmax(); // draft sequence data std::vector drafts(n_seq_dft); @@ -255,7 +255,7 @@ int main(int argc, char ** argv) { LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size()); float r = u_dist(rng); - llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true }; + llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true }; //GGML_ASSERT(dist_tgt.size <= dist_dft.size); @@ -625,7 +625,7 @@ int main(int argc, char ** argv) { gpt_sampler_free(drafts[s].smpl); } - llama_constraint_free(softmax); + llama_sampler_free(softmax); llama_batch_free(batch_dft); llama_free(ctx_tgt); diff --git a/include/llama.h b/include/llama.h index dd047e0aceb9c..f735413894912 100644 --- a/include/llama.h +++ b/include/llama.h @@ -216,6 +216,7 @@ extern "C" { // TODO: consider SoA llama_token_data * data; size_t size; + int64_t selected; bool sorted; } llama_token_data_array; @@ -369,21 +370,9 @@ extern "C" { float bias; } llama_logit_bias; - enum llama_sampler_type { - LLAMA_SAMPLER_TYPE_GREEDY = 0, - LLAMA_SAMPLER_TYPE_DIST = 1, - }; - - typedef struct llama_sampler_params { - uint32_t seed; // the seed used to initialize the rng of the sampler - - int32_t n_prev; // size of ring buffer to keep previous accepted tokens (needed for llama_sampler_prev_ API) - - // TODO: will be used by the llama_decode_with_sampler() API in the future - enum llama_sampler_type type; - + typedef struct llama_sampler_chain_params { bool no_timing; // whether to measure performance timings - } llama_sampler_params; + } llama_sampler_chain_params; // performance timing information struct llama_timings { @@ -412,7 +401,7 @@ extern "C" { // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) LLAMA_API struct llama_model_params llama_model_default_params(void); LLAMA_API struct llama_context_params llama_context_default_params(void); - LLAMA_API struct llama_sampler_params llama_sampler_default_params(void); + LLAMA_API struct llama_sampler_chain_params llama_sampler_chain_default_params(void); LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); // Initialize the llama + ggml backend @@ -1003,70 +992,73 @@ extern "C" { // // Sampling API // - // - Constraints - // The llama_constraint object works on a set of candidate tokens (llama_token_data_array), by modifying their - // logits and probabilities inplace. The interface is abstracted so that users can implement custom constraints. - // - // - Samplers - // The llama_sampler samples a token based on the candidate token probabilities. Before the actual sampling, the - // sampler can apply a sequence of constraints in order to modify the probabilities of the candidates. - // - // The llama_sampler object contains the entire sampling information: - // - // - RNG state (seed and generator) - // - Custom set of constraints (see llama_sampler_constraint_add) - // - Sampling method (greedy, dist) - // - Previous tokens - // // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). // // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model - // constraints - - struct llama_constraint; + typedef void * llama_sampler_context_t; - typedef void * llama_constraint_context_t; - - // user code can implement the interface below in order to create custom llama_constraint - struct llama_constraint_i { - const char * (*name) (const struct llama_constraint * cnstr); // can be NULL - void (*accept)( struct llama_constraint * cnstr, llama_token token); // can be NULL - void (*apply) ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); // required - void (*reset) ( struct llama_constraint * cnstr); // can be NULL - struct llama_constraint * (*clone) (const struct llama_constraint * cnstr); // can be NULL if ctx is NULL - void (*free) ( struct llama_constraint * cnstr); // can be NULL if ctx is NULL + // user code can implement the interface below in order to create custom llama_sampler + struct llama_sampler_i { + const char * (*name) (const struct llama_sampler * smpl); // can be NULL + void (*accept)( struct llama_sampler * smpl, llama_token token); // can be NULL + void (*apply) ( struct llama_sampler * smpl, llama_token_data_array * cur_p); // required + void (*reset) ( struct llama_sampler * smpl); // can be NULL + struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL + void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_constraint * cnstr, ...); + //void (*apply_ggml) (struct llama_sampler * smpl, ...); }; - struct llama_constraint { - struct llama_constraint_i * iface; - llama_constraint_context_t ctx; + struct llama_sampler { + struct llama_sampler_i * iface; + llama_sampler_context_t ctx; }; + LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); + LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); + LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); + LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); + // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) + LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); + + // llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers + + LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); + + // important: takes ownership of the sampler object and will free it when llama_sampler_free is called + LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); + LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); + + // available samplers: + + LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API struct llama_constraint * llama_constraint_init_softmax (void); + LLAMA_API struct llama_sampler * llama_sampler_init_softmax (void); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_constraint * llama_constraint_init_top_k (int32_t k); + LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_constraint * llama_constraint_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API struct llama_constraint * llama_constraint_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API struct llama_constraint * llama_constraint_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API struct llama_constraint * llama_constraint_init_typical (float p, int32_t min_keep); - LLAMA_API struct llama_constraint * llama_constraint_init_temp (float t); + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API struct llama_constraint * llama_constraint_init_temp_ext (float t, float delta, float exponent); + LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext (float t, float delta, float exponent); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -1074,7 +1066,7 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API struct llama_constraint * llama_constraint_init_mirostat( + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( const struct llama_model * model, float tau, float eta); @@ -1084,16 +1076,16 @@ extern "C" { /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API struct llama_constraint * llama_constraint_init_mirostat_v2( + LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( float tau, float eta); - LLAMA_API struct llama_constraint * llama_constraint_init_grammar( + LLAMA_API struct llama_sampler * llama_sampler_init_grammar( const struct llama_model * model, const char * grammar_str, const char * grammar_root); - LLAMA_API struct llama_constraint * llama_constraint_init_penalties( + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( const struct llama_model * model, int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat, // 1.0 = disabled @@ -1102,57 +1094,14 @@ extern "C" { bool penalize_nl, // consider newlines as a repeatable token bool ignore_eos); // ignore the end-of-sequence token - LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( const struct llama_model * model, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - LLAMA_API struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr); - - // important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add) - LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr); - - LLAMA_API const char * llama_constraint_name (const struct llama_constraint * cnstr); - LLAMA_API void llama_constraint_accept( struct llama_constraint * cnstr, llama_token token); - LLAMA_API void llama_constraint_apply ( struct llama_constraint * cnstr, llama_token_data_array * cur_p); - LLAMA_API void llama_constraint_reset ( struct llama_constraint * cnstr); - - // samplers - - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_model * model, struct llama_sampler_params params); - LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_clone (const struct llama_sampler * smpl); - LLAMA_API void llama_sampler_reset ( struct llama_sampler * smpl); - LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); - LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); - - LLAMA_API void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits); - - LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl); - - // important: takes ownership of the constraint object and will free it in llama_sampler_free - LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr); - LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl); - LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i); - - - LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p); - - /// @details Get the number of accepted tokens so far (max of n_prev) - LLAMA_API int llama_sampler_n_prev(const struct llama_sampler * smpl); - - /// @details Get the ith accepted token - /// @param ith [0, n_prev), ith == 0 is the last accepted token. - /// returns LLAMA_TOKEN_NULL if ith is out of bounds - LLAMA_API llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith); - - /// @details Get the last accepted token - /// Same as llama_sampler_prev(smpl, 0) - /// returns LLAMA_TOKEN_NULL if there are no accepted tokens - LLAMA_API llama_token llama_sampler_last(const struct llama_sampler * smpl); + LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future - //LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t i); //LLAMA_API void llama_decode_with_sampler(struct llama_context * ctx, struct llama_sampler * smpl, struct llama_batch batch, ...); // @@ -1172,8 +1121,9 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl); - LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl); + // note: requires llama_sampler_chain. how to prevent misuse? + LLAMA_API void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain); + LLAMA_API void llama_reset_timings( struct llama_context * ctx, struct llama_sampler * chain); // Print system information LLAMA_API const char * llama_print_system_info(void); diff --git a/src/llama-impl.h b/src/llama-impl.h index 6d388655d01a8..fa2e09e1f688e 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -32,6 +32,20 @@ void llama_log_callback_default(ggml_log_level level, const char * text, void * // helpers // +struct time_meas { + time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + if (t_start_us >= 0) { + t_acc += ggml_time_us() - t_start_us; + } + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + static void replace_all(std::string & s, const std::string & search, const std::string & replace) { if (search.empty()) { return; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index cf28baab5978f..735992faa94bb 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include static void llama_log_softmax(float * array, size_t size) { @@ -24,7 +25,7 @@ static void llama_log_softmax(float * array, size_t size) { } } -static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { +static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); // Sort the logits in descending order @@ -49,7 +50,7 @@ static void llama_constraint_softmax_impl(llama_token_data_array * cur_p) { } } -static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t k) { +static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)cur_p->size) { // return; @@ -125,12 +126,12 @@ static void llama_constraint_top_k_impl(llama_token_data_array * cur_p, int32_t cur_p->size = k; } -static void llama_constraint_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_top_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -151,7 +152,7 @@ static void llama_constraint_top_p_impl(llama_token_data_array * cur_p, float p, cur_p->size = last_idx; } -static void llama_constraint_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_min_p_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { if (p <= 0.0f || !cur_p->size) { return; } @@ -206,12 +207,12 @@ static void llama_constraint_min_p_impl(llama_token_data_array * cur_p, float p, } } -static void llama_constraint_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { +static void llama_sampler_tail_free_impl(llama_token_data_array * cur_p, float z, size_t min_keep) { if (z >= 1.0f || cur_p->size <= 2) { return; } - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Compute the first and second derivatives std::vector first_derivatives(cur_p->size - 1); @@ -260,7 +261,7 @@ static void llama_constraint_tail_free_impl(llama_token_data_array * cur_p, floa cur_p->size = last_idx; } -static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { +static void llama_sampler_typical_impl(llama_token_data_array * cur_p, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -268,7 +269,7 @@ static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float } // Compute the softmax of logits and calculate entropy - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); float entropy = 0.0f; for (size_t i = 0; i < cur_p->size; ++i) { @@ -318,7 +319,7 @@ static void llama_constraint_typical_impl(llama_token_data_array * cur_p, float cur_p->sorted = false; } -static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { +static void llama_sampler_entropy_impl(llama_token_data_array * cur_p, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if (cur_p->size <= 1) { return; @@ -327,7 +328,7 @@ static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float // Calculate maximum possible entropy float max_entropy = -logf(1.0f / cur_p->size); - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -381,17 +382,17 @@ static void llama_constraint_entropy_impl(llama_token_data_array * cur_p, float #endif } -static void llama_constraint_temp_impl(llama_token_data_array * cur_p, float temp) { +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { for (size_t i = 0; i < cur_p->size; ++i) { cur_p->data[i].logit /= temp; } } -static void llama_constraint_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { +static void llama_sampler_grammar_impl(llama_token_data_array * cur_p, const struct llama_grammar & grammar) { llama_grammar_apply_impl(grammar, cur_p); } -void llama_constraint_penalties_impl( +void llama_sampler_penalties_impl( llama_token_data_array * cur_p, const llama_token_cnt & token_count, float penalty_repeat, @@ -421,56 +422,124 @@ void llama_constraint_penalties_impl( } // -// constraints +// samplers // +// greedy + +static struct llama_sampler_i llama_sampler_greedy_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "greedy"; }, + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + cur_p->selected = 0; + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) { + cur_p->selected = i; + } + } + }, + /* .reset = */ nullptr, + /* .clone = */ nullptr, + /* .free = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_greedy_impl() { + return new llama_sampler { + /* .iface = */ &llama_sampler_greedy_i, + /* .ctx = */ nullptr, + }; +} + +// dist + +struct llama_sampler_context_dist { + const uint32_t seed; + + std::mt19937 rng; +}; + +static struct llama_sampler_i llama_sampler_dist_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; }, + /* .accept = */ nullptr, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_dist *) smpl->ctx; + std::vector probs; + probs.reserve(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs.push_back(cur_p->data[i].p); + } + + std::discrete_distribution dist(probs.begin(), probs.end()); + + cur_p->selected = dist(ctx->rng); + }, + /* .reset = */ nullptr, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; + return llama_sampler_init_dist_impl(ctx->seed); + }, + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_dist *) smpl->ctx; + }, +}; + +struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { + return new llama_sampler { + /* .iface = */ &llama_sampler_dist_i, + /* .ctx = */ new llama_sampler_context_dist { + /* .seed = */ seed, + /* .rng = */ std::mt19937(seed), + }, + }; +} + // softmax -static struct llama_constraint_i llama_constraint_softmax_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "softmax"; }, +static struct llama_sampler_i llama_sampler_softmax_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "softmax"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * /*cnstr*/, llama_token_data_array * cur_p) { - llama_constraint_softmax_impl(cur_p); + /* .apply = */ [](struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { + llama_sampler_softmax_impl(cur_p); }, /* .reset = */ nullptr, /* .clone = */ nullptr, /* .free = */ nullptr, }; -struct llama_constraint * llama_constraint_init_softmax_impl() { - return new llama_constraint { - /* .iface = */ &llama_constraint_softmax_i, +struct llama_sampler * llama_sampler_init_softmax_impl() { + return new llama_sampler { + /* .iface = */ &llama_sampler_softmax_i, /* .ctx = */ nullptr, }; } // top-k -struct llama_constraint_context_top_k { +struct llama_sampler_context_top_k { const int32_t k; }; -static struct llama_constraint_i llama_constraint_top_k_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-k"; }, +static struct llama_sampler_i llama_sampler_top_k_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_top_k *) cnstr->ctx; - llama_constraint_top_k_impl(cur_p, ctx->k); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_top_k *) smpl->ctx; + llama_sampler_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_top_k *) cnstr->ctx; - return llama_constraint_init_top_k_impl(ctx->k); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; + return llama_sampler_init_top_k_impl(ctx->k); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_k *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_top_k *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { - return new llama_constraint { - /* .iface = */ &llama_constraint_top_k_i, - /* .ctx = */ new llama_constraint_context_top_k { +struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_k_i, + /* .ctx = */ new llama_sampler_context_top_k { /* .k = */ k, }, }; @@ -478,32 +547,32 @@ struct llama_constraint * llama_constraint_init_top_k_impl(int32_t k) { // top-p -struct llama_constraint_context_top_p { +struct llama_sampler_context_top_p { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_top_p_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "top-p"; }, +static struct llama_sampler_i llama_sampler_top_p_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_top_p *) cnstr->ctx; - llama_constraint_top_p_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_top_p *) smpl->ctx; + llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_top_p *) cnstr->ctx; - return llama_constraint_init_top_p_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; + return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_top_p *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_top_p *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_top_p_i, - /* .ctx = */ new llama_constraint_context_top_p { +struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_top_p_i, + /* .ctx = */ new llama_sampler_context_top_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -512,32 +581,32 @@ struct llama_constraint * llama_constraint_init_top_p_impl(float p, size_t min_k // min-p -struct llama_constraint_context_min_p { +struct llama_sampler_context_min_p { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_min_p_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "min-p"; }, +static struct llama_sampler_i llama_sampler_min_p_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_min_p *) cnstr->ctx; - llama_constraint_min_p_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_min_p *) smpl->ctx; + llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_min_p *) cnstr->ctx; - return llama_constraint_init_min_p_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; + return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_min_p *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_min_p *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_min_p_i, - /* .ctx = */ new llama_constraint_context_min_p { +struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_min_p_i, + /* .ctx = */ new llama_sampler_context_min_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -546,32 +615,32 @@ struct llama_constraint * llama_constraint_init_min_p_impl(float p, size_t min_k // tail-free -struct llama_constraint_context_tail_free { +struct llama_sampler_context_tail_free { const float z; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_tail_free_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "tail-free"; }, +static struct llama_sampler_i llama_sampler_tail_free_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_tail_free *) cnstr->ctx; - llama_constraint_tail_free_impl(cur_p, ctx->z, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_tail_free *) smpl->ctx; + llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_tail_free *) cnstr->ctx; - return llama_constraint_init_tail_free_impl(ctx->z, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; + return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_tail_free *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_tail_free *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_tail_free_i, - /* .ctx = */ new llama_constraint_context_tail_free { +struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_tail_free_i, + /* .ctx = */ new llama_sampler_context_tail_free { /* .z = */ z, /*. min_keep = */ min_keep, }, @@ -580,32 +649,32 @@ struct llama_constraint * llama_constraint_init_tail_free_impl(float z, size_t m // typical -struct llama_constraint_context_typical { +struct llama_sampler_context_typical { const float p; const size_t min_keep; }; -static struct llama_constraint_i llama_constraint_typical_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "typical"; }, +static struct llama_sampler_i llama_sampler_typical_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_typical *) cnstr->ctx; - llama_constraint_typical_impl(cur_p, ctx->p, ctx->min_keep); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_typical *) smpl->ctx; + llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_typical *) cnstr->ctx; - return llama_constraint_init_typical_impl(ctx->p, ctx->min_keep); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; + return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_typical *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_typical *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min_keep) { - return new llama_constraint { - /* .iface = */ &llama_constraint_typical_i, - /* .ctx = */ new llama_constraint_context_typical { +struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) { + return new llama_sampler { + /* .iface = */ &llama_sampler_typical_i, + /* .ctx = */ new llama_sampler_context_typical { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -614,31 +683,31 @@ struct llama_constraint * llama_constraint_init_typical_impl(float p, size_t min // temp -struct llama_constraint_context_temp { +struct llama_sampler_context_temp { const float temp; }; -static struct llama_constraint_i llama_constraint_temp_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp"; }, +static struct llama_sampler_i llama_sampler_temp_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_temp *) cnstr->ctx; - llama_constraint_temp_impl(cur_p, ctx->temp); + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_temp *) smpl->ctx; + llama_sampler_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_temp *) cnstr->ctx; - return llama_constraint_init_temp_impl(ctx->temp); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; + return llama_sampler_init_temp_impl(ctx->temp); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_temp *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_temp *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_temp_impl(float temp) { - return new llama_constraint { - /* .iface = */ &llama_constraint_temp_i, - /* .ctx = */ new llama_constraint_context_temp { +struct llama_sampler * llama_sampler_init_temp_impl(float temp) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_i, + /* .ctx = */ new llama_sampler_context_temp { /*.temp = */ temp, }, }; @@ -646,40 +715,40 @@ struct llama_constraint * llama_constraint_init_temp_impl(float temp) { // temp-ext -struct llama_constraint_context_temp_ext { +struct llama_sampler_context_temp_ext { const float temp; const float delta; const float exponent; }; -static struct llama_constraint_i llama_constraint_temp_ext_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "temp-ext"; }, +static struct llama_sampler_i llama_sampler_temp_ext_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_temp_ext *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; - llama_constraint_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); + llama_sampler_entropy_impl(cur_p, temp_min, temp_max, ctx->exponent); } else { - llama_constraint_temp_impl(cur_p, ctx->temp); + llama_sampler_temp_impl(cur_p, ctx->temp); } }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_temp_ext *) cnstr->ctx; - return llama_constraint_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; + return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_temp_ext *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_temp_ext *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float delta, float exponent) { - return new llama_constraint { - /* .iface = */ &llama_constraint_temp_ext_i, - /* .ctx = */ new llama_constraint_context_temp_ext { +struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) { + return new llama_sampler { + /* .iface = */ &llama_sampler_temp_ext_i, + /* .ctx = */ new llama_sampler_context_temp_ext { /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, @@ -689,7 +758,7 @@ struct llama_constraint * llama_constraint_init_temp_ext_impl(float temp, float // mirostat -struct llama_constraint_context_mirostat { +struct llama_sampler_context_mirostat { const struct llama_vocab * vocab; const float tau; @@ -702,10 +771,10 @@ struct llama_constraint_context_mirostat { std::vector cur; }; -static struct llama_constraint_i llama_constraint_mirostat_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_mirostat_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; int32_t idx = -1; for (size_t i = 0; i < ctx->cur.size(); ++i) { @@ -721,10 +790,10 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { // Update mu using the learning rate and error ctx->mu = ctx->mu - ctx->eta * e; }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -742,7 +811,7 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { float epsilon_hat = s_hat - 1; float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); - llama_constraint_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -750,23 +819,23 @@ static struct llama_constraint_i llama_constraint_mirostat_i = { ctx->cur[i] = cur_p->data[i]; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_mirostat *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_mirostat *) cnstr->ctx; - return llama_constraint_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx; + return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_mirostat *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_mirostat *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { - return new llama_constraint { - /* .iface = */ &llama_constraint_mirostat_i, - /* .ctx = */ new llama_constraint_context_mirostat { +struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_i, + /* .ctx = */ new llama_sampler_context_mirostat { /* .vocab = */ &vocab, /* .tau = */ tau, /* .eta = */ eta, @@ -779,7 +848,7 @@ struct llama_constraint * llama_constraint_init_mirostat_impl(const struct llama // mirostat v2 -struct llama_constraint_context_mirostat_v2 { +struct llama_sampler_context_mirostat_v2 { const float tau; const float eta; @@ -788,10 +857,10 @@ struct llama_constraint_context_mirostat_v2 { std::vector cur; }; -static struct llama_constraint_i llama_constraint_mirostat_v2_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "mirostat-v2"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_mirostat_v2_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; int32_t idx = -1; for (size_t i = 0; i < ctx->cur.size(); ++i) { @@ -807,10 +876,10 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { // Update mu using the learning rate and error ctx->mu = ctx->mu - ctx->eta * e; }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // Truncate the words with surprise values greater than mu cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) { @@ -822,7 +891,7 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { } // Normalize the probabilities of the remaining words - llama_constraint_softmax_impl(cur_p); + llama_sampler_softmax_impl(cur_p); // remember the order to be able to compute the distance later when accepting the token ctx->cur.resize(cur_p->size); @@ -830,23 +899,23 @@ static struct llama_constraint_i llama_constraint_mirostat_v2_i = { ctx->cur[i] = cur_p->data[i]; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx = (const llama_constraint_context_mirostat_v2 *) cnstr->ctx; - return llama_constraint_init_mirostat_v2_impl(ctx->tau, ctx->eta); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; + return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_mirostat_v2 *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, float eta) { - return new llama_constraint { - /* .iface = */ &llama_constraint_mirostat_v2_i, - /* .ctx = */ new llama_constraint_context_mirostat_v2 { +struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) { + return new llama_sampler { + /* .iface = */ &llama_sampler_mirostat_v2_i, + /* .ctx = */ new llama_sampler_context_mirostat_v2 { /* .tau = */ tau, /* .eta = */ eta, /* .mu = */ 2.0f*tau, @@ -857,7 +926,7 @@ struct llama_constraint * llama_constraint_init_mirostat_v2_impl(float tau, floa // grammar -struct llama_constraint_context_grammar { +struct llama_sampler_context_grammar { const struct llama_vocab * vocab; std::string grammar_str; @@ -866,22 +935,22 @@ struct llama_constraint_context_grammar { struct llama_grammar * grammar; }; -static struct llama_constraint_i llama_constraint_grammar_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "grammar"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_grammar_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_accept_impl(*ctx->grammar, token); } }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { - llama_constraint_grammar_impl(cur_p, *ctx->grammar); + llama_sampler_grammar_impl(cur_p, *ctx->grammar); } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (!ctx->grammar) { return; } @@ -891,12 +960,12 @@ static struct llama_constraint_i llama_constraint_grammar_i = { llama_grammar_free_impl(ctx->grammar); ctx->grammar = grammar_new; }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_grammar *) cnstr->ctx; + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_grammar *) smpl->ctx; - auto * result = llama_constraint_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); + auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); - auto * ctx_dst = (llama_constraint_context_grammar *) result->ctx; + auto * ctx_dst = (llama_sampler_context_grammar *) result->ctx; if (ctx_src->grammar) { ctx_dst->grammar_str = ctx_src->grammar_str; ctx_dst->grammar_root = ctx_src->grammar_root; @@ -906,8 +975,8 @@ static struct llama_constraint_i llama_constraint_grammar_i = { return result; }, - /* .free = */ [](struct llama_constraint * cnstr) { - const auto * ctx = (llama_constraint_context_grammar *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); @@ -917,8 +986,8 @@ static struct llama_constraint_i llama_constraint_grammar_i = { }, }; -struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { - auto * ctx = new llama_constraint_context_grammar; +struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + auto * ctx = new llama_sampler_context_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { @@ -936,15 +1005,15 @@ struct llama_constraint * llama_constraint_init_grammar_impl(const struct llama_ }; } - return new llama_constraint { - /* .iface = */ &llama_constraint_grammar_i, + return new llama_sampler { + /* .iface = */ &llama_sampler_grammar_i, /* .ctx = */ ctx, }; } // penalties -struct llama_constraint_context_penalties { +struct llama_sampler_context_penalties { const struct llama_vocab * vocab; const int32_t penalty_last_n; @@ -958,16 +1027,16 @@ struct llama_constraint_context_penalties { ring_buffer prev; }; -static struct llama_constraint_i llama_constraint_penalties_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "penalties"; }, - /* .accept = */ [](struct llama_constraint * cnstr, llama_token token) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; +static struct llama_sampler_i llama_sampler_penalties_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; ctx->prev.push_back(token); }, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary"); if (ctx->ignore_eos) { cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY; @@ -981,26 +1050,26 @@ static struct llama_constraint_i llama_constraint_penalties_i = { const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY; // Create a frequency map to count occurrences of each token in last_tokens - // TODO: optimize this by maintaining the token count in the constraint context + // TODO: optimize this by maintaining the token count in the sampler context llama_token_cnt token_count; for (int i = 0; i < ctx->penalty_last_n; ++i) { token_count[ctx->prev.rat(i)]++; } - llama_constraint_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); + llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); if (!ctx->penalize_nl) { // restore the logit of the newline token if it was penalized cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit; } }, - /* .reset = */ [](struct llama_constraint * cnstr) { - auto * ctx = (llama_constraint_context_penalties *) cnstr->ctx; + /* .reset = */ [](struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; ctx->prev.clear(); }, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_penalties *) cnstr->ctx; - auto * result = llama_constraint_init_penalties_impl( + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_penalties *) smpl->ctx; + auto * result = llama_sampler_init_penalties_impl( *ctx_src->vocab, ctx_src->penalty_last_n, ctx_src->penalty_repeat, @@ -1009,23 +1078,23 @@ static struct llama_constraint_i llama_constraint_penalties_i = { ctx_src->penalize_nl, ctx_src->ignore_eos); - auto * ctx_dst = (llama_constraint_context_penalties *) result->ctx; + auto * ctx_dst = (llama_sampler_context_penalties *) result->ctx; ctx_dst->prev = ctx_src->prev; return result; }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_penalties *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_penalties *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { +struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); - return new llama_constraint { - /* .iface = */ &llama_constraint_penalties_i, - /* .ctx = */ new llama_constraint_context_penalties { + return new llama_sampler { + /* .iface = */ &llama_sampler_penalties_i, + /* .ctx = */ new llama_sampler_context_penalties { /* .vocab = */ &vocab, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, @@ -1040,230 +1109,166 @@ struct llama_constraint * llama_constraint_init_penalties_impl(const struct llam // logit-bias -struct llama_constraint_context_logit_bias { +struct llama_sampler_context_logit_bias { const struct llama_vocab * vocab; std::vector logit_bias; }; -static struct llama_constraint_i llama_constraint_logit_bias_i = { - /* .name = */ [](const struct llama_constraint * /*cnstr*/) { return "logit-bias"; }, +static struct llama_sampler_i llama_sampler_logit_bias_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; }, /* .accept = */ nullptr, - /* .apply = */ [](struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - auto * ctx = (llama_constraint_context_logit_bias *) cnstr->ctx; + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_context_logit_bias *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' constraint must be applied on the full vocabulary"); + GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary"); for (const auto & lb : ctx->logit_bias) { cur_p->data[lb.token].logit += lb.bias; } }, /* .reset = */ nullptr, - /* .clone = */ [](const struct llama_constraint * cnstr) { - const auto * ctx_src = (const llama_constraint_context_logit_bias *) cnstr->ctx; - return llama_constraint_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * ctx_src = (const llama_sampler_context_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, - /* .free = */ [](struct llama_constraint * cnstr) { - delete (llama_constraint_context_logit_bias *) cnstr->ctx; + /* .free = */ [](struct llama_sampler * smpl) { + delete (llama_sampler_context_logit_bias *) smpl->ctx; }, }; -struct llama_constraint * llama_constraint_init_logit_bias_impl( +struct llama_sampler * llama_sampler_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { - return new llama_constraint { - /* .iface = */ &llama_constraint_logit_bias_i, - /* .ctx = */ new llama_constraint_context_logit_bias { + return new llama_sampler { + /* .iface = */ &llama_sampler_logit_bias_i, + /* .ctx = */ new llama_sampler_context_logit_bias { /* .vocab = */ &vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), }, }; } -//////////////////////////////////////// +// sampler chain -struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr) { - return cnstr.iface->clone ? cnstr.iface->clone(&cnstr) : nullptr; -} +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) { + auto * chain = (llama_sampler_chain *) smpl->ctx; -void llama_constraint_free_impl(struct llama_constraint * cnstr) { - if (cnstr == nullptr) { - return; - } + chain->n_sample++; + }, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; - if (cnstr->iface->free) { - cnstr->iface->free(cnstr); - } + time_meas tm(chain->t_sample_us, chain->params.no_timing); - delete cnstr; -} + for (auto * smpl : chain->samplers) { + llama_sampler_apply_impl(*smpl, cur_p); + } + }, + /* .reset = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; -void llama_constraint_accept_impl(struct llama_constraint & cnstr, llama_token token) { - if (cnstr.iface->accept) { - cnstr.iface->accept(&cnstr, token); - } -} + for (auto * smpl : chain->samplers) { + llama_sampler_reset_impl(*smpl); + } -void llama_constraint_apply_impl(struct llama_constraint & cnstr, struct llama_token_data_array * cur_p) { - GGML_ASSERT(cnstr.iface->apply); - cnstr.iface->apply(&cnstr, cur_p); -} + chain->t_sample_us = 0; + chain->n_sample = 0; + }, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; -void llama_constraint_reset_impl(struct llama_constraint & cnstr) { - if (cnstr.iface->reset) { - cnstr.iface->reset(&cnstr); - } -} + auto * result = llama_sampler_chain_init_impl(chain_src->params); -// -// samplers -// + auto * chain_dst = (llama_sampler_chain *) result->ctx; + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl)); + } -struct llama_sampler * llama_sampler_init_impl(const struct llama_vocab & vocab, struct llama_sampler_params params) { - return new llama_sampler { - /* .params = */ params, - /* .vocab = */ &vocab, + return result; + }, + /* .free = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free_impl(smpl); + } - /* .rng = */ std::mt19937(params.seed), + delete chain; + }, +}; - /* .prev = */ { (size_t) params.n_prev }, - /* .constraints = */ {}, - /* .cur = */ {}, - /* .cur_p = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, +struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) { + return new llama_sampler { + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }, }; } -void llama_sampler_free_impl(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - for (auto * cnstr : smpl->constraints) { - llama_constraint_free_impl(cnstr); - } - - delete smpl; +void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) { + chain.samplers.push_back(smpl); } -struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { - auto * result = new llama_sampler { - /* .params = */ smpl.params, - /* .vocab = */ smpl.vocab, - - /* .rng = */ smpl.rng, - - /* .prev = */ smpl.prev, - /* .constraints = */ {}, - /* .cur = */ {}, - /* .cur_p = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - }; - - // clone the constraints objects - result->constraints.clear(); - for (const auto & cnstr : smpl.constraints) { - if (cnstr->ctx == nullptr) { - result->constraints.push_back(new llama_constraint { - /* .iface = */ cnstr->iface, - /* .ctx = */ nullptr, - }); - } else { - GGML_ASSERT(cnstr->iface->clone); - result->constraints.push_back(cnstr->iface->clone(cnstr)); - } +struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) { + if (i < 0 || i >= (int32_t) chain.samplers.size()) { + return nullptr; } - return result; + return chain.samplers[i]; } -void llama_sampler_reset_impl(struct llama_sampler & smpl) { - smpl.prev.clear(); +int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) { + return chain.samplers.size(); +} - for (auto * cnstr : smpl.constraints) { - llama_constraint_reset_impl(*cnstr); - } - // TODO: should we reset the timings? -} +//////////////////////////////////////// -const char * llama_constraint_name_impl(const struct llama_constraint & cnstr) { - if (!cnstr.iface) { +const char * llama_sampler_name_impl(const struct llama_sampler & smpl) { + if (!smpl.iface) { return "(null)"; } - return cnstr.iface->name(&cnstr); + return smpl.iface->name(&smpl); } void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - smpl.prev.push_back(token); - - for (auto * cnstr : smpl.constraints) { - llama_constraint_accept_impl(*cnstr, token); + if (smpl.iface->accept) { + smpl.iface->accept(&smpl, token); } } void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - for (auto * cnstr : smpl.constraints) { - llama_constraint_apply_impl(*cnstr, cur_p); - } + GGML_ASSERT(smpl.iface->apply); + smpl.iface->apply(&smpl, cur_p); } -void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) { - smpl.constraints.push_back(cnstr); -} - -int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) { - return smpl.constraints.size(); -} - -struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.constraints.size()) { - return nullptr; +void llama_sampler_reset_impl(struct llama_sampler & smpl) { + if (smpl.iface->reset) { + smpl.iface->reset(&smpl); } - - return smpl.constraints[ith]; } -llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type) { - switch (type) { - case LLAMA_SAMPLER_TYPE_GREEDY: - { - llama_constraint_softmax_impl(cur_p); - - return cur_p->data[0].id; - } - case LLAMA_SAMPLER_TYPE_DIST: - { - llama_constraint_softmax_impl(cur_p); - - std::vector probs(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs[i] = cur_p->data[i].p; - } - - std::discrete_distribution<> dist(probs.begin(), probs.end()); - - const int idx = dist(rng); - - return cur_p->data[idx].id; - } - default: - GGML_ABORT("invalid sampler type"); - } +struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { + return smpl.iface->clone ? smpl.iface->clone(&smpl) : nullptr; } -llama_token llama_sampler_prev_impl(const struct llama_sampler & smpl, int ith) { - if (ith < 0 || ith >= (int) smpl.prev.size()) { - return LLAMA_TOKEN_NULL; +void llama_sampler_free_impl(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; } - return smpl.prev.rat(ith); -} + if (smpl->iface->free) { + smpl->iface->free(smpl); + } -int llama_sampler_n_prev_impl(const struct llama_sampler & smpl) { - return smpl.prev.size(); + delete smpl; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 18304b49a8ef1..3f14ec621f5c1 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -2,49 +2,76 @@ #include "llama-grammar.h" -#include #include struct llama_vocab; struct llama_grammar; +// samplers + +const char * llama_sampler_name_impl (const struct llama_sampler & smpl); +void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); +void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); +void llama_sampler_reset_impl ( struct llama_sampler & smpl); +struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); +void llama_sampler_free_impl ( struct llama_sampler * smpl); + +// sampler chain + +struct llama_sampler_chain { + llama_sampler_chain_params params; + + std::vector samplers; + + // timing + + mutable int64_t t_sample_us; + + mutable int32_t n_sample; +}; + +struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params); +void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl); +struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i); +int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain); + using llama_token_cnt = std::unordered_map; // TODO: tmp exposed until test-sampling is fixed -void llama_constraint_penalties_impl( +void llama_sampler_penalties_impl( llama_token_data_array * cur_p, const llama_token_cnt & token_count, float penalty_repeat, float penalty_freq, float penalty_present); -// constraints - -struct llama_constraint * llama_constraint_init_softmax_impl (); -struct llama_constraint * llama_constraint_init_top_k_impl (int32_t k); -struct llama_constraint * llama_constraint_init_top_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_min_p_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_tail_free_impl (float z, size_t min_keep); -struct llama_constraint * llama_constraint_init_typical_impl (float p, size_t min_keep); -struct llama_constraint * llama_constraint_init_temp_impl (float t); -struct llama_constraint * llama_constraint_init_temp_ext_impl (float t, float delta, float exponent); - -struct llama_constraint * llama_constraint_init_mirostat_impl( +struct llama_sampler * llama_sampler_init_greedy_impl (); +struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed); +struct llama_sampler * llama_sampler_init_softmax_impl (); +struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k); +struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep); +struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep); +struct llama_sampler * llama_sampler_init_temp_impl (float t); +struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent); + +struct llama_sampler * llama_sampler_init_mirostat_impl( const struct llama_vocab & vocab, float tau, float eta, int32_t m); -struct llama_constraint * llama_constraint_init_mirostat_v2_impl( +struct llama_sampler * llama_sampler_init_mirostat_v2_impl( float tau, float eta); -struct llama_constraint * llama_constraint_init_grammar_impl( +struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); -struct llama_constraint * llama_constraint_init_penalties_impl( +struct llama_sampler * llama_sampler_init_penalties_impl( const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, @@ -53,58 +80,7 @@ struct llama_constraint * llama_constraint_init_penalties_impl( bool penalize_nl, bool ignore_eos); - LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias_impl( + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); - -struct llama_constraint * llama_constraint_clone_impl(const struct llama_constraint & cnstr); - -void llama_constraint_free_impl(struct llama_constraint * cnstr); - -const char * llama_constraint_name_impl (const struct llama_constraint & cnstr); -void llama_constraint_accept_impl( struct llama_constraint & cnstr, llama_token token); -void llama_constraint_apply_impl ( struct llama_constraint & cnstr, struct llama_token_data_array * cur_p); -void llama_constraint_reset_impl ( struct llama_constraint & cnstr); - -// samplers - -struct llama_sampler { - llama_sampler_params params; - - const struct llama_vocab * vocab; - - // state - - std::mt19937 rng; - - ring_buffer prev; - - std::vector constraints; - - std::vector cur; - - llama_token_data_array cur_p; - - // timing - - mutable int64_t t_sample_us; - - mutable int32_t n_sample; -}; - -struct llama_sampler * llama_sampler_init_impl (const struct llama_vocab & vocab, struct llama_sampler_params params); -void llama_sampler_free_impl ( struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); -void llama_sampler_reset_impl ( struct llama_sampler & smpl); -void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); - -void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr); -int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl); -struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith); - -llama_token llama_sampler_sample_impl(struct llama_token_data_array * cur_p, std::mt19937 & rng, enum llama_sampler_type type); - -llama_token llama_sampler_prev_impl (const struct llama_sampler & smpl, int ith); -int llama_sampler_n_prev_impl(const struct llama_sampler & smpl); diff --git a/src/llama.cpp b/src/llama.cpp index 2636f2316104b..df12de7add9c1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -147,21 +147,6 @@ static void zeros(std::ofstream & file, size_t n) { } } -struct time_meas { - time_meas(int64_t & t_acc, bool disable = false) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {} - - ~time_meas() { - if (t_start_us >= 0) { - t_acc += ggml_time_us() - t_start_us; - } - } - - const int64_t t_start_us; - - int64_t & t_acc; -}; - - LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -17937,11 +17922,8 @@ struct llama_context_params llama_context_default_params() { return result; } -struct llama_sampler_params llama_sampler_default_params() { - struct llama_sampler_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_prev =*/ 256, - /*.type =*/ LLAMA_SAMPLER_TYPE_DIST, +struct llama_sampler_chain_params llama_sampler_chain_default_params() { + struct llama_sampler_chain_params result = { /*.no_timing =*/ false, // TODO: change to true and set explicitly in examples }; @@ -20610,188 +20592,138 @@ int32_t llama_chat_apply_template( // sampling // -struct llama_constraint * llama_constraint_init_softmax(void) { - return llama_constraint_init_softmax_impl(); +const char * llama_sampler_name(const struct llama_sampler * smpl) { + return llama_sampler_name_impl(*smpl); } -struct llama_constraint * llama_constraint_init_top_k(int32_t k) { - return llama_constraint_init_top_k_impl(k); +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + llama_sampler_accept_impl(*smpl, token); } -struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep) { - return llama_constraint_init_top_p_impl(p, min_keep); +void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + llama_sampler_apply_impl(*smpl, cur_p); } -struct llama_constraint * llama_constraint_init_min_p(float p, int32_t min_keep) { - return llama_constraint_init_min_p_impl(p, min_keep); +void llama_sampler_reset(struct llama_sampler * smpl) { + llama_sampler_reset_impl(*smpl); } -struct llama_constraint * llama_constraint_init_tail_free(float z, int32_t min_keep) { - return llama_constraint_init_tail_free_impl(z, min_keep); +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + return llama_sampler_clone_impl(*smpl); } -struct llama_constraint * llama_constraint_init_typical(float p, int32_t min_keep) { - return llama_constraint_init_typical_impl(p, min_keep); -} +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } -struct llama_constraint * llama_constraint_init_temp(float temp) { - return llama_constraint_init_temp_impl(temp); + llama_sampler_free_impl(smpl); } -struct llama_constraint * llama_constraint_init_temp_ext(float temp, float delta, float exponent) { - return llama_constraint_init_temp_ext_impl(temp, delta, exponent); +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return llama_sampler_chain_init_impl(params); } -struct llama_constraint * llama_constraint_init_mirostat(const struct llama_model * model, float tau, float eta) { - return llama_constraint_init_mirostat_impl(model->vocab, tau, eta, 100); +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl); } -struct llama_constraint * llama_constraint_init_mirostat_v2(float tau, float eta) { - return llama_constraint_init_mirostat_v2_impl(tau, eta); +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i); } -struct llama_constraint * llama_constraint_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { - return llama_constraint_init_grammar_impl(model->vocab, grammar_str, grammar_root); +int llama_sampler_chain_n(const struct llama_sampler * chain) { + return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx); } -struct llama_constraint * llama_constraint_init_penalties( - const struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - return llama_constraint_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); +struct llama_sampler * llama_sampler_init_greedy(void) { + return llama_sampler_init_greedy_impl(); } -LLAMA_API struct llama_constraint * llama_constraint_init_logit_bias( - const struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { - return llama_constraint_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { + return llama_sampler_init_dist_impl(seed); } -struct llama_constraint * llama_constraint_clone(const struct llama_constraint * cnstr) { - return llama_constraint_clone_impl(*cnstr); +struct llama_sampler * llama_sampler_init_softmax(void) { + return llama_sampler_init_softmax_impl(); } -void llama_constraint_free(struct llama_constraint * cnstr) { - if (cnstr == nullptr) { - return; - } - - llama_constraint_free_impl(cnstr); +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + return llama_sampler_init_top_k_impl(k); } -const char * llama_constraint_name(const struct llama_constraint * cnstr) { - return llama_constraint_name_impl(*cnstr); +struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) { + return llama_sampler_init_top_p_impl(p, min_keep); } -void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token) { - llama_constraint_accept_impl(*cnstr, token); +struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) { + return llama_sampler_init_min_p_impl(p, min_keep); } -void llama_constraint_apply(struct llama_constraint * cnstr, llama_token_data_array * cur_p) { - llama_constraint_apply_impl(*cnstr, cur_p); +struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) { + return llama_sampler_init_tail_free_impl(z, min_keep); } -void llama_constraint_reset(struct llama_constraint * cnstr) { - llama_constraint_reset_impl(*cnstr); +struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) { + return llama_sampler_init_typical_impl(p, min_keep); } -struct llama_sampler * llama_sampler_init(const struct llama_model * model, struct llama_sampler_params params) { - return llama_sampler_init_impl(model->vocab, params); +struct llama_sampler * llama_sampler_init_temp(float temp) { + return llama_sampler_init_temp_impl(temp); } -void llama_sampler_free(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - llama_sampler_free_impl(smpl); +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { + return llama_sampler_init_temp_ext_impl(temp, delta, exponent); } -struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { - return llama_sampler_clone_impl(*smpl); +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) { + return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100); } -void llama_sampler_reset(struct llama_sampler * smpl) { - llama_sampler_reset_impl(*smpl); -} - -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); +struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) { + return llama_sampler_init_mirostat_v2_impl(tau, eta); } -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us, smpl->params.no_timing); - - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; - } - - llama_sampler_apply_impl(*smpl, cur_p); +struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { + return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } -void llama_sampler_set_logits(struct llama_sampler * smpl, const float * logits) { - const int n_vocab = smpl->vocab->n_vocab; - - smpl->cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - smpl->cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - smpl->cur_p = { smpl->cur.data(), smpl->cur.size(), false }; -} - -llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl) { - return &smpl->cur_p; -} - -void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) { - llama_sampler_constraint_add_impl(*smpl, cnstr); +struct llama_sampler * llama_sampler_init_penalties( + const struct llama_model * model, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); } -int llama_sampler_n_constraints (const struct llama_sampler * smpl) { - return llama_sampler_n_constraints_impl(*smpl); +LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( + const struct llama_model * model, + int32_t n_logit_bias, + const llama_logit_bias * logit_bias) { + return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } -struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) { - return llama_sampler_constraint_get_impl(*smpl, i); -} +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); -llama_token llama_sampler_sample(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - time_meas tm(smpl->t_sample_us, smpl->params.no_timing); + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - if (cur_p == nullptr) { - cur_p = &smpl->cur_p; + // TODO: do not allocate each time + std::vector cur(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - auto res = llama_sampler_sample_impl(cur_p, smpl->rng, smpl->params.type); + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - smpl->n_sample++; + llama_sampler_apply(smpl, &cur_p); - return res; -} - -int llama_sampler_n_prev(const struct llama_sampler * smpl) { - return llama_sampler_n_prev_impl(*smpl); -} - -llama_token llama_sampler_prev(const struct llama_sampler * smpl, int32_t ith) { - return llama_sampler_prev_impl(*smpl, ith); + return cur_p.data[cur_p.selected].id; } -llama_token llama_sampler_last(const struct llama_sampler * smpl) { - return llama_sampler_prev_impl(*smpl, 0); -} - -//llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i) { -// GGML_ABORT("not implemented"); -//} - // // model split // @@ -20820,7 +20752,9 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl) { +void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain) { + auto * smpl = chain ? (const struct llama_sampler_chain *) chain->ctx : nullptr; + const llama_timings timings = { /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, /*.t_end_ms =*/ 1.00 * ggml_time_ms(), @@ -20845,13 +20779,15 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampler * smpl LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * smpl) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * chain) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - if (smpl) { - smpl->t_sample_us = smpl->n_sample = 0; + if (chain) { + auto * smpl = (struct llama_sampler_chain *) chain->ctx; + + smpl->t_sample_us = smpl->n_sample = 0; } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 74bb4a3a3a40c..adc1ff4e6da7d 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -21,8 +21,8 @@ static void dump(const llama_token_data_array * cur_p) { #define APPLY(__cnstr, __cur_p) do { \ auto * cnstr = (__cnstr); \ - llama_constraint_apply(cnstr, (__cur_p)); \ - llama_constraint_free(cnstr); \ + llama_sampler_apply(cnstr, (__cur_p)); \ + llama_sampler_free(cnstr); \ } while(0) static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { @@ -35,10 +35,10 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array cur_p = { cur.data(), cur.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; DUMP(&cur_p); - APPLY(llama_constraint_init_tail_free(z, 1), &cur_p); + APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); GGML_ASSERT(cur_p.size == expected_probs.size()); @@ -100,11 +100,11 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector Date: Thu, 5 Sep 2024 17:08:46 +0300 Subject: [PATCH 30/47] sampling : fix cloning of samplers with null ctx ggml-ci --- src/llama-grammar.cpp | 1 + src/llama-sampling.cpp | 13 ++++++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 09f756fbec727..353cb398a27a5 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -3,6 +3,7 @@ #include "llama-vocab.h" #include "llama-sampling.h" +#include #include #include diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 735992faa94bb..8ff52dd2decd1 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1258,7 +1258,18 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) { } struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { - return smpl.iface->clone ? smpl.iface->clone(&smpl) : nullptr; + if (smpl.iface->clone) { + return smpl.iface->clone(&smpl); + } + + if (smpl.ctx == nullptr) { + return new llama_sampler { + /* .iface = */ smpl.iface, + /* .ctx = */ nullptr, + }; + } + + GGML_ABORT("the sampler does not support cloning"); } void llama_sampler_free_impl(struct llama_sampler * smpl) { From bd8835283422ce4459c3f2dc0cfdf827c3de5878 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 18:10:09 +0300 Subject: [PATCH 31/47] ios : try to fix build --- examples/batched.swift/Sources/main.swift | 6 +++--- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 24f8a7027bc10..e9acdc7ac86aa 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -61,9 +61,9 @@ defer { llama_sampler_free(smpl) } -llama_sampler_sampler_add(smpl, llama_sampler_init_top_k(40)); -llama_sampler_sampler_add(smpl, llama_sampler_init_top_p(0.9, 1)); -llama_sampler_sampler_add(smpl, llama_sampler_init_temp (0.4)); +llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40)); +llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); +llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4)); let n_ctx = llama_n_ctx(context) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 73cabc6c7444c..24cef348ce734 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,7 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer - private var sampling: OpaquePointer + private var sampling: llama_sampler private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false From 82a89df960571b9f2a3afcc853066531c8f833f0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 5 Sep 2024 18:07:47 +0300 Subject: [PATCH 32/47] sampling : improve mirostat implementation ggml-ci --- common/sampling.cpp | 22 +++---- common/sampling.h | 2 +- include/llama.h | 2 + src/llama-sampling.cpp | 133 ++++++++++++++++++++-------------------- src/llama-sampling.h | 2 + src/llama.cpp | 8 +-- tests/test-sampling.cpp | 14 ++--- 7 files changed, 95 insertions(+), 88 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index de7c9b1b97395..cf3ee98d4c744 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -121,7 +121,7 @@ struct gpt_sampler { cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; } - cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + cur_p = { cur.data(), cur.size(), -1, false }; } }; @@ -202,17 +202,17 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st GGML_ASSERT(false && "unknown sampler type"); } } + llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); + llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); } else { GGML_ASSERT(false && "unknown mirostat version"); } - llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); - llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else { llama_sampler_chain_add(result->chain, llama_sampler_init_softmax()); llama_sampler_chain_add(result->chain, llama_sampler_init_greedy()); @@ -246,8 +246,8 @@ struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { }; } -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar) { - if (apply_grammar) { +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { + if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); } @@ -293,9 +293,9 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(chain, &cur_p); - const llama_token id = cur_p.data[cur_p.selected].id; + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); - GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); + const llama_token id = cur_p.data[cur_p.selected].id; if (grammar_first) { return id; @@ -304,7 +304,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // check if it the sampled token fits the grammar { llama_token_data single_token_data = { id, 1.0f, 0.0f }; - llama_token_data_array single_token_data_array = { &single_token_data, 1, LLAMA_TOKEN_NULL, false }; + llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false }; llama_sampler_apply(grmr, &single_token_data_array); @@ -324,7 +324,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(chain, &cur_p); - GGML_ASSERT(cur_p.data[cur_p.selected].id != LLAMA_TOKEN_NULL && "null token in the sampling history - check your sampling configuration"); + GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); return cur_p.data[cur_p.selected].id; } diff --git a/common/sampling.h b/common/sampling.h index 5083f456f1f96..d88038204c89f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -70,7 +70,7 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl); struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool apply_grammar); +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); void gpt_sampler_reset (struct gpt_sampler * gsmpl); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); diff --git a/include/llama.h b/include/llama.h index f735413894912..50c89c10fa2b2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1068,6 +1068,7 @@ extern "C" { /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( const struct llama_model * model, + uint32_t seed, float tau, float eta); @@ -1077,6 +1078,7 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat_v2( + uint32_t seed, float tau, float eta); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8ff52dd2decd1..0084fe0b7dc4c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -11,6 +11,17 @@ #include #include +static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng, std::vector & probs) { + probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + probs[i] = cur_p->data[i].p; + } + + std::discrete_distribution dist(probs.begin(), probs.end()); + + return dist(rng); +} + static void llama_log_softmax(float * array, size_t size) { float max_l = *std::max_element(array, array + size); float sum = 0.f; @@ -456,6 +467,8 @@ struct llama_sampler_context_dist { const uint32_t seed; std::mt19937 rng; + + std::vector probs; // work array }; static struct llama_sampler_i llama_sampler_dist_i = { @@ -463,15 +476,7 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_dist *) smpl->ctx; - std::vector probs; - probs.reserve(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - probs.push_back(cur_p->data[i].p); - } - - std::discrete_distribution dist(probs.begin(), probs.end()); - - cur_p->selected = dist(ctx->rng); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { @@ -489,6 +494,7 @@ struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { /* .ctx = */ new llama_sampler_context_dist { /* .seed = */ seed, /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -761,6 +767,8 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, struct llama_sampler_context_mirostat { const struct llama_vocab * vocab; + const uint32_t seed; + const float tau; const float eta; @@ -768,28 +776,14 @@ struct llama_sampler_context_mirostat { float mu; - std::vector cur; + std::mt19937 rng; + + std::vector probs; }; static struct llama_sampler_i llama_sampler_mirostat_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; - - int32_t idx = -1; - for (size_t i = 0; i < ctx->cur.size(); ++i) { - if (ctx->cur[i].id == token) { - idx = i; - break; - } - } - - float observed_surprise = -log2f(ctx->cur[idx].p); - float e = observed_surprise - ctx->tau; - - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, + /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; @@ -812,36 +806,44 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); + llama_sampler_softmax_impl(cur_p); - // remember the order to be able to compute the distance later when accepting the token - ctx->cur.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - ctx->cur[i] = cur_p->data[i]; - } + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->tau, ctx->eta, ctx->m); + return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, float tau, float eta, int32_t m) { +struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_context_mirostat { /* .vocab = */ &vocab, + /* .seed = */ seed, /* .tau = */ tau, /* .eta = */ eta, /* .m = */ m, /* .mu = */ 2.0f*tau, - /* .cur = */ {}, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -849,33 +851,21 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab // mirostat v2 struct llama_sampler_context_mirostat_v2 { + const uint32_t seed; + const float tau; const float eta; float mu; - std::vector cur; + std::mt19937 rng; + + std::vector probs; }; static struct llama_sampler_i llama_sampler_mirostat_v2_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; - - int32_t idx = -1; - for (size_t i = 0; i < ctx->cur.size(); ++i) { - if (ctx->cur[i].id == token) { - idx = i; - break; - } - } - - float observed_surprise = -log2f(ctx->cur[idx].p); - float e = observed_surprise - ctx->tau; - - // Update mu using the learning rate and error - ctx->mu = ctx->mu - ctx->eta * e; - }, + /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; @@ -893,33 +883,40 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { // Normalize the probabilities of the remaining words llama_sampler_softmax_impl(cur_p); - // remember the order to be able to compute the distance later when accepting the token - ctx->cur.resize(cur_p->size); - for (size_t i = 0; i < cur_p->size; ++i) { - ctx->cur[i] = cur_p->data[i]; - } + const int idx = llama_sample_dist(cur_p, ctx->rng, ctx->probs); + + cur_p->selected = idx; + + float observed_surprise = -log2f(cur_p->data[idx].p); + float e = observed_surprise - ctx->tau; + + // Update mu using the learning rate and error + ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; + ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2_impl(ctx->tau, ctx->eta); + return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_v2_impl(float tau, float eta) { +struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_context_mirostat_v2 { - /* .tau = */ tau, - /* .eta = */ eta, - /* .mu = */ 2.0f*tau, - /* .cur = */ {}, + /* .seed = */ seed, + /* .tau = */ tau, + /* .eta = */ eta, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -1154,9 +1151,15 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl( static struct llama_sampler_i llama_sampler_chain_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token /*token*/) { + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { auto * chain = (llama_sampler_chain *) smpl->ctx; + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept_impl(*smpl, token); + } + chain->n_sample++; }, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 3f14ec621f5c1..0088060c8d971 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -58,11 +58,13 @@ struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta struct llama_sampler * llama_sampler_init_mirostat_impl( const struct llama_vocab & vocab, + uint32_t seed, float tau, float eta, int32_t m); struct llama_sampler * llama_sampler_init_mirostat_v2_impl( + uint32_t seed, float tau, float eta); diff --git a/src/llama.cpp b/src/llama.cpp index df12de7add9c1..ce920965834eb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20676,12 +20676,12 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa return llama_sampler_init_temp_ext_impl(temp, delta, exponent); } -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, float tau, float eta) { - return llama_sampler_init_mirostat_impl(model->vocab, tau, eta, 100); +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) { + return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100); } -struct llama_sampler * llama_sampler_init_mirostat_v2(float tau, float eta) { - return llama_sampler_init_mirostat_v2_impl(tau, eta); +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { + return llama_sampler_init_mirostat_v2_impl(seed, tau, eta); } struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index adc1ff4e6da7d..cc4882d37579a 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -35,7 +35,7 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector cur.emplace_back(llama_token_data{token_id, logit, 0.0f}); } - llama_token_data_array cur_p = { cur.data(), cur.size(), LLAMA_TOKEN_NULL, false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; DUMP(&cur_p); APPLY(llama_sampler_init_tail_free(z, 1), &cur_p); DUMP(&cur_p); @@ -100,7 +100,7 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector Date: Thu, 5 Sep 2024 18:29:53 +0300 Subject: [PATCH 33/47] swift : fix example --- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 24cef348ce734..92f61fe83081d 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,7 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer - private var sampling: llama_sampler + private var sampling: UnsafeMutablePointer private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false @@ -43,8 +43,11 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - var sparams = llama_sampler_chain_default_params() + let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) + llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4)) + llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax()) + llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234)) } deinit { From 8c972b69c167a10423c38c2e7f89f11e56d2cddf Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 11:58:11 +0300 Subject: [PATCH 34/47] grammar : restore llama_grammar_accept signature ggml-ci --- examples/gbnf-validator/gbnf-validator.cpp | 12 ++++++------ src/llama-grammar.cpp | 19 ++++++++++--------- src/llama-grammar.h | 5 +++-- tests/test-grammar-integration.cpp | 18 +++++++++--------- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index f439c0c5648a8..7493af9d3aec3 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -12,24 +12,24 @@ static bool llama_grammar_validate(struct llama_grammar * grammar, const std::st const auto cpts = unicode_cpts_from_utf8(input_str); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); size_t pos = 0; for (const auto & cpt : cpts) { - const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); - if (cur_stacks.empty()) { + if (stacks_cur.empty()) { error_pos = pos; error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'"; - cur_stacks = prev_stacks; + stacks_cur = stacks_prev; return false; } ++pos; } - for (const auto & stack : cur_stacks) { + for (const auto & stack : stacks_cur) { if (stack.empty()) { return true; } diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 353cb398a27a5..74e9f64b393b2 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -822,12 +822,13 @@ llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) return grammar->stacks; } -llama_grammar_stacks llama_grammar_accept( +void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - const uint32_t chr) { - llama_grammar_stacks result; - result.reserve(stacks.size()); + const uint32_t chr, + llama_grammar_stacks & stacks_new) { + stacks_new.clear(); + stacks_new.reserve(stacks.size()); for (const auto & stack : stacks) { if (stack.empty()) { @@ -843,11 +844,9 @@ llama_grammar_stacks llama_grammar_accept( if (!llama_grammar_is_end_of_sequence(pos)) { new_stack.push_back(pos); } - llama_grammar_advance_stack(rules, new_stack, result); + llama_grammar_advance_stack(rules, new_stack, stacks_new); } } - - return result; } llama_grammar_candidates llama_grammar_reject_candidates_for_stack( @@ -1127,9 +1126,11 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; + llama_grammar_stacks stacks_new; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_stacks new_stacks = llama_grammar_accept(grammar.rules, grammar.stacks, *it); - grammar.stacks = std::move(new_stacks); + llama_grammar_accept(grammar.rules, grammar.stacks, *it, stacks_new); + grammar.stacks = std::move(stacks_new); } grammar.partial_utf8 = decoded.second; diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 419a616d644dc..f529ce351e416 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -65,10 +65,11 @@ const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar // be positioned at a character range (see `llama_grammar_advance_stack`), and // produces the N possible stacks if the given char is accepted at those // positions -llama_grammar_stacks llama_grammar_accept( +void llama_grammar_accept( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, - uint32_t chr); + uint32_t chr, + llama_grammar_stacks & stacks_new); std::vector llama_grammar_reject_candidates_for_stack( const llama_grammar_rules & rules, diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 788b02a6a5cd8..5cc0cdb04751f 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -33,20 +33,20 @@ static bool match_string(const std::string & input, llama_grammar * grammar) { const auto cpts = unicode_cpts_from_utf8(input); const llama_grammar_rules & rules = llama_grammar_get_rules (grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); for (const auto & cpt : cpts) { - const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy + const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy - cur_stacks = llama_grammar_accept(rules, prev_stacks, cpt); + llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur); - if (cur_stacks.empty()) { + if (stacks_cur.empty()) { // no stacks means that the grammar failed to match at this point return false; } } - for (const auto & stack : cur_stacks) { + for (const auto & stack : stacks_cur) { if (stack.empty()) { // An empty stack means that the grammar has been completed return true; @@ -63,9 +63,9 @@ static void test(const std::string & test_desc, const std::string & grammar_str, auto * grammar = build_grammar(grammar_str); // Save the original grammar stacks so that we can reset after every new string we want to test - const llama_grammar_stacks original_stacks = llama_grammar_get_stacks(grammar); + const llama_grammar_stacks stacks_org = llama_grammar_get_stacks(grammar); - llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar); + llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar); fprintf(stderr, " 🔵 Valid strings:\n"); @@ -102,7 +102,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, assert(matched); // Reset the grammar stacks - cur_stacks = original_stacks; + stacks_cur = stacks_org; } fprintf(stderr, " 🟠 Invalid strings:\n"); @@ -122,7 +122,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, assert(!matched); // Reset the grammar stacks - cur_stacks = original_stacks; + stacks_cur = stacks_org; } // Clean up allocated memory From 809bdcf76778ad7af0df257dc730aedb0f18a321 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 12:06:00 +0300 Subject: [PATCH 35/47] sampling : allow passing m to mirostat sampler --- common/sampling.cpp | 2 +- include/llama.h | 3 ++- src/llama.cpp | 4 ++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index cf3ee98d4c744..9964501da7ccd 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -206,7 +206,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/include/llama.h b/include/llama.h index 50c89c10fa2b2..02c565a3d6bc5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1070,7 +1070,8 @@ extern "C" { const struct llama_model * model, uint32_t seed, float tau, - float eta); + float eta, + int32_t m); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. diff --git a/src/llama.cpp b/src/llama.cpp index ce920965834eb..db50a0332e4c4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20676,8 +20676,8 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa return llama_sampler_init_temp_ext_impl(temp, delta, exponent); } -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta) { - return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, 100); +struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { + return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); } struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { From b448c753b939a515fa0afa74ad05e89a8efca1e9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 12:39:43 +0300 Subject: [PATCH 36/47] sampling : remove redundant indirection calls ggml-ci --- include/llama.h | 12 +- src/llama-sampling.cpp | 341 ++++++++++++++++++++++------------------- src/llama-sampling.h | 32 +--- src/llama.cpp | 112 +------------- 4 files changed, 193 insertions(+), 304 deletions(-) diff --git a/include/llama.h b/include/llama.h index 02c565a3d6bc5..29c216f2d9581 100644 --- a/include/llama.h +++ b/include/llama.h @@ -992,9 +992,9 @@ extern "C" { // // Sampling API // - // In the future, it will be utilized offload the sampling to the backends (e.g. GPU). + // In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // - // TODO: in the future, the entire API should be changed to accept llama_vocab, instead of llama_model + // TODO: in the future, the entire API that uses llama_model should start using llama_vocab typedef void * llama_sampler_context_t; @@ -1045,16 +1045,16 @@ extern "C" { LLAMA_API struct llama_sampler * llama_sampler_init_top_k (int32_t k); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_top_p (float p, size_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_min_p (float p, size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_tail_free (float z, size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, int32_t min_keep); + LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0084fe0b7dc4c..92d39222d3c03 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -432,6 +432,167 @@ void llama_sampler_penalties_impl( cur_p->sorted = false; } +// llama_sampler API + +const char * llama_sampler_name(const struct llama_sampler * smpl) { + if (!smpl->iface) { + return "(null)"; + } + + return smpl->iface->name(smpl); +} + +void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { + if (smpl->iface->accept) { + smpl->iface->accept(smpl, token); + } +} + +void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) { + GGML_ASSERT(smpl->iface->apply); + smpl->iface->apply(smpl, cur_p); +} + +void llama_sampler_reset(struct llama_sampler * smpl) { + if (smpl->iface->reset) { + smpl->iface->reset(smpl); + } +} + +struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { + if (smpl->iface->clone) { + return smpl->iface->clone(smpl); + } + + if (smpl->ctx == nullptr) { + return new llama_sampler { + /* .iface = */ smpl->iface, + /* .ctx = */ nullptr, + }; + } + + GGML_ABORT("the sampler does not support cloning"); +} + +void llama_sampler_free(struct llama_sampler * smpl) { + if (smpl == nullptr) { + return; + } + + if (smpl->iface->free) { + smpl->iface->free(smpl); + } + + delete smpl; +} + +llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { + const auto * logits = llama_get_logits_ith(ctx, idx); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + + // TODO: do not allocate each time + std::vector cur(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } + + llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; + + llama_sampler_apply(smpl, &cur_p); + + return cur_p.data[cur_p.selected].id; +} + +// sampler chain + +static struct llama_sampler_i llama_sampler_chain_i = { + /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, + /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_accept(smpl, token); + } + + chain->n_sample++; + }, + /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + time_meas tm(chain->t_sample_us, chain->params.no_timing); + + for (auto * smpl : chain->samplers) { + llama_sampler_apply(smpl, cur_p); + } + }, + /* .reset = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_reset(smpl); + } + + chain->t_sample_us = 0; + chain->n_sample = 0; + }, + /* .clone = */ [](const struct llama_sampler * smpl) { + const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; + + auto * result = llama_sampler_chain_init(chain_src->params); + + for (auto * smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + } + + return result; + }, + /* .free = */ [](struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto * smpl : chain->samplers) { + llama_sampler_free(smpl); + } + + delete chain; + }, +}; + +struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { + return new llama_sampler { + /* .iface = */ &llama_sampler_chain_i, + /* .ctx = */ new llama_sampler_chain { + /* .params = */ params, + /* .samplers = */ {}, + /* .t_sample_us = */ 0, + /* .n_sample = */ 0, + }, + }; +} + +void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { + auto * p = (llama_sampler_chain *) chain->ctx; + p->samplers.push_back(smpl); +} + +struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + if (i < 0 || i >= (int32_t) p->samplers.size()) { + return nullptr; + } + + return p->samplers[i]; +} + +int llama_sampler_chain_n(const struct llama_sampler * chain) { + const auto * p = (const llama_sampler_chain *) chain->ctx; + + return p->samplers.size(); +} + // // samplers // @@ -454,7 +615,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = { /* .free = */ nullptr, }; -struct llama_sampler * llama_sampler_init_greedy_impl() { +struct llama_sampler * llama_sampler_init_greedy() { return new llama_sampler { /* .iface = */ &llama_sampler_greedy_i, /* .ctx = */ nullptr, @@ -481,14 +642,14 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; - return llama_sampler_init_dist_impl(ctx->seed); + return llama_sampler_init_dist(ctx->seed); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_dist *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_dist_impl(uint32_t seed) { +struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return new llama_sampler { /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_context_dist { @@ -512,7 +673,7 @@ static struct llama_sampler_i llama_sampler_softmax_i = { /* .free = */ nullptr, }; -struct llama_sampler * llama_sampler_init_softmax_impl() { +struct llama_sampler * llama_sampler_init_softmax() { return new llama_sampler { /* .iface = */ &llama_sampler_softmax_i, /* .ctx = */ nullptr, @@ -535,14 +696,14 @@ static struct llama_sampler_i llama_sampler_top_k_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; - return llama_sampler_init_top_k_impl(ctx->k); + return llama_sampler_init_top_k(ctx->k); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_top_k *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_top_k_impl(int32_t k) { +struct llama_sampler * llama_sampler_init_top_k(int32_t k) { return new llama_sampler { /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_context_top_k { @@ -568,14 +729,14 @@ static struct llama_sampler_i llama_sampler_top_p_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; - return llama_sampler_init_top_p_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_top_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_top_p *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_top_p_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_context_top_p { @@ -602,14 +763,14 @@ static struct llama_sampler_i llama_sampler_min_p_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; - return llama_sampler_init_min_p_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_min_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_min_p *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_min_p_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_context_min_p { @@ -636,14 +797,14 @@ static struct llama_sampler_i llama_sampler_tail_free_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; - return llama_sampler_init_tail_free_impl(ctx->z, ctx->min_keep); + return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_tail_free *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep) { +struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_tail_free_i, /* .ctx = */ new llama_sampler_context_tail_free { @@ -670,14 +831,14 @@ static struct llama_sampler_i llama_sampler_typical_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; - return llama_sampler_init_typical_impl(ctx->p, ctx->min_keep); + return llama_sampler_init_typical(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_typical *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_typical_impl(float p, size_t min_keep) { +struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_context_typical { @@ -703,14 +864,14 @@ static struct llama_sampler_i llama_sampler_temp_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; - return llama_sampler_init_temp_impl(ctx->temp); + return llama_sampler_init_temp(ctx->temp); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_temp *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_temp_impl(float temp) { +struct llama_sampler * llama_sampler_init_temp(float temp) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_context_temp { @@ -744,14 +905,14 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; - return llama_sampler_init_temp_ext_impl(ctx->temp, ctx->delta, ctx->exponent); + return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_temp_ext *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_temp_ext_impl(float temp, float delta, float exponent) { +struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_context_temp_ext { @@ -900,14 +1061,14 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2_impl(ctx->seed, ctx->tau, ctx->eta); + return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_v2_impl(uint32_t seed, float tau, float eta) { +struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, /* .ctx = */ new llama_sampler_context_mirostat_v2 { @@ -1146,143 +1307,3 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl( }, }; } - -// sampler chain - -static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "chain"; }, - /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_timing); - - for (auto * smpl : chain->samplers) { - llama_sampler_accept_impl(*smpl, token); - } - - chain->n_sample++; - }, - /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - time_meas tm(chain->t_sample_us, chain->params.no_timing); - - for (auto * smpl : chain->samplers) { - llama_sampler_apply_impl(*smpl, cur_p); - } - }, - /* .reset = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_reset_impl(*smpl); - } - - chain->t_sample_us = 0; - chain->n_sample = 0; - }, - /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * chain_src = (const llama_sampler_chain *) smpl->ctx; - - auto * result = llama_sampler_chain_init_impl(chain_src->params); - - auto * chain_dst = (llama_sampler_chain *) result->ctx; - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add_impl(*chain_dst, llama_sampler_clone_impl(*smpl)); - } - - return result; - }, - /* .free = */ [](struct llama_sampler * smpl) { - auto * chain = (llama_sampler_chain *) smpl->ctx; - - for (auto * smpl : chain->samplers) { - llama_sampler_free_impl(smpl); - } - - delete chain; - }, -}; - -struct llama_sampler * llama_sampler_chain_init_impl(struct llama_sampler_chain_params params) { - return new llama_sampler { - /* .iface = */ &llama_sampler_chain_i, - /* .ctx = */ new llama_sampler_chain { - /* .params = */ params, - /* .samplers = */ {}, - /* .t_sample_us = */ 0, - /* .n_sample = */ 0, - }, - }; -} - -void llama_sampler_chain_add_impl(struct llama_sampler_chain & chain, struct llama_sampler * smpl) { - chain.samplers.push_back(smpl); -} - -struct llama_sampler * llama_sampler_chain_get_impl(const struct llama_sampler_chain & chain, int32_t i) { - if (i < 0 || i >= (int32_t) chain.samplers.size()) { - return nullptr; - } - - return chain.samplers[i]; -} - -int llama_sampler_chain_n_impl(const struct llama_sampler_chain & chain) { - return chain.samplers.size(); -} - - -//////////////////////////////////////// - -const char * llama_sampler_name_impl(const struct llama_sampler & smpl) { - if (!smpl.iface) { - return "(null)"; - } - - return smpl.iface->name(&smpl); -} - -void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) { - if (smpl.iface->accept) { - smpl.iface->accept(&smpl, token); - } -} - -void llama_sampler_apply_impl(struct llama_sampler & smpl, struct llama_token_data_array * cur_p) { - GGML_ASSERT(smpl.iface->apply); - smpl.iface->apply(&smpl, cur_p); -} - -void llama_sampler_reset_impl(struct llama_sampler & smpl) { - if (smpl.iface->reset) { - smpl.iface->reset(&smpl); - } -} - -struct llama_sampler * llama_sampler_clone_impl(const struct llama_sampler & smpl) { - if (smpl.iface->clone) { - return smpl.iface->clone(&smpl); - } - - if (smpl.ctx == nullptr) { - return new llama_sampler { - /* .iface = */ smpl.iface, - /* .ctx = */ nullptr, - }; - } - - GGML_ABORT("the sampler does not support cloning"); -} - -void llama_sampler_free_impl(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - if (smpl->iface->free) { - smpl->iface->free(smpl); - } - - delete smpl; -} diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 0088060c8d971..05bb294a10d2f 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -7,15 +7,6 @@ struct llama_vocab; struct llama_grammar; -// samplers - -const char * llama_sampler_name_impl (const struct llama_sampler & smpl); -void llama_sampler_accept_impl( struct llama_sampler & smpl, llama_token token); -void llama_sampler_apply_impl ( struct llama_sampler & smpl, struct llama_token_data_array * cur_p); -void llama_sampler_reset_impl ( struct llama_sampler & smpl); -struct llama_sampler * llama_sampler_clone_impl (const struct llama_sampler & smpl); -void llama_sampler_free_impl ( struct llama_sampler * smpl); - // sampler chain struct llama_sampler_chain { @@ -30,11 +21,6 @@ struct llama_sampler_chain { mutable int32_t n_sample; }; -struct llama_sampler * llama_sampler_chain_init_impl( struct llama_sampler_chain_params params); -void llama_sampler_chain_add_impl ( struct llama_sampler_chain & chain, struct llama_sampler * smpl); -struct llama_sampler * llama_sampler_chain_get_impl (const struct llama_sampler_chain & chain, int32_t i); -int llama_sampler_chain_n_impl (const struct llama_sampler_chain & chain); - using llama_token_cnt = std::unordered_map; // TODO: tmp exposed until test-sampling is fixed @@ -45,17 +31,6 @@ void llama_sampler_penalties_impl( float penalty_freq, float penalty_present); -struct llama_sampler * llama_sampler_init_greedy_impl (); -struct llama_sampler * llama_sampler_init_dist_impl (uint32_t seed); -struct llama_sampler * llama_sampler_init_softmax_impl (); -struct llama_sampler * llama_sampler_init_top_k_impl (int32_t k); -struct llama_sampler * llama_sampler_init_top_p_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_min_p_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_tail_free_impl(float z, size_t min_keep); -struct llama_sampler * llama_sampler_init_typical_impl (float p, size_t min_keep); -struct llama_sampler * llama_sampler_init_temp_impl (float t); -struct llama_sampler * llama_sampler_init_temp_ext_impl (float t, float delta, float exponent); - struct llama_sampler * llama_sampler_init_mirostat_impl( const struct llama_vocab & vocab, uint32_t seed, @@ -63,11 +38,6 @@ struct llama_sampler * llama_sampler_init_mirostat_impl( float eta, int32_t m); -struct llama_sampler * llama_sampler_init_mirostat_v2_impl( - uint32_t seed, - float tau, - float eta); - struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, @@ -82,7 +52,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl( bool penalize_nl, bool ignore_eos); - LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( +LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( const struct llama_vocab & vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); diff --git a/src/llama.cpp b/src/llama.cpp index db50a0332e4c4..f5e01004f1894 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20592,102 +20592,17 @@ int32_t llama_chat_apply_template( // sampling // -const char * llama_sampler_name(const struct llama_sampler * smpl) { - return llama_sampler_name_impl(*smpl); -} - -void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) { - llama_sampler_accept_impl(*smpl, token); -} - -void llama_sampler_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { - llama_sampler_apply_impl(*smpl, cur_p); -} - -void llama_sampler_reset(struct llama_sampler * smpl) { - llama_sampler_reset_impl(*smpl); -} - -struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) { - return llama_sampler_clone_impl(*smpl); -} - -void llama_sampler_free(struct llama_sampler * smpl) { - if (smpl == nullptr) { - return; - } - - llama_sampler_free_impl(smpl); -} - -struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { - return llama_sampler_chain_init_impl(params); -} - -void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { - llama_sampler_chain_add_impl(*(struct llama_sampler_chain *) chain->ctx, smpl); -} - -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { - return llama_sampler_chain_get_impl(*(const struct llama_sampler_chain *) chain->ctx, i); -} - -int llama_sampler_chain_n(const struct llama_sampler * chain) { - return llama_sampler_chain_n_impl(*(const struct llama_sampler_chain *) chain->ctx); -} - -struct llama_sampler * llama_sampler_init_greedy(void) { - return llama_sampler_init_greedy_impl(); -} - -struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { - return llama_sampler_init_dist_impl(seed); -} - -struct llama_sampler * llama_sampler_init_softmax(void) { - return llama_sampler_init_softmax_impl(); -} - -struct llama_sampler * llama_sampler_init_top_k(int32_t k) { - return llama_sampler_init_top_k_impl(k); -} - -struct llama_sampler * llama_sampler_init_top_p(float p, int32_t min_keep) { - return llama_sampler_init_top_p_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_min_p(float p, int32_t min_keep) { - return llama_sampler_init_min_p_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_tail_free(float z, int32_t min_keep) { - return llama_sampler_init_tail_free_impl(z, min_keep); -} - -struct llama_sampler * llama_sampler_init_typical(float p, int32_t min_keep) { - return llama_sampler_init_typical_impl(p, min_keep); -} - -struct llama_sampler * llama_sampler_init_temp(float temp) { - return llama_sampler_init_temp_impl(temp); -} - -struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init_temp_ext_impl(temp, delta, exponent); -} - +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); } -struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { - return llama_sampler_init_mirostat_v2_impl(seed, tau, eta); -} - +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_penalties( const struct llama_model * model, int32_t penalty_last_n, @@ -20699,31 +20614,14 @@ struct llama_sampler * llama_sampler_init_penalties( return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); } -LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( +// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp +struct llama_sampler * llama_sampler_init_logit_bias( const struct llama_model * model, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); } -llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); - - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); - - // TODO: do not allocate each time - std::vector cur(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - llama_token_data_array cur_p = { cur.data(), cur.size(), -1, false }; - - llama_sampler_apply(smpl, &cur_p); - - return cur_p.data[cur_p.selected].id; -} - // // model split // From 5ab52c1f64d5d632a0c4a0507c123221bf4b2899 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 12:52:19 +0300 Subject: [PATCH 37/47] sampling : remove _context suffix [no ci] --- src/llama-sampling.cpp | 146 ++++++++++++++++++++--------------------- 1 file changed, 73 insertions(+), 73 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 92d39222d3c03..2aa4981cebc18 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -624,7 +624,7 @@ struct llama_sampler * llama_sampler_init_greedy() { // dist -struct llama_sampler_context_dist { +struct llama_sampler_dist { const uint32_t seed; std::mt19937 rng; @@ -636,23 +636,23 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "dist"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_context_dist *) smpl->ctx; + auto * ctx = (llama_sampler_dist *) smpl->ctx; cur_p->selected = llama_sample_dist(cur_p, ctx->rng, ctx->probs); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_dist *) smpl->ctx; + const auto * ctx = (const llama_sampler_dist *) smpl->ctx; return llama_sampler_init_dist(ctx->seed); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_dist *) smpl->ctx; + delete (llama_sampler_dist *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return new llama_sampler { /* .iface = */ &llama_sampler_dist_i, - /* .ctx = */ new llama_sampler_context_dist { + /* .ctx = */ new llama_sampler_dist { /* .seed = */ seed, /* .rng = */ std::mt19937(seed), /* .probs = */ {}, @@ -682,7 +682,7 @@ struct llama_sampler * llama_sampler_init_softmax() { // top-k -struct llama_sampler_context_top_k { +struct llama_sampler_top_k { const int32_t k; }; @@ -690,23 +690,23 @@ static struct llama_sampler_i llama_sampler_top_k_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-k"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_top_k *) smpl->ctx; + const auto * ctx = (llama_sampler_top_k *) smpl->ctx; llama_sampler_top_k_impl(cur_p, ctx->k); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_top_k *) smpl->ctx; + const auto * ctx = (const llama_sampler_top_k *) smpl->ctx; return llama_sampler_init_top_k(ctx->k); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_top_k *) smpl->ctx; + delete (llama_sampler_top_k *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { return new llama_sampler { /* .iface = */ &llama_sampler_top_k_i, - /* .ctx = */ new llama_sampler_context_top_k { + /* .ctx = */ new llama_sampler_top_k { /* .k = */ k, }, }; @@ -714,7 +714,7 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { // top-p -struct llama_sampler_context_top_p { +struct llama_sampler_top_p { const float p; const size_t min_keep; }; @@ -723,23 +723,23 @@ static struct llama_sampler_i llama_sampler_top_p_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "top-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_top_p *) smpl->ctx; + const auto * ctx = (llama_sampler_top_p *) smpl->ctx; llama_sampler_top_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_top_p *) smpl->ctx; + const auto * ctx = (const llama_sampler_top_p *) smpl->ctx; return llama_sampler_init_top_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_top_p *) smpl->ctx; + delete (llama_sampler_top_p *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_top_p_i, - /* .ctx = */ new llama_sampler_context_top_p { + /* .ctx = */ new llama_sampler_top_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -748,7 +748,7 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_context_min_p { +struct llama_sampler_min_p { const float p; const size_t min_keep; }; @@ -757,23 +757,23 @@ static struct llama_sampler_i llama_sampler_min_p_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "min-p"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_min_p *) smpl->ctx; + const auto * ctx = (llama_sampler_min_p *) smpl->ctx; llama_sampler_min_p_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_min_p *) smpl->ctx; + const auto * ctx = (const llama_sampler_min_p *) smpl->ctx; return llama_sampler_init_min_p(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_min_p *) smpl->ctx; + delete (llama_sampler_min_p *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_min_p_i, - /* .ctx = */ new llama_sampler_context_min_p { + /* .ctx = */ new llama_sampler_min_p { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -782,7 +782,7 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { // tail-free -struct llama_sampler_context_tail_free { +struct llama_sampler_tail_free { const float z; const size_t min_keep; }; @@ -791,23 +791,23 @@ static struct llama_sampler_i llama_sampler_tail_free_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "tail-free"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_tail_free *) smpl->ctx; + const auto * ctx = (llama_sampler_tail_free *) smpl->ctx; llama_sampler_tail_free_impl(cur_p, ctx->z, ctx->min_keep); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_tail_free *) smpl->ctx; + const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx; return llama_sampler_init_tail_free(ctx->z, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_tail_free *) smpl->ctx; + delete (llama_sampler_tail_free *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_tail_free_i, - /* .ctx = */ new llama_sampler_context_tail_free { + /* .ctx = */ new llama_sampler_tail_free { /* .z = */ z, /*. min_keep = */ min_keep, }, @@ -816,7 +816,7 @@ struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) { // typical -struct llama_sampler_context_typical { +struct llama_sampler_typical { const float p; const size_t min_keep; }; @@ -825,23 +825,23 @@ static struct llama_sampler_i llama_sampler_typical_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "typical"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_typical *) smpl->ctx; + const auto * ctx = (llama_sampler_typical *) smpl->ctx; llama_sampler_typical_impl(cur_p, ctx->p, ctx->min_keep); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_typical *) smpl->ctx; + const auto * ctx = (const llama_sampler_typical *) smpl->ctx; return llama_sampler_init_typical(ctx->p, ctx->min_keep); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_typical *) smpl->ctx; + delete (llama_sampler_typical *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { return new llama_sampler { /* .iface = */ &llama_sampler_typical_i, - /* .ctx = */ new llama_sampler_context_typical { + /* .ctx = */ new llama_sampler_typical { /* .p = */ p, /* .min_keep = */ min_keep, }, @@ -850,7 +850,7 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_context_temp { +struct llama_sampler_temp { const float temp; }; @@ -858,23 +858,23 @@ static struct llama_sampler_i llama_sampler_temp_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_temp *) smpl->ctx; + const auto * ctx = (llama_sampler_temp *) smpl->ctx; llama_sampler_temp_impl(cur_p, ctx->temp); }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_temp *) smpl->ctx; + const auto * ctx = (const llama_sampler_temp *) smpl->ctx; return llama_sampler_init_temp(ctx->temp); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_temp *) smpl->ctx; + delete (llama_sampler_temp *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_temp(float temp) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_i, - /* .ctx = */ new llama_sampler_context_temp { + /* .ctx = */ new llama_sampler_temp { /*.temp = */ temp, }, }; @@ -882,7 +882,7 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_context_temp_ext { +struct llama_sampler_temp_ext { const float temp; const float delta; const float exponent; @@ -892,7 +892,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "temp-ext"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_temp_ext *) smpl->ctx; + const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx; if (ctx->delta > 0) { const float temp_min = std::max(0.0f, ctx->temp - ctx->delta); const float temp_max = ctx->temp + ctx->delta; @@ -904,18 +904,18 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = { }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_temp_ext *) smpl->ctx; + const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx; return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_temp_ext *) smpl->ctx; + delete (llama_sampler_temp_ext *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { return new llama_sampler { /* .iface = */ &llama_sampler_temp_ext_i, - /* .ctx = */ new llama_sampler_context_temp_ext { + /* .ctx = */ new llama_sampler_temp_ext { /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, @@ -925,7 +925,7 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa // mirostat -struct llama_sampler_context_mirostat { +struct llama_sampler_mirostat { const struct llama_vocab * vocab; const uint32_t seed; @@ -946,7 +946,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; llama_sampler_softmax_impl(cur_p); @@ -980,23 +980,23 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_context_mirostat *) smpl->ctx; + auto * ctx = (llama_sampler_mirostat *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_mirostat *) smpl->ctx; + const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_mirostat *) smpl->ctx; + delete (llama_sampler_mirostat *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, - /* .ctx = */ new llama_sampler_context_mirostat { + /* .ctx = */ new llama_sampler_mirostat { /* .vocab = */ &vocab, /* .seed = */ seed, /* .tau = */ tau, @@ -1011,7 +1011,7 @@ struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab // mirostat v2 -struct llama_sampler_context_mirostat_v2 { +struct llama_sampler_mirostat_v2 { const uint32_t seed; const float tau; @@ -1028,7 +1028,7 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "mirostat-v2"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; llama_sampler_softmax_impl(cur_p); @@ -1055,23 +1055,23 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { ctx->mu = ctx->mu - ctx->eta * e; }, /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_context_mirostat_v2 *) smpl->ctx; + auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx; ctx->mu = 2.0f*ctx->tau; ctx->rng = std::mt19937(ctx->seed); }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx = (const llama_sampler_context_mirostat_v2 *) smpl->ctx; + const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_mirostat_v2 *) smpl->ctx; + delete (llama_sampler_mirostat_v2 *) smpl->ctx; }, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_v2_i, - /* .ctx = */ new llama_sampler_context_mirostat_v2 { + /* .ctx = */ new llama_sampler_mirostat_v2 { /* .seed = */ seed, /* .tau = */ tau, /* .eta = */ eta, @@ -1084,7 +1084,7 @@ struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, // grammar -struct llama_sampler_context_grammar { +struct llama_sampler_grammar { const struct llama_vocab * vocab; std::string grammar_str; @@ -1096,19 +1096,19 @@ struct llama_sampler_context_grammar { static struct llama_sampler_i llama_sampler_grammar_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "grammar"; }, /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_accept_impl(*ctx->grammar, token); } }, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (ctx->grammar) { llama_sampler_grammar_impl(cur_p, *ctx->grammar); } }, /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; + auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (!ctx->grammar) { return; } @@ -1119,11 +1119,11 @@ static struct llama_sampler_i llama_sampler_grammar_i = { ctx->grammar = grammar_new; }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_context_grammar *) smpl->ctx; + const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx; auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); - auto * ctx_dst = (llama_sampler_context_grammar *) result->ctx; + auto * ctx_dst = (llama_sampler_grammar *) result->ctx; if (ctx_src->grammar) { ctx_dst->grammar_str = ctx_src->grammar_str; ctx_dst->grammar_root = ctx_src->grammar_root; @@ -1134,7 +1134,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = { return result; }, /* .free = */ [](struct llama_sampler * smpl) { - const auto * ctx = (llama_sampler_context_grammar *) smpl->ctx; + const auto * ctx = (llama_sampler_grammar *) smpl->ctx; if (ctx->grammar) { llama_grammar_free_impl(ctx->grammar); @@ -1145,7 +1145,7 @@ static struct llama_sampler_i llama_sampler_grammar_i = { }; struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { - auto * ctx = new llama_sampler_context_grammar; + auto * ctx = new llama_sampler_grammar; if (grammar_str != nullptr && grammar_str[0] != '\0') { *ctx = { @@ -1171,7 +1171,7 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab // penalties -struct llama_sampler_context_penalties { +struct llama_sampler_penalties { const struct llama_vocab * vocab; const int32_t penalty_last_n; @@ -1188,11 +1188,11 @@ struct llama_sampler_context_penalties { static struct llama_sampler_i llama_sampler_penalties_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "penalties"; }, /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { - auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; + auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.push_back(token); }, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; + auto * ctx = (llama_sampler_penalties *) smpl->ctx; GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary"); @@ -1222,11 +1222,11 @@ static struct llama_sampler_i llama_sampler_penalties_i = { } }, /* .reset = */ [](struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_context_penalties *) smpl->ctx; + auto * ctx = (llama_sampler_penalties *) smpl->ctx; ctx->prev.clear(); }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_context_penalties *) smpl->ctx; + const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties_impl( *ctx_src->vocab, ctx_src->penalty_last_n, @@ -1236,13 +1236,13 @@ static struct llama_sampler_i llama_sampler_penalties_i = { ctx_src->penalize_nl, ctx_src->ignore_eos); - auto * ctx_dst = (llama_sampler_context_penalties *) result->ctx; + auto * ctx_dst = (llama_sampler_penalties *) result->ctx; ctx_dst->prev = ctx_src->prev; return result; }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_penalties *) smpl->ctx; + delete (llama_sampler_penalties *) smpl->ctx; }, }; @@ -1252,7 +1252,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, - /* .ctx = */ new llama_sampler_context_penalties { + /* .ctx = */ new llama_sampler_penalties { /* .vocab = */ &vocab, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, @@ -1267,7 +1267,7 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca // logit-bias -struct llama_sampler_context_logit_bias { +struct llama_sampler_logit_bias { const struct llama_vocab * vocab; std::vector logit_bias; @@ -1277,7 +1277,7 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .name = */ [](const struct llama_sampler * /*smpl*/) { return "logit-bias"; }, /* .accept = */ nullptr, /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { - auto * ctx = (llama_sampler_context_logit_bias *) smpl->ctx; + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary"); @@ -1287,11 +1287,11 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_context_logit_bias *) smpl->ctx; + const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_sampler * smpl) { - delete (llama_sampler_context_logit_bias *) smpl->ctx; + delete (llama_sampler_logit_bias *) smpl->ctx; }, }; @@ -1301,7 +1301,7 @@ struct llama_sampler * llama_sampler_init_logit_bias_impl( const llama_logit_bias * logit_bias) { return new llama_sampler { /* .iface = */ &llama_sampler_logit_bias_i, - /* .ctx = */ new llama_sampler_context_logit_bias { + /* .ctx = */ new llama_sampler_logit_bias { /* .vocab = */ &vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), }, From 757a9bf8686a7a2f478bdf1aca17e8e7f77fb309 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 13:47:27 +0300 Subject: [PATCH 38/47] llama : add new llama_perf API ggml-ci --- common/common.cpp | 2 +- common/sampling.cpp | 13 ++- common/sampling.h | 6 +- examples/batched-bench/batched-bench.cpp | 3 +- examples/batched/batched.cpp | 4 +- examples/embedding/embedding.cpp | 4 +- examples/eval-callback/eval-callback.cpp | 3 +- examples/gritlm/gritlm.cpp | 4 + examples/imatrix/imatrix.cpp | 3 +- examples/infill/infill.cpp | 7 +- examples/llama-bench/llama-bench.cpp | 2 +- examples/llava/llava-cli.cpp | 4 +- examples/llava/minicpmv-cli.cpp | 2 +- examples/lookahead/lookahead.cpp | 3 +- examples/lookup/lookup.cpp | 5 +- examples/main/main.cpp | 7 +- examples/parallel/parallel.cpp | 2 +- examples/passkey/passkey.cpp | 3 +- examples/perplexity/perplexity.cpp | 5 +- examples/retrieval/retrieval.cpp | 4 +- examples/simple/simple.cpp | 8 +- examples/speculative/speculative.cpp | 8 +- include/llama.h | 40 ++++----- src/llama-sampling.cpp | 4 +- src/llama.cpp | 104 ++++++++++++++--------- 25 files changed, 149 insertions(+), 101 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 2a51649a5c49f..6394301318c4b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2533,7 +2533,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx, nullptr); + llama_perf_reset(lctx, LLAMA_PERF_TYPE_CONTEXT); } iparams.model = model; diff --git a/common/sampling.cpp b/common/sampling.cpp index 9964501da7ccd..553aefbf4d542 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -153,7 +153,7 @@ std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); - lparams.no_timing = false; + lparams.no_perf = false; // TODO: control via params auto * result = new gpt_sampler { /* .params = */ params, @@ -270,8 +270,15 @@ llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { return gsmpl->prev.rat(0); } -void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { - llama_print_timings(ctx, gsmpl ? gsmpl->chain : nullptr); +void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { + // TODO: measure grammar performance + + if (gsmpl) { + llama_perf_print(gsmpl->chain, LLAMA_PERF_TYPE_SAMPLER_CHAIN); + } + if (ctx) { + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + } } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { diff --git a/common/sampling.h b/common/sampling.h index d88038204c89f..fa691cda23499 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -62,6 +62,8 @@ struct gpt_sampler_params { // - grammar support // - custom sampler logic based on the parameters // +// TODO: measure grammar performance +// struct gpt_sampler; struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); @@ -75,11 +77,9 @@ void gpt_sampler_reset (struct gpt_sampler * gsmpl); llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); -//llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_token_data_array * cur_p); - llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); -void gpt_print_timings(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); +void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); // extended sampling implementation: // diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index b02ef74cfcdab..b043c74cc4954 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -210,7 +210,8 @@ int main(int argc, char ** argv) { } } - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); llama_batch_free(batch); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index b6e98fcc36335..7ab7eed79c307 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -231,7 +231,9 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx, smpl); + LOG_TEE("\n"); + llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); fprintf(stderr, "\n"); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 4b288b46092ef..e5e0872b1ba4a 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -307,8 +307,10 @@ int main(int argc, char ** argv) { if (notArray) fprintf(stdout, "\n}\n"); } + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + // clean up - llama_print_timings(ctx, nullptr); llama_batch_free(batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 166ca4b7da6bd..aea15c864ea93 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -181,7 +181,8 @@ int main(int argc, char ** argv) { return 1; } - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); llama_free(ctx); llama_free_model(model); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index b402abbb80256..4e801c69d2f06 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -171,8 +171,12 @@ int main(int argc, char * argv[]) { auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic { diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 1c7f5350555e9..107f8c8859dcf 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -638,7 +638,8 @@ int main(int argc, char ** argv) { g_collector.save_imatrix(); - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); llama_free(ctx); llama_free_model(model); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 3895b586ecc69..1ebc0b324bc82 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -81,7 +81,7 @@ static void write_logfile( yaml_dump_string_multiline(logfile, "output", output.c_str()); yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - llama_dump_timing_info_yaml(logfile, ctx); + llama_perf_dump_yaml(logfile, ctx); fclose(logfile); } @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - gpt_print_timings(*g_ctx, *g_smpl); + gpt_perf_print(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -634,7 +634,8 @@ int main(int argc, char ** argv) { fflush(stdout); } - gpt_print_timings(ctx, smpl); + LOG_TEE("\n"); + gpt_perf_print(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 1385634ddcc9c..d7db5af722a60 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1630,7 +1630,7 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_print_timings(ctx, nullptr); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); llama_free(ctx); diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 63a75c4a34ca2..4d7ccc91fc4b4 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -310,7 +310,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama, nullptr); + llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); @@ -327,7 +327,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama, nullptr); + llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 15f258b91169f..237da9429ecc6 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -319,7 +319,7 @@ int main(int argc, char ** argv) { } } printf("\n"); - llama_print_timings(ctx_llava->ctx_llama, nullptr); + llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 8b461555bb594..c2e931c651008 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -467,7 +467,8 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - gpt_print_timings(ctx, smpl); + LOG_TEE("\n"); + gpt_perf_print(ctx, smpl); gpt_sampler_free(smpl); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index da3583f3c00f2..071400b7e7f7e 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -238,8 +238,9 @@ int main(int argc, char ** argv){ LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - LOG_TEE("\ntarget:\n"); - gpt_print_timings(ctx, smpl); + LOG_TEE("\ntarget:\n\n"); + llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); gpt_sampler_free(smpl); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 85dea9782e152..42058d41de35d 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -93,7 +93,7 @@ static void write_logfile( yaml_dump_string_multiline(logfile, "output", output.c_str()); yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - llama_dump_timing_info_yaml(logfile, ctx); + llama_perf_dump_yaml(logfile, ctx); fclose(logfile); } @@ -106,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - gpt_print_timings(*g_ctx, *g_smpl); + gpt_perf_print(*g_ctx, *g_smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -929,7 +929,8 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - gpt_print_timings(ctx, smpl); + LOG_TEE("\n"); + gpt_perf_print(ctx, smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); gpt_sampler_free(smpl); diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 7422042db268f..c331c0f28dc7e 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -414,7 +414,7 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); // TODO: print sampling/grammar timings for all clients - llama_print_timings(ctx, nullptr); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); llama_batch_free(batch); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 92c71c5a1a35d..ff8d0302f8f0a 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -259,7 +259,8 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 987236ab683de..2ca43f1256765 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -76,7 +76,7 @@ static void write_logfile( fprintf(logfile, "ppl_value: %f\n", results.ppl_value); yaml_dump_vector_float(logfile, "probs", results.probs); - llama_dump_timing_info_yaml(logfile, ctx); + llama_perf_dump_yaml(logfile, ctx); fclose(logfile); } @@ -2048,7 +2048,8 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); write_logfile(ctx, params, model, results); llama_free(ctx); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 6089344a02bb7..7eb94765041a2 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -293,8 +293,10 @@ int main(int argc, char ** argv) { } } + LOG_TEE("\n"); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); + // clean up - llama_print_timings(ctx, nullptr); llama_batch_free(query_batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index e5dfeb2f4b4f8..8a0ad43ad31b8 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -57,8 +57,12 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_chain_default_params(); + sparams.no_perf = false; + llama_sampler * smpl = llama_sampler_chain_init(sparams); + llama_sampler_chain_add(smpl, llama_sampler_init_greedy()); + // tokenize the prompt std::vector tokens_list; @@ -153,7 +157,9 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx, nullptr); + LOG_TEE("\n"); + llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN); + llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT); fprintf(stderr, "\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 037d5d34bb54d..55c6bda70e8e1 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -613,12 +613,12 @@ int main(int argc, char ** argv) { LOG_TEE("n_accept = %d\n", n_accept); LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); - LOG_TEE("\ndraft:\n"); + LOG_TEE("\ndraft:\n\n"); // TODO: print sampling/grammar timings for all drafts - gpt_print_timings(ctx_dft, nullptr); + llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT); - LOG_TEE("\ntarget:\n"); - gpt_print_timings(ctx_tgt, smpl); + LOG_TEE("\ntarget:\n\n"); + gpt_perf_print(ctx_tgt, smpl); gpt_sampler_free(smpl); for (int s = 0; s < n_seq_dft; ++s) { diff --git a/include/llama.h b/include/llama.h index 29c216f2d9581..be01dadfb663d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -342,6 +342,7 @@ extern "C" { bool embeddings; // if true, extract embeddings (together with logits) bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU bool flash_attn; // whether to use flash attention [EXPERIMENTAL] + //bool no_perf; // whether to measure performance timings, TODO: implement // Abort callback // if it returns true, execution of llama_decode() will be aborted @@ -371,23 +372,9 @@ extern "C" { } llama_logit_bias; typedef struct llama_sampler_chain_params { - bool no_timing; // whether to measure performance timings + bool no_perf; // whether to measure performance timings } llama_sampler_chain_params; - // performance timing information - struct llama_timings { - double t_start_ms; - double t_end_ms; - double t_load_ms; - double t_sampler_ms; - double t_p_eval_ms; - double t_eval_ms; - - int32_t n_sampler; - int32_t n_p_eval; - int32_t n_eval; - }; - // used in chat template typedef struct llama_chat_message { const char * role; @@ -1121,13 +1108,6 @@ extern "C" { // Returns the split_prefix length. LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); - // Performance information - LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - - // note: requires llama_sampler_chain. how to prevent misuse? - LLAMA_API void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain); - LLAMA_API void llama_reset_timings( struct llama_context * ctx, struct llama_sampler * chain); - // Print system information LLAMA_API const char * llama_print_system_info(void); @@ -1135,7 +1115,21 @@ extern "C" { // If this is not called, or NULL is supplied, everything is output on stderr. LLAMA_API void llama_log_set(ggml_log_callback log_callback, void * user_data); - LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx); + // + // Performance utils + // + // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements. + // + + enum llama_perf_type { + LLAMA_PERF_TYPE_CONTEXT = 0, + LLAMA_PERF_TYPE_SAMPLER_CHAIN = 1, + }; + + LLAMA_API void llama_perf_print(const void * ctx, enum llama_perf_type type); + LLAMA_API void llama_perf_reset( void * ctx, enum llama_perf_type type); + + LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); #ifdef __cplusplus } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 2aa4981cebc18..15d5b5f8a44c8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -511,7 +511,7 @@ static struct llama_sampler_i llama_sampler_chain_i = { /* .accept = */ [](struct llama_sampler * smpl, llama_token token) { auto * chain = (llama_sampler_chain *) smpl->ctx; - time_meas tm(chain->t_sample_us, chain->params.no_timing); + time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { llama_sampler_accept(smpl, token); @@ -522,7 +522,7 @@ static struct llama_sampler_i llama_sampler_chain_i = { /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * chain = (llama_sampler_chain *) smpl->ctx; - time_meas tm(chain->t_sample_us, chain->params.no_timing); + time_meas tm(chain->t_sample_us, chain->params.no_perf); for (auto * smpl : chain->samplers) { llama_sampler_apply(smpl, cur_p); diff --git a/src/llama.cpp b/src/llama.cpp index f5e01004f1894..c67f3638d337d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17924,7 +17924,7 @@ struct llama_context_params llama_context_default_params() { struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { - /*.no_timing =*/ false, // TODO: change to true and set explicitly in examples + /*.no_perf =*/ true, }; return result; @@ -20650,45 +20650,6 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -void llama_print_timings(const struct llama_context * ctx, const struct llama_sampler * chain) { - auto * smpl = chain ? (const struct llama_sampler_chain *) chain->ctx : nullptr; - - const llama_timings timings = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sampler_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0), - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - - /*.n_sampler =*/ std::max(0, smpl ? smpl->n_sample : 0), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), - }; - - LLAMA_LOG_INFO("\n"); - LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sampler_ms, timings.n_sampler, timings.t_sampler_ms / timings.n_sampler, 1e3 / timings.t_sampler_ms * timings.n_sampler); - LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); - LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_eval_ms, timings.n_eval, timings.t_eval_ms / timings.n_eval, 1e3 / timings.t_eval_ms * timings.n_eval); - LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); -} - -void llama_reset_timings(struct llama_context * ctx, struct llama_sampler * chain) { - ctx->t_start_us = ggml_time_us(); - ctx->t_eval_us = ctx->n_eval = 0; - ctx->t_p_eval_us = ctx->n_p_eval = 0; - - if (chain) { - auto * smpl = (struct llama_sampler_chain *) chain->ctx; - - smpl->t_sample_us = smpl->n_sample = 0; - } -} - const char * llama_print_system_info(void) { static std::string s; @@ -20717,7 +20678,68 @@ const char * llama_print_system_info(void) { return s.c_str(); } -void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { +void llama_perf_print(const void * ctx, enum llama_perf_type type) { + switch (type) { + case LLAMA_PERF_TYPE_CONTEXT: + { + const auto * p = (const struct llama_context *) ctx; + + const double t_start_ms = 1e-3 * p->t_start_us; + const double t_end_ms = 1.00 * ggml_time_ms(); + const double t_load_ms = 1e-3 * p->t_load_us; + const double t_p_eval_ms = 1e-3 * p->t_p_eval_us; + const double t_eval_ms = 1e-3 * p->t_eval_us; + + const int32_t n_p_eval = std::max(0, p->n_p_eval); + const int32_t n_eval = std::max(1, p->n_eval); + + LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, t_load_ms); + LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, t_p_eval_ms, n_p_eval, t_p_eval_ms / n_p_eval, 1e3 / t_p_eval_ms * n_p_eval); + LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, t_eval_ms, n_eval, t_eval_ms / n_eval, 1e3 / t_eval_ms * n_eval); + LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - t_start_ms), (n_p_eval + n_eval)); + } break; + case LLAMA_PERF_TYPE_SAMPLER_CHAIN: + { + const auto * smpl = (const struct llama_sampler *) ctx; + const auto * p = (const struct llama_sampler_chain *) smpl->ctx; + + const double t_sampler_ms = 1e-3 * p->t_sample_us; + + const int32_t n_sampler = std::max(0, p->n_sample); + + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, t_sampler_ms, n_sampler, t_sampler_ms / n_sampler, 1e3 / t_sampler_ms * n_sampler); + } break; + default: + GGML_ABORT("invalid perf type"); + } +} + +void llama_perf_reset(void * ctx, enum llama_perf_type type) { + switch (type) { + case LLAMA_PERF_TYPE_CONTEXT: + { + auto * p = (struct llama_context *) ctx; + + p->t_start_us = ggml_time_us(); + p->t_eval_us = p->n_eval = 0; + p->t_p_eval_us = p->n_p_eval = 0; + } break; + case LLAMA_PERF_TYPE_SAMPLER_CHAIN: + { + auto * smpl = (struct llama_sampler *) ctx; + auto * p = (struct llama_sampler_chain *) smpl->ctx; + + p->t_sample_us = p->n_sample = 0; + } break; + default: + GGML_ABORT("invalid perf type"); + } +} + +void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) { fprintf(stream, "\n"); fprintf(stream, "###########\n"); fprintf(stream, "# Timings #\n"); From befcfe7a31dec28c7284bde9ff82847ca6578de9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 14:02:17 +0300 Subject: [PATCH 39/47] common : simplify gpt_sampler ggml-ci --- common/sampling.cpp | 49 ++++++++++++++++++--------------------------- 1 file changed, 19 insertions(+), 30 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 553aefbf4d542..a4baf9db60084 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -98,10 +98,7 @@ struct ring_buffer { struct gpt_sampler { gpt_sampler_params params; - struct llama_sampler * bias; - struct llama_sampler * pnlt; struct llama_sampler * grmr; - struct llama_sampler * chain; ring_buffer prev; @@ -140,11 +137,11 @@ std::string gpt_sampler_params::print() const { } std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { - std::string result = "\tlogits"; + std::string result = "\tlogits "; for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string(" -> ") + llama_sampler_name(smpl) + " "; + result += std::string("-> ") + llama_sampler_name(smpl) + " "; } return result; @@ -157,18 +154,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st auto * result = new gpt_sampler { /* .params = */ params, - /* .bias = */ llama_sampler_init_logit_bias( - model, - params.logit_bias.size(), - params.logit_bias.data()), - /* .pnlt = */ llama_sampler_init_penalties( - model, - params.penalty_last_n, - params.penalty_repeat, - params.penalty_freq, - params.penalty_present, - params.penalize_nl, - params.ignore_eos), /* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(params.n_prev), @@ -176,6 +161,22 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st /* .cur_p = */ {}, }; + llama_sampler_chain_add(result->chain, + llama_sampler_init_logit_bias( + model, + params.logit_bias.size(), + params.logit_bias.data())); + + llama_sampler_chain_add(result->chain, + llama_sampler_init_penalties( + model, + params.penalty_last_n, + params.penalty_repeat, + params.penalty_freq, + params.penalty_present, + params.penalize_nl, + params.ignore_eos)); + if (params.temp > 0.0f) { if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { @@ -223,8 +224,6 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st void gpt_sampler_free(struct gpt_sampler * gsmpl) { if (gsmpl) { - llama_sampler_free(gsmpl->bias); - llama_sampler_free(gsmpl->pnlt); llama_sampler_free(gsmpl->grmr); llama_sampler_free(gsmpl->chain); @@ -236,8 +235,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { return new gpt_sampler { /* .params = */ gsmpl->params, - /* .bias = */ llama_sampler_clone(gsmpl->bias), - /* .pnlt = */ llama_sampler_clone(gsmpl->pnlt), /* .grmr = */ llama_sampler_clone(gsmpl->grmr), /* .chain = */ llama_sampler_clone(gsmpl->chain), /* .prev = */ gsmpl->prev, @@ -282,8 +279,6 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & bias = gsmpl->bias; - auto & pnlt = gsmpl->pnlt; auto & grmr = gsmpl->grmr; auto & chain = gsmpl->chain; @@ -291,9 +286,6 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context auto & cur_p = gsmpl->cur_p; - llama_sampler_apply(bias, &cur_p); - llama_sampler_apply(pnlt, &cur_p); - if (grammar_first) { llama_sampler_apply(grmr, &cur_p); } @@ -325,10 +317,7 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context // if the token is not valid, sample again, first apply the grammar samplers and then sample gsmpl->set_logits(ctx, idx); - llama_sampler_apply(bias, &cur_p); - llama_sampler_apply(pnlt, &cur_p); - llama_sampler_apply(grmr, &cur_p); - + llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); From 9ce9210ef1b1ba29dd94e339c2e165a3465ded9f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 6 Sep 2024 14:57:44 +0300 Subject: [PATCH 40/47] batched.swift : fix build [no ci] --- examples/batched.swift/Sources/main.swift | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index e9acdc7ac86aa..fbb9a92b349a9 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -199,9 +199,10 @@ if n_parallel > 1 { let t_main_end = ggml_time_us() -print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") +print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n") -llama_print_timings(context, smpl) +llama_perf_print(context, LLAMA_PERF_TYPE_CONTEXT) +llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count From 4a4530b7ff15da59d08eb0f91cf01c8254391310 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 12:21:45 +0300 Subject: [PATCH 41/47] examples : add missing samplers --- examples/batched.swift/Sources/main.swift | 1 + examples/batched/batched.cpp | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index fbb9a92b349a9..c17357b7ceee0 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -64,6 +64,7 @@ defer { llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4)); +llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234)); let n_ctx = llama_n_ctx(context) diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 7ab7eed79c307..f321f61047ad5 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -71,6 +71,7 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k)); llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp)); + llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed)); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); From 4b272356242e50484fe4fba3656e759275c92756 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 12:22:27 +0300 Subject: [PATCH 42/47] style : rearrange code + add comments and TODOs ggml-ci --- common/sampling.cpp | 71 ++++++++++++++++++++++---------------------- common/sampling.h | 41 +++++++++++++++++++------ include/llama.h | 49 +++++++++++++++++++++++++++--- src/llama-sampling.h | 2 ++ 4 files changed, 115 insertions(+), 48 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index a4baf9db60084..5f27d5006044f 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const { return std::string(result); } -std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { - std::string result = "\tlogits "; - - for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { - const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); - result += std::string("-> ") + llama_sampler_name(smpl) + " "; - } - - return result; -} - struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) { llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) { } } -struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { - return new gpt_sampler { - /* .params = */ gsmpl->params, - /* .grmr = */ llama_sampler_clone(gsmpl->grmr), - /* .chain = */ llama_sampler_clone(gsmpl->chain), - /* .prev = */ gsmpl->prev, - /* .cur = */ gsmpl->cur, - /* .cur_p = */ gsmpl->cur_p, - }; -} - void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) { if (accept_grammar) { llama_sampler_accept(gsmpl->grmr, token); @@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) { llama_sampler_reset(gsmpl->chain); } -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { - return &gsmpl->cur_p; -} - -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { - return gsmpl->prev.rat(0); +struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) { + return new gpt_sampler { + /* .params = */ gsmpl->params, + /* .grmr = */ llama_sampler_clone(gsmpl->grmr), + /* .chain = */ llama_sampler_clone(gsmpl->chain), + /* .prev = */ gsmpl->prev, + /* .cur = */ gsmpl->cur, + /* .cur_p = */ gsmpl->cur_p, + }; } void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) { @@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * } llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) { - auto & grmr = gsmpl->grmr; - auto & chain = gsmpl->chain; - gsmpl->set_logits(ctx, idx); - auto & cur_p = gsmpl->cur_p; + auto & grmr = gsmpl->grmr; + auto & chain = gsmpl->chain; + auto & cur_p = gsmpl->cur_p; // initialized by set_logits if (grammar_first) { llama_sampler_apply(grmr, &cur_p); @@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context llama_sampler_apply(grmr, &single_token_data_array); - // check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY const bool is_valid = single_token_data_array.data[0].logit != -INFINITY; if (is_valid) { return id; } } - // if the token is not valid, sample again, first apply the grammar samplers and then sample + // resampling: + // if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain gsmpl->set_logits(ctx, idx); llama_sampler_apply(grmr, &cur_p); llama_sampler_apply(chain, &cur_p); - GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration"); + GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration"); return cur_p.data[cur_p.selected].id; } +// helpers + +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) { + return &gsmpl->cur_p; +} + +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) { + return gsmpl->prev.rat(0); +} + +std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) { + std::string result = "\tlogits "; + + for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) { + const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i); + result += std::string("-> ") + llama_sampler_name(smpl) + " "; + } + + return result; +} + std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) { n = std::min(n, (int) gsmpl->prev.size()); diff --git a/common/sampling.h b/common/sampling.h index fa691cda23499..654e0c513904d 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -61,24 +61,41 @@ struct gpt_sampler_params { // // - grammar support // - custom sampler logic based on the parameters +// - history of the last accepted tokens +// - performance metrics +// +// This goal is to have a common implementation of the sampling logic shared across the examples. +// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more +// complex (top-k, top-p, etc). +// +// Another example is related to the grammar. In general, the grammar constraints applied on the full +// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled +// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the +// grammar constraints are applied to the full vocabulary and the token is resampled. +// +// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can +// be moved into the core llama library. +// +// For convenience, the gpt_sampler also maintains a container with the current candidate tokens. +// This can be used to access the probabilities of the rest of the non-sampled tokens. // // TODO: measure grammar performance // + struct gpt_sampler; +// llama_sampler API overloads + struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params); void gpt_sampler_free(struct gpt_sampler * gsmpl); -struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl); - -void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); -void gpt_sampler_reset (struct gpt_sampler * gsmpl); - -llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); - -llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); +// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar +void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar); +void gpt_sampler_reset (struct gpt_sampler * gsmpl); +struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl); +// arguments can be nullptr to skip printing void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl); // extended sampling implementation: @@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * // - if not: resample by first applying the grammar constraints and then sampling again (slower path) // // if grammar_first is true, the grammar is applied before the samplers (slower) -// useful in cases where all the resulting candidates must fit the grammar +// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar // llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); // helpers +// access the internal list of current candidate tokens +llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl); + +// get the last accepted token +llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl); + // print the sampler chain into a string std::string gpt_sampler_print(const struct gpt_sampler * gsmpl); diff --git a/include/llama.h b/include/llama.h index be01dadfb663d..5441d98f05f28 100644 --- a/include/llama.h +++ b/include/llama.h @@ -206,6 +206,7 @@ extern "C" { LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs }; + // TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979) typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -216,7 +217,7 @@ extern "C" { // TODO: consider SoA llama_token_data * data; size_t size; - int64_t selected; + int64_t selected; // this is the index in the data array (i.e. not the token id) bool sorted; } llama_token_data_array; @@ -979,9 +980,38 @@ extern "C" { // // Sampling API // - // In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). + // Sample usage: + // + // // prepare the sampling chain at the start + // auto sparams = llama_sampler_chain_default_params(); + // + // llama_sampler * smpl = llama_sampler_chain_init(sparams); + // + // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); + // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); + // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); + // llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed)); + // + // ... + // + // // decoding loop: + // while (...) { + // ... + // + // llama_decode(ctx, batch); + // + // // sample from the logits of the last token in the batch + // const llama_token id = llama_sampler_sample(smpl, ctx, -1); + // + // ... + // } + // + // llama_sampler_free(smpl); + // + // + // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). + // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // - // TODO: in the future, the entire API that uses llama_model should start using llama_vocab typedef void * llama_sampler_context_t; @@ -1003,6 +1033,7 @@ extern "C" { llama_sampler_context_t ctx; }; + // mirror of llama_sampler_i: LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1011,7 +1042,8 @@ extern "C" { // important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add) LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl); - // llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers + // llama_sampler_chain + // a type of llama_sampler that can chain multiple samplers one after another LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params); @@ -1089,6 +1121,15 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); + // Shorthand for: + // + // const auto * logits = llama_get_logits_ith(ctx, idx); + // llama_token_data_array cur_p = { ... init from logits ... }; + // llama_sampler_apply(smpl, &cur_p); + // return cur_p.data[cur_p.selected].id; + // + // At this point, this is mostly a convenience function. + // LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx); // TODO: extend in the future diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 05bb294a10d2f..ddc84a3900666 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,5 +1,7 @@ #pragma once +// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? + #include "llama-grammar.h" #include From 19c36962f7d3219564ad36f86d81b372d6f95228 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 12:49:56 +0300 Subject: [PATCH 43/47] batched.swift : fix build --- examples/batched.swift/Sources/main.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index c17357b7ceee0..4bc2bbf2c1570 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -202,8 +202,8 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n") -llama_perf_print(context, LLAMA_PERF_TYPE_CONTEXT) -llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN) +llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT) +llama_perf_print(UnsafeRawPointer(smpl), LLAMA_PERF_TYPE_SAMPLER_CHAIN) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count From 0e6d170a506e951f716afee816d1388b75102887 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 14:16:21 +0300 Subject: [PATCH 44/47] sampling : avoid llama_model in few samplers ggml-ci --- common/sampling.cpp | 8 ++- include/llama.h | 11 ++-- src/llama-sampling.cpp | 138 +++++++++++++++++++++++++++++++---------- src/llama-sampling.h | 21 ------- src/llama.cpp | 25 -------- 5 files changed, 117 insertions(+), 86 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 5f27d5006044f..c81b4d233b04e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -152,13 +152,15 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - model, + llama_n_vocab(model), params.logit_bias.size(), params.logit_bias.data())); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties( - model, + llama_n_vocab (model), + llama_token_eos(model), + llama_token_nl (model), params.penalty_last_n, params.penalty_repeat, params.penalty_freq, @@ -196,7 +198,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(model, params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/include/llama.h b/include/llama.h index 5441d98f05f28..8bfb9e3b1532b 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1003,12 +1003,13 @@ extern "C" { // // sample from the logits of the last token in the batch // const llama_token id = llama_sampler_sample(smpl, ctx, -1); // + // // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.) + // llama_sampler_accept(smpl, id); // ... // } // // llama_sampler_free(smpl); // - // // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab // @@ -1086,7 +1087,7 @@ extern "C" { /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. LLAMA_API struct llama_sampler * llama_sampler_init_mirostat( - const struct llama_model * model, + int32_t n_vocab, uint32_t seed, float tau, float eta, @@ -1108,7 +1109,9 @@ extern "C" { const char * grammar_root); LLAMA_API struct llama_sampler * llama_sampler_init_penalties( - const struct llama_model * model, + int32_t n_vocab, // llama_n_vocab() + llama_token special_eos_id, // llama_token_eos() + llama_token linefeed_id, // llama_token_nl() int32_t penalty_last_n, // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat, // 1.0 = disabled float penalty_freq, // 0.0 = disabled @@ -1117,7 +1120,7 @@ extern "C" { bool ignore_eos); // ignore the end-of-sequence token LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( - const struct llama_model * model, + int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 15d5b5f8a44c8..e53b3d3a77edc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -3,6 +3,7 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include #include #include #include @@ -926,7 +927,7 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa // mirostat struct llama_sampler_mirostat { - const struct llama_vocab * vocab; + const int32_t n_vocab; const uint32_t seed; @@ -964,7 +965,7 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { // Compute k from the estimated s_hat and target surprise value float epsilon_hat = s_hat - 1; - float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->vocab->n_vocab, -epsilon_hat)), 1 / s_hat); + float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat); llama_sampler_top_k_impl(cur_p, std::max(int(k), 1)); llama_sampler_softmax_impl(cur_p); @@ -986,25 +987,25 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat_impl(*ctx->vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_mirostat_impl(const struct llama_vocab & vocab, uint32_t seed, float tau, float eta, int32_t m) { +struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { return new llama_sampler { /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { - /* .vocab = */ &vocab, - /* .seed = */ seed, - /* .tau = */ tau, - /* .eta = */ eta, - /* .m = */ m, - /* .mu = */ 2.0f*tau, - /* .rng = */ std::mt19937(seed), - /* .probs = */ {}, + /* .n_vocab = */ n_vocab, + /* .seed = */ seed, + /* .tau = */ tau, + /* .eta = */ eta, + /* .m = */ m, + /* .mu = */ 2.0f*tau, + /* .rng = */ std::mt19937(seed), + /* .probs = */ {}, }, }; } @@ -1172,7 +1173,9 @@ struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab // penalties struct llama_sampler_penalties { - const struct llama_vocab * vocab; + const int32_t n_vocab; + const llama_token special_eos_id; + const llama_token linefeed_id; const int32_t penalty_last_n; const float penalty_repeat; @@ -1194,10 +1197,21 @@ static struct llama_sampler_i llama_sampler_penalties_i = { /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_penalties *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'penalties' sampler must be applied on the full vocabulary"); - if (ctx->ignore_eos) { - cur_p->data[ctx->vocab->special_eos_id].logit = -INFINITY; + assert(ctx->special_eos_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) { + cur_p->data[ctx->special_eos_id].logit = -INFINITY; + } else { + // else, search for the special EOS token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->special_eos_id) { + cur_p->data[i].logit = -INFINITY; + break; + } + } + } } if ((ctx->penalty_last_n == 0) || @@ -1205,7 +1219,29 @@ static struct llama_sampler_i llama_sampler_penalties_i = { return; } - const float nl_logit = !ctx->penalize_nl ? cur_p->data[ctx->vocab->linefeed_id].logit : -INFINITY; + bool nl_found = false; + size_t nl_idx = 0; + float nl_logit = -INFINITY; + if (!ctx->penalize_nl) { + assert(ctx->linefeed_id >= 0); + + // optimistically check if the candidates are not yet sorted/shuffled/truncated + if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = ctx->linefeed_id; + nl_logit = cur_p->data[ctx->linefeed_id].logit; + } else { + // else, search for the linefeed token + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id == ctx->linefeed_id) { + nl_found = true; + nl_idx = i; + nl_logit = cur_p->data[i].logit; + break; + } + } + } + } // Create a frequency map to count occurrences of each token in last_tokens // TODO: optimize this by maintaining the token count in the sampler context @@ -1216,9 +1252,9 @@ static struct llama_sampler_i llama_sampler_penalties_i = { llama_sampler_penalties_impl(cur_p, token_count, ctx->penalty_repeat, ctx->penalty_freq, ctx->penalty_present); - if (!ctx->penalize_nl) { + if (!ctx->penalize_nl && nl_found) { // restore the logit of the newline token if it was penalized - cur_p->data[ctx->vocab->linefeed_id].logit = nl_logit; + cur_p->data[nl_idx].logit = nl_logit; } }, /* .reset = */ [](struct llama_sampler * smpl) { @@ -1227,8 +1263,10 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; - auto * result = llama_sampler_init_penalties_impl( - *ctx_src->vocab, + auto * result = llama_sampler_init_penalties( + ctx_src->n_vocab, + ctx_src->special_eos_id, + ctx_src->linefeed_id, ctx_src->penalty_last_n, ctx_src->penalty_repeat, ctx_src->penalty_freq, @@ -1246,14 +1284,30 @@ static struct llama_sampler_i llama_sampler_penalties_i = { }, }; -struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_vocab & vocab, int32_t penalty_last_n, float penalty_repeat, float penalty_freq, float penalty_present, bool penalize_nl, bool ignore_eos) { - GGML_ASSERT(penalize_nl || vocab.linefeed_id != LLAMA_TOKEN_NULL); - GGML_ASSERT(!ignore_eos || vocab.special_eos_id != LLAMA_TOKEN_NULL); +struct llama_sampler * llama_sampler_init_penalties( + int32_t n_vocab, + llama_token special_eos_id, + llama_token linefeed_id, + int32_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present, + bool penalize_nl, + bool ignore_eos) { + if (linefeed_id == LLAMA_TOKEN_NULL) { + penalize_nl = false; + } + + if (special_eos_id == LLAMA_TOKEN_NULL) { + ignore_eos = true; + } return new llama_sampler { /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { - /* .vocab = */ &vocab, + /* .n_vocab = */ n_vocab, + /* .special_eos_id = */ special_eos_id, + /* .linefeed_id = */ linefeed_id, /* .penalty_last_n = */ penalty_last_n, /* .penalty_repeat = */ penalty_repeat, /* .penalty_freq = */ penalty_freq, @@ -1268,9 +1322,11 @@ struct llama_sampler * llama_sampler_init_penalties_impl(const struct llama_voca // logit-bias struct llama_sampler_logit_bias { - const struct llama_vocab * vocab; + const int32_t n_vocab; + + const std::vector logit_bias; - std::vector logit_bias; + std::vector to_search; }; static struct llama_sampler_i llama_sampler_logit_bias_i = { @@ -1279,31 +1335,47 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { /* .apply = */ [](struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; - GGML_ASSERT(cur_p->size == ctx->vocab->n_vocab && cur_p->sorted == false && "the 'logit_bias' sampler must be applied on the full vocabulary"); + ctx->to_search.clear(); + // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id) for (const auto & lb : ctx->logit_bias) { - cur_p->data[lb.token].logit += lb.bias; + if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) { + cur_p->data[lb.token].logit += lb.bias; + } else { + ctx->to_search.push_back(lb); + } + } + + // search for the remaining candidates that were not found in the previous step + for (size_t i = 0; i < cur_p->size; ++i) { + for (const auto & lb : ctx->to_search) { + if (cur_p->data[i].id == lb.token) { + cur_p->data[i].logit += lb.bias; + break; + } + } } }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias_impl(*ctx_src->vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; }, }; -struct llama_sampler * llama_sampler_init_logit_bias_impl( - const struct llama_vocab & vocab, +struct llama_sampler * llama_sampler_init_logit_bias( + int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { return new llama_sampler { /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .vocab = */ &vocab, + /* .n_vocab = */ n_vocab, /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, }, }; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index ddc84a3900666..137c0025ce0d8 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -33,28 +33,7 @@ void llama_sampler_penalties_impl( float penalty_freq, float penalty_present); -struct llama_sampler * llama_sampler_init_mirostat_impl( - const struct llama_vocab & vocab, - uint32_t seed, - float tau, - float eta, - int32_t m); - struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); - -struct llama_sampler * llama_sampler_init_penalties_impl( - const struct llama_vocab & vocab, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos); - -LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias_impl( - const struct llama_vocab & vocab, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias); diff --git a/src/llama.cpp b/src/llama.cpp index c67f3638d337d..6bbaf9fc9bae7 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -20592,36 +20592,11 @@ int32_t llama_chat_apply_template( // sampling // -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_mirostat(const struct llama_model * model, uint32_t seed, float tau, float eta, int32_t m) { - return llama_sampler_init_mirostat_impl(model->vocab, seed, tau, eta, m); -} - // TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_penalties( - const struct llama_model * model, - int32_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present, - bool penalize_nl, - bool ignore_eos) { - return llama_sampler_init_penalties_impl(model->vocab, penalty_last_n, penalty_repeat, penalty_freq, penalty_present, penalize_nl, ignore_eos); -} - -// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp -struct llama_sampler * llama_sampler_init_logit_bias( - const struct llama_model * model, - int32_t n_logit_bias, - const llama_logit_bias * logit_bias) { - return llama_sampler_init_logit_bias_impl(model->vocab, n_logit_bias, logit_bias); -} - // // model split // From 8a82f388cdc32aa677f272054328683819b81d91 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 14:38:00 +0300 Subject: [PATCH 45/47] sampling : fix state cloning ggml-ci --- src/llama-sampling.cpp | 88 ++++++++++++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 25 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e53b3d3a77edc..02b93b64c6575 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -643,7 +643,16 @@ static struct llama_sampler_i llama_sampler_dist_i = { /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; - return llama_sampler_init_dist(ctx->seed); + auto * result = llama_sampler_init_dist(ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_dist *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; @@ -987,7 +996,17 @@ static struct llama_sampler_i llama_sampler_mirostat_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx; - return llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat *) smpl->ctx; @@ -1062,7 +1081,18 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = { }, /* .clone = */ [](const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx; - return llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta); + + // copy the state + { + auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx; + + result_ctx->mu = ctx->mu; + result_ctx->rng = ctx->rng; + } + + return result; }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_mirostat_v2 *) smpl->ctx; @@ -1120,16 +1150,20 @@ static struct llama_sampler_i llama_sampler_grammar_i = { ctx->grammar = grammar_new; }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_grammar *) smpl->ctx; + const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; + + auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); - auto * result = llama_sampler_init_grammar_impl(*ctx_src->vocab, nullptr, nullptr); + // copy the state + { + auto * result_ctx = (llama_sampler_grammar *) result->ctx; - auto * ctx_dst = (llama_sampler_grammar *) result->ctx; - if (ctx_src->grammar) { - ctx_dst->grammar_str = ctx_src->grammar_str; - ctx_dst->grammar_root = ctx_src->grammar_root; + if (ctx->grammar) { + result_ctx->grammar_str = ctx->grammar_str; + result_ctx->grammar_root = ctx->grammar_root; - ctx_dst->grammar = llama_grammar_clone_impl(*ctx_src->grammar); + result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar); + } } return result; @@ -1262,20 +1296,24 @@ static struct llama_sampler_i llama_sampler_penalties_i = { ctx->prev.clear(); }, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_penalties *) smpl->ctx; + const auto * ctx = (const llama_sampler_penalties *) smpl->ctx; auto * result = llama_sampler_init_penalties( - ctx_src->n_vocab, - ctx_src->special_eos_id, - ctx_src->linefeed_id, - ctx_src->penalty_last_n, - ctx_src->penalty_repeat, - ctx_src->penalty_freq, - ctx_src->penalty_present, - ctx_src->penalize_nl, - ctx_src->ignore_eos); - - auto * ctx_dst = (llama_sampler_penalties *) result->ctx; - ctx_dst->prev = ctx_src->prev; + ctx->n_vocab, + ctx->special_eos_id, + ctx->linefeed_id, + ctx->penalty_last_n, + ctx->penalty_repeat, + ctx->penalty_freq, + ctx->penalty_present, + ctx->penalize_nl, + ctx->ignore_eos); + + // copy the state + { + auto * result_ctx = (llama_sampler_penalties *) result->ctx; + + result_ctx->prev = ctx->prev; + } return result; }, @@ -1358,8 +1396,8 @@ static struct llama_sampler_i llama_sampler_logit_bias_i = { }, /* .reset = */ nullptr, /* .clone = */ [](const struct llama_sampler * smpl) { - const auto * ctx_src = (const llama_sampler_logit_bias *) smpl->ctx; - return llama_sampler_init_logit_bias(ctx_src->n_vocab, ctx_src->logit_bias.size(), ctx_src->logit_bias.data()); + const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx; + return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data()); }, /* .free = */ [](struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; From 2387dbea7d3417218faf7507eb9ff4eece396717 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 14:50:43 +0300 Subject: [PATCH 46/47] sampling : fix repeat penalty out-of-bounds access ggml-ci --- examples/server/server.cpp | 8 +++----- src/llama-sampling.cpp | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1095f43b206bb..f45b59983f05b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2323,10 +2323,10 @@ struct server_context { slot.release(); slot.i_batch = -1; continue; // continue loop of slots - } else { - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; } + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } @@ -2347,8 +2347,6 @@ struct server_context { const auto * cur_p = gpt_sampler_get_candidates(slot.smpl); - // TODO: this logic might have been broken during https://github.com/ggerganov/llama.cpp/pull/8643 - // fix if necessary for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) { result.probs.push_back({ cur_p->data[i].id, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 02b93b64c6575..61f4cbb9217e8 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1280,7 +1280,7 @@ static struct llama_sampler_i llama_sampler_penalties_i = { // Create a frequency map to count occurrences of each token in last_tokens // TODO: optimize this by maintaining the token count in the sampler context llama_token_cnt token_count; - for (int i = 0; i < ctx->penalty_last_n; ++i) { + for (int i = 0; i < std::min(ctx->penalty_last_n, ctx->prev.size()); ++i) { token_count[ctx->prev.rat(i)]++; } From 4ac186aece126b25e5eaf2334d97708b3bf60ffe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 7 Sep 2024 15:14:37 +0300 Subject: [PATCH 47/47] llama : update doc [no ci] --- include/llama.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/include/llama.h b/include/llama.h index 8bfb9e3b1532b..6334fc30d413c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -990,7 +990,10 @@ extern "C" { // llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50)); // llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1)); // llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8)); - // llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed)); + // + // // typically, the chain should end with a sampler such as "greedy", "dist" or "mirostat" + // // this sampler will be responsible to select the actual token + // llama_sampler_chain_add(smpl, llama_sampler_init_dist(seed)); // // ... //