From 97ab66453ffc6a9cc21cd0e8d4f83044f53b2832 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 5 Aug 2024 10:08:25 +0300 Subject: [PATCH] llama : add llama_sampling and combine it with llama_grammar ggml-ci --- Makefile | 6 - common/CMakeLists.txt | 2 - common/common.cpp | 2 +- common/grammar-parser.cpp | 536 ---------- common/grammar-parser.h | 29 - common/sampling.cpp | 158 ++- common/sampling.h | 17 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/batched.swift/Sources/main.swift | 14 +- examples/batched/batched.cpp | 14 +- examples/embedding/embedding.cpp | 2 +- examples/eval-callback/eval-callback.cpp | 2 +- examples/gbnf-validator/gbnf-validator.cpp | 30 +- examples/gritlm/gritlm.cpp | 29 +- examples/imatrix/imatrix.cpp | 2 +- examples/infill/infill.cpp | 18 +- examples/llama-bench/llama-bench.cpp | 2 +- .../llama/src/main/cpp/llama-android.cpp | 3 +- .../llama.cpp.swift/LibLlama.swift | 5 +- examples/llava/llava-cli.cpp | 8 +- examples/lookahead/lookahead.cpp | 8 +- examples/lookup/lookup.cpp | 6 +- examples/main/main.cpp | 24 +- examples/parallel/parallel.cpp | 9 +- examples/passkey/passkey.cpp | 7 +- examples/perplexity/perplexity.cpp | 2 +- examples/quantize-stats/quantize-stats.cpp | 2 +- examples/retrieval/retrieval.cpp | 2 +- examples/save-load-state/save-load-state.cpp | 14 +- examples/server/server.cpp | 10 +- examples/simple/simple.cpp | 6 +- examples/speculative/speculative.cpp | 28 +- include/llama.h | 331 +++---- src/llama-grammar.cpp | 930 ++++++++++++++---- src/llama-grammar.h | 143 ++- src/llama-impl.h | 9 +- src/llama-sampling.cpp | 227 ++--- src/llama-sampling.h | 67 +- src/llama-vocab.h | 5 +- src/llama.cpp | 333 ++++--- tests/test-grammar-integration.cpp | 32 +- tests/test-grammar-parser.cpp | 10 +- tests/test-json-schema-to-grammar.cpp | 10 +- tests/test-llama-grammar.cpp | 13 +- tests/test-sampling.cpp | 56 +- 45 files changed, 1611 insertions(+), 1554 deletions(-) delete mode 100644 common/grammar-parser.cpp delete mode 100644 common/grammar-parser.h diff --git a/Makefile b/Makefile index 649671ed6a72e2..2f9a715038194d 100644 --- a/Makefile +++ b/Makefile @@ -923,7 +923,6 @@ OBJ_COMMON = \ common/ngram-cache.o \ common/sampling.o \ common/train.o \ - common/grammar-parser.o \ common/build-info.o \ common/json-schema-to-grammar.o @@ -1163,11 +1162,6 @@ common/console.o: \ common/console.h $(CXX) $(CXXFLAGS) -c $< -o $@ -common/grammar-parser.o: \ - common/grammar-parser.cpp \ - common/grammar-parser.h - $(CXX) $(CXXFLAGS) -c $< -o $@ - common/json-schema-to-grammar.o: \ common/json-schema-to-grammar.cpp \ common/json-schema-to-grammar.h diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 761971d6881f38..2c72793b89dbe2 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -58,8 +58,6 @@ add_library(${TARGET} STATIC sampling.cpp console.h console.cpp - grammar-parser.h - grammar-parser.cpp json.hpp json-schema-to-grammar.cpp train.h diff --git a/common/common.cpp b/common/common.cpp index 560e20d080d0f1..77992aec9b135c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2159,7 +2159,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); llama_synchronize(lctx); - llama_reset_timings(lctx); + llama_reset_timings(lctx, nullptr); } iparams.model = model; diff --git a/common/grammar-parser.cpp b/common/grammar-parser.cpp deleted file mode 100644 index a518b766dc33e5..00000000000000 --- a/common/grammar-parser.cpp +++ /dev/null @@ -1,536 +0,0 @@ -#include "grammar-parser.h" -#include -#include -#include -#include -#include -#include - -namespace grammar_parser { - // NOTE: assumes valid utf8 (but checks for overrun) - // copied from llama.cpp - static std::pair decode_utf8(const char * src) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; - uint8_t first_byte = static_cast(*src); - uint8_t highbits = first_byte >> 4; - int len = lookup[highbits]; - uint8_t mask = (1 << (8 - len)) - 1; - uint32_t value = first_byte & mask; - const char * end = src + len; // may overrun! - const char * pos = src + 1; - for ( ; pos < end && *pos; pos++) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - } - return std::make_pair(value, pos); - } - - static uint32_t get_symbol_id(parse_state & state, const char * src, size_t len) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - auto result = state.symbol_ids.emplace(std::string(src, len), next_id); - return result.first->second; - } - - static uint32_t generate_symbol_id(parse_state & state, const std::string & base_name) { - uint32_t next_id = static_cast(state.symbol_ids.size()); - state.symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; - return next_id; - } - - static void add_rule( - parse_state & state, - uint32_t rule_id, - const std::vector & rule) { - if (state.rules.size() <= rule_id) { - state.rules.resize(rule_id + 1); - } - state.rules[rule_id] = rule; - } - - static bool is_digit_char(char c) { - return '0' <= c && c <= '9'; - } - - static bool is_word_char(char c) { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); - } - - static std::pair parse_hex(const char * src, int size) { - const char * pos = src; - const char * end = src + size; - uint32_t value = 0; - for ( ; pos < end && *pos; pos++) { - value <<= 4; - char c = *pos; - if ('a' <= c && c <= 'f') { - value += c - 'a' + 10; - } else if ('A' <= c && c <= 'F') { - value += c - 'A' + 10; - } else if ('0' <= c && c <= '9') { - value += c - '0'; - } else { - break; - } - } - if (pos != end) { - throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); - } - return std::make_pair(value, pos); - } - - static const char * parse_space(const char * src, bool newline_ok) { - const char * pos = src; - while (*pos == ' ' || *pos == '\t' || *pos == '#' || - (newline_ok && (*pos == '\r' || *pos == '\n'))) { - if (*pos == '#') { - while (*pos && *pos != '\r' && *pos != '\n') { - pos++; - } - } else { - pos++; - } - } - return pos; - } - - static const char * parse_name(const char * src) { - const char * pos = src; - while (is_word_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting name at ") + src); - } - return pos; - } - - static const char * parse_int(const char * src) { - const char * pos = src; - while (is_digit_char(*pos)) { - pos++; - } - if (pos == src) { - throw std::runtime_error(std::string("expecting integer at ") + src); - } - return pos; - } - - static std::pair parse_char(const char * src) { - if (*src == '\\') { - switch (src[1]) { - case 'x': return parse_hex(src + 2, 2); - case 'u': return parse_hex(src + 2, 4); - case 'U': return parse_hex(src + 2, 8); - case 't': return std::make_pair('\t', src + 2); - case 'r': return std::make_pair('\r', src + 2); - case 'n': return std::make_pair('\n', src + 2); - case '\\': - case '"': - case '[': - case ']': - return std::make_pair(src[1], src + 2); - default: - throw std::runtime_error(std::string("unknown escape at ") + src); - } - } else if (*src) { - return decode_utf8(src); - } - throw std::runtime_error("unexpected end of input"); - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested); - - static const char * parse_sequence( - parse_state & state, - const char * src, - const std::string & rule_name, - std::vector & out_elements, - bool is_nested) { - size_t last_sym_start = out_elements.size(); - const char * pos = src; - - auto handle_repetitions = [&](int min_times, int max_times) { - - if (last_sym_start == out_elements.size()) { - throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); - } - - // apply transformation to previous symbol (last_sym_start to end) according to - // the following rewrite rules: - // S{m,n} --> S S S (m times) S'(n-m) - // S'(x) ::= S S'(x-1) | - // (... n-m definitions of these S' rules ...) - // S'(1) ::= S | - // S{m,} --> S S S (m times) S' - // S' ::= S S' | - // S* --> S{0,} - // --> S' ::= S S' | - // S+ --> S{1,} - // --> S S' - // S' ::= S S' | - // S? --> S{0,1} - // --> S' - // S' ::= S | - - std::vector previous_elements(out_elements.begin() + last_sym_start, out_elements.end()); - if (min_times == 0) { - out_elements.resize(last_sym_start); - } else { - // Repeat the previous elements (min_times - 1) times - for (int i = 1; i < min_times; i++) { - out_elements.insert(out_elements.end(), previous_elements.begin(), previous_elements.end()); - } - } - - uint32_t last_rec_rule_id = 0; - auto n_opt = max_times < 0 ? 1 : max_times - min_times; - - std::vector rec_rule(previous_elements); - for (int i = 0; i < n_opt; i++) { - rec_rule.resize(previous_elements.size()); - uint32_t rec_rule_id = generate_symbol_id(state, rule_name); - if (i > 0 || max_times < 0) { - rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); - } - rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); - rec_rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rec_rule_id, rec_rule); - last_rec_rule_id = rec_rule_id; - } - if (n_opt > 0) { - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); - } - }; - - while (*pos) { - if (*pos == '"') { // literal string - pos++; - last_sym_start = out_elements.size(); - while (*pos != '"') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '[') { // char range(s) - pos++; - enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; - if (*pos == '^') { - pos++; - start_type = LLAMA_GRETYPE_CHAR_NOT; - } - last_sym_start = out_elements.size(); - while (*pos != ']') { - if (!*pos) { - throw std::runtime_error("unexpected end of input"); - } - auto char_pair = parse_char(pos); - pos = char_pair.second; - enum llama_gretype type = last_sym_start < out_elements.size() - ? LLAMA_GRETYPE_CHAR_ALT - : start_type; - - out_elements.push_back({type, char_pair.first}); - if (pos[0] == '-' && pos[1] != ']') { - if (!pos[1]) { - throw std::runtime_error("unexpected end of input"); - } - auto endchar_pair = parse_char(pos + 1); - pos = endchar_pair.second; - out_elements.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); - } - } - pos = parse_space(pos + 1, is_nested); - } else if (is_word_char(*pos)) { // rule reference - const char * name_end = parse_name(pos); - uint32_t ref_rule_id = get_symbol_id(state, pos, name_end - pos); - pos = parse_space(name_end, is_nested); - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); - } else if (*pos == '(') { // grouping - // parse nested alternates into synthesized rule - pos = parse_space(pos + 1, true); - uint32_t sub_rule_id = generate_symbol_id(state, rule_name); - pos = parse_alternates(state, pos, rule_name, sub_rule_id, true); - last_sym_start = out_elements.size(); - // output reference to synthesized rule - out_elements.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); - if (*pos != ')') { - throw std::runtime_error(std::string("expecting ')' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '.') { // any char - last_sym_start = out_elements.size(); - out_elements.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); - pos = parse_space(pos + 1, is_nested); - } else if (*pos == '*') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, -1); - } else if (*pos == '+') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(1, -1); - } else if (*pos == '?') { - pos = parse_space(pos + 1, is_nested); - handle_repetitions(0, 1); - } else if (*pos == '{') { - pos = parse_space(pos + 1, is_nested); - - if (!is_digit_char(*pos)) { - throw std::runtime_error(std::string("expecting an int at ") + pos); - } - const char * int_end = parse_int(pos); - int min_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - - int max_times = -1; - - if (*pos == '}') { - max_times = min_times; - pos = parse_space(pos + 1, is_nested); - } else if (*pos == ',') { - pos = parse_space(pos + 1, is_nested); - - if (is_digit_char(*pos)) { - const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); - pos = parse_space(int_end, is_nested); - } - - if (*pos != '}') { - throw std::runtime_error(std::string("expecting '}' at ") + pos); - } - pos = parse_space(pos + 1, is_nested); - } else { - throw std::runtime_error(std::string("expecting ',' at ") + pos); - } - handle_repetitions(min_times, max_times); - } else { - break; - } - } - return pos; - } - - const char * parse_alternates( - parse_state & state, - const char * src, - const std::string & rule_name, - uint32_t rule_id, - bool is_nested) { - std::vector rule; - const char * pos = parse_sequence(state, src, rule_name, rule, is_nested); - while (*pos == '|') { - rule.push_back({LLAMA_GRETYPE_ALT, 0}); - pos = parse_space(pos + 1, true); - pos = parse_sequence(state, pos, rule_name, rule, is_nested); - } - rule.push_back({LLAMA_GRETYPE_END, 0}); - add_rule(state, rule_id, rule); - return pos; - } - - static const char * parse_rule(parse_state & state, const char * src) { - const char * name_end = parse_name(src); - const char * pos = parse_space(name_end, false); - size_t name_len = name_end - src; - uint32_t rule_id = get_symbol_id(state, src, name_len); - const std::string name(src, name_len); - - if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { - throw std::runtime_error(std::string("expecting ::= at ") + pos); - } - pos = parse_space(pos + 3, true); - - pos = parse_alternates(state, pos, name, rule_id, false); - - if (*pos == '\r') { - pos += pos[1] == '\n' ? 2 : 1; - } else if (*pos == '\n') { - pos++; - } else if (*pos) { - throw std::runtime_error(std::string("expecting newline or end at ") + pos); - } - return parse_space(pos, true); - } - - parse_state parse(const char * src) { - try { - parse_state state; - const char * pos = parse_space(src, true); - while (*pos) { - pos = parse_rule(state, pos); - } - // Validate the state to ensure that all rules are defined - for (const auto & rule : state.rules) { - for (const auto & elem : rule) { - if (elem.type == LLAMA_GRETYPE_RULE_REF) { - // Ensure that the rule at that location exists - if (elem.value >= state.rules.size() || state.rules[elem.value].empty()) { - // Get the name of the rule that is missing - for (const auto & kv : state.symbol_ids) { - if (kv.second == elem.value) { - throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); - } - } - } - } - } - } - return state; - } catch (const std::exception & err) { - fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); - return parse_state(); - } - } - - static void print_grammar_char(FILE * file, uint32_t c) { - if (0x20 <= c && c <= 0x7f) { - fprintf(file, "%c", static_cast(c)); - } else { - // cop out of encoding UTF-8 - fprintf(file, "", c); - } - } - - static bool is_char_element(llama_grammar_element elem) { - switch (elem.type) { - case LLAMA_GRETYPE_CHAR: return true; - case LLAMA_GRETYPE_CHAR_NOT: return true; - case LLAMA_GRETYPE_CHAR_ALT: return true; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; - case LLAMA_GRETYPE_CHAR_ANY: return true; - default: return false; - } - } - - static void print_rule_binary(FILE * file, const std::vector & rule) { - for (auto elem : rule) { - switch (elem.type) { - case LLAMA_GRETYPE_END: fprintf(file, "END"); break; - case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; - case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; - case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; - case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; - case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; - case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; - } - switch (elem.type) { - case LLAMA_GRETYPE_END: - case LLAMA_GRETYPE_ALT: - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "(%u) ", elem.value); - break; - case LLAMA_GRETYPE_CHAR: - case LLAMA_GRETYPE_CHAR_NOT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "(\""); - print_grammar_char(file, elem.value); - fprintf(file, "\") "); - break; - } - } - fprintf(file, "\n"); - } - - static void print_rule( - FILE * file, - uint32_t rule_id, - const std::vector & rule, - const std::map & symbol_id_names) { - if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { - throw std::runtime_error( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); - } - fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - for (size_t i = 0, end = rule.size() - 1; i < end; i++) { - llama_grammar_element elem = rule[i]; - switch (elem.type) { - case LLAMA_GRETYPE_END: - throw std::runtime_error( - "unexpected end of rule: " + std::to_string(rule_id) + "," + - std::to_string(i)); - case LLAMA_GRETYPE_ALT: - fprintf(file, "| "); - break; - case LLAMA_GRETYPE_RULE_REF: - fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); - break; - case LLAMA_GRETYPE_CHAR: - fprintf(file, "["); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_NOT: - fprintf(file, "[^"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - fprintf(file, "-"); - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ALT: - if (i == 0 || !is_char_element(rule[i - 1])) { - throw std::runtime_error( - "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + - std::to_string(rule_id) + "," + std::to_string(i)); - } - print_grammar_char(file, elem.value); - break; - case LLAMA_GRETYPE_CHAR_ANY: - fprintf(file, "."); - break; - } - if (is_char_element(elem)) { - switch (rule[i + 1].type) { - case LLAMA_GRETYPE_CHAR_ALT: - case LLAMA_GRETYPE_CHAR_RNG_UPPER: - case LLAMA_GRETYPE_CHAR_ANY: - break; - default: - fprintf(file, "] "); - } - } - } - fprintf(file, "\n"); - } - - void print_grammar(FILE * file, const parse_state & state) { - try { - std::map symbol_id_names; - for (const auto & kv : state.symbol_ids) { - symbol_id_names[kv.second] = kv.first; - } - for (size_t i = 0, end = state.rules.size(); i < end; i++) { - // fprintf(file, "%zu: ", i); - // print_rule_binary(file, state.rules[i]); - print_rule(file, uint32_t(i), state.rules[i], symbol_id_names); - // fprintf(file, "\n"); - } - } catch (const std::exception & err) { - fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); - } - } - - std::vector parse_state::c_rules() { - std::vector ret; - ret.reserve(rules.size()); - for (const auto & rule : rules) { - ret.push_back(rule.data()); - } - return ret; - } -} diff --git a/common/grammar-parser.h b/common/grammar-parser.h deleted file mode 100644 index 9037d72728a42e..00000000000000 --- a/common/grammar-parser.h +++ /dev/null @@ -1,29 +0,0 @@ -// Implements a parser for an extended Backus-Naur form (BNF), producing the -// binary context-free grammar format specified by llama.h. Supports character -// ranges, grouping, and repetition operators. As an example, a grammar for -// arithmetic might look like: -// -// root ::= expr -// expr ::= term ([-+*/] term)* -// term ::= num | "(" space expr ")" space -// num ::= [0-9]+ space -// space ::= [ \t\n]* - -#pragma once -#include "llama.h" -#include -#include -#include -#include - -namespace grammar_parser { - struct parse_state { - std::map symbol_ids; - std::vector> rules; - - std::vector c_rules(); - }; - - parse_state parse(const char * src); - void print_grammar(FILE * file, const parse_state & state); -} diff --git a/common/sampling.cpp b/common/sampling.cpp index 079e405168dff2..f4f659b81f944d 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,99 +1,53 @@ -#define LLAMA_API_INTERNAL #include "sampling.h" -#include -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { - struct llama_sampling_context * result = new llama_sampling_context(); +#include - result->params = params; - result->grammar = nullptr; +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) { + auto result = llama_sampling_init(params, llama_sampling_init(model, params.grammar.c_str(), "root")); - // if there is a grammar, parse it - if (!params.grammar.empty()) { - result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + result->owned = true; - // will be empty (default) if there are parse errors - if (result->parsed_grammar.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); - delete result; - return nullptr; - } - - // Ensure that there is a "root" node. - if (result->parsed_grammar.symbol_ids.find("root") == result->parsed_grammar.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); - delete result; - return nullptr; - } + return result; +} - std::vector grammar_rules(result->parsed_grammar.c_rules()); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) { + struct llama_sampling_context * result = new llama_sampling_context(); - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - result->grammar = grammar; - } + result->params = params; + result->owned = false; + result->smpl = smpl; result->prev.resize(params.n_prev); result->n_valid = 0; - llama_sampling_set_rng_seed(result, params.seed); + llama_sampling_set_rng_seed(result->smpl, params.seed); return result; } void llama_sampling_free(struct llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); + if (ctx->owned) { + llama_sampling_free(ctx->smpl); } delete ctx; } void llama_sampling_reset(llama_sampling_context * ctx) { - if (ctx->grammar != NULL) { - llama_grammar_free(ctx->grammar); - ctx->grammar = NULL; - } - - if (!ctx->parsed_grammar.rules.empty()) { - std::vector grammar_rules(ctx->parsed_grammar.c_rules()); - - struct llama_grammar * grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); - if (grammar == nullptr) { - throw std::runtime_error("Failed to initialize llama_grammar"); - } - ctx->grammar = grammar; - } + llama_sampling_reset(ctx->smpl, ctx->params.grammar.c_str(), "root"); std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); ctx->n_valid = 0; } -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = std::random_device{}(); - } - ctx->rng.seed(seed); -} - void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) { - if (dst->grammar) { - llama_grammar_free(dst->grammar); - dst->grammar = nullptr; - } - - if (src->grammar) { - dst->grammar = llama_grammar_copy(src->grammar); + if (dst->smpl) { + llama_sampling_free(dst->smpl); } + dst->smpl = llama_sampling_cp(src->smpl); dst->prev = src->prev; } @@ -228,10 +182,13 @@ std::vector llama_sampling_types_from_chars(const std::strin // no reasons to expose this function in header static void sampler_queue( - struct llama_context * ctx_main, - const llama_sampling_params & params, + struct llama_sampling_context * ctx_sampling, llama_token_data_array & cur_p, size_t min_keep) { + llama_sampling * smpl = ctx_sampling->smpl; + + const llama_sampling_params & params = ctx_sampling->params; + const float temp = params.temp; const float dynatemp_range = params.dynatemp_range; const float dynatemp_exponent = params.dynatemp_exponent; @@ -244,18 +201,18 @@ static void sampler_queue( for (auto sampler_type : samplers_sequence) { switch (sampler_type) { - case llama_sampler_type::TOP_K : llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); break; - case llama_sampler_type::TFS_Z : llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); break; - case llama_sampler_type::TYPICAL_P: llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); break; - case llama_sampler_type::TOP_P : llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); break; - case llama_sampler_type::MIN_P : llama_sample_min_p (ctx_main, &cur_p, min_p, min_keep); break; + case llama_sampler_type::TOP_K : llama_sampling_top_k (smpl, &cur_p, top_k, min_keep); break; + case llama_sampler_type::TFS_Z : llama_sampling_tail_free(smpl, &cur_p, tfs_z, min_keep); break; + case llama_sampler_type::TYPICAL_P: llama_sampling_typical (smpl, &cur_p, typical_p, min_keep); break; + case llama_sampler_type::TOP_P : llama_sampling_top_p (smpl, &cur_p, top_p, min_keep); break; + case llama_sampler_type::MIN_P : llama_sampling_min_p (smpl, &cur_p, min_p, min_keep); break; case llama_sampler_type::TEMPERATURE: if (dynatemp_range > 0) { float dynatemp_min = std::max(0.0f, temp - dynatemp_range); float dynatemp_max = std::max(0.0f, temp + dynatemp_range); - llama_sample_entropy(ctx_main, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); + llama_sampling_entropy(smpl, &cur_p, dynatemp_min, dynatemp_max, dynatemp_exponent); } else { - llama_sample_temp(ctx_main, &cur_p, temp); + llama_sampling_temp(smpl, &cur_p, temp); } break; default : break; @@ -269,42 +226,44 @@ static llama_token llama_sampling_sample_impl( struct llama_context * ctx_cfg, const int idx, bool is_resampling) { + llama_sampling * smpl = ctx_sampling->smpl; + const llama_sampling_params & params = ctx_sampling->params; - const float temp = params.temp; - const int mirostat = params.mirostat; - const float mirostat_tau = params.mirostat_tau; - const float mirostat_eta = params.mirostat_eta; + const float temp = params.temp; + const int mirostat = params.mirostat; + const float mirostat_tau = params.mirostat_tau; + const float mirostat_eta = params.mirostat_eta; std::vector original_logits; auto cur_p = llama_sampling_prepare(ctx_sampling, ctx_main, ctx_cfg, idx, /* apply_grammar= */ is_resampling, &original_logits); - if (ctx_sampling->grammar != NULL && !is_resampling) { + if (!is_resampling) { GGML_ASSERT(!original_logits.empty()); } llama_token id = 0; if (temp < 0.0) { // greedy sampling, with probs - llama_sample_softmax(ctx_main, &cur_p); + llama_sampling_softmax(smpl, &cur_p); id = cur_p.data[0].id; } else if (temp == 0.0) { // greedy sampling, no probs - id = llama_sample_token_greedy(ctx_main, &cur_p); + id = llama_sampling_sample_greedy(smpl, &cur_p); } else { if (mirostat == 1) { const int mirostat_m = 100; - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); + llama_sampling_temp(smpl, &cur_p, temp); + id = llama_sampling_sample_mirostat(smpl, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); } else if (mirostat == 2) { - llama_sample_temp(ctx_main, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); + llama_sampling_temp(smpl, &cur_p, temp); + id = llama_sampling_sample_mirostat_v2(smpl, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { // temperature sampling size_t min_keep = std::max(1, params.min_keep); - sampler_queue(ctx_main, params, cur_p, min_keep); + sampler_queue(ctx_sampling, cur_p, min_keep); - id = llama_sample_token_with_rng(ctx_main, &cur_p, ctx_sampling->rng); + id = llama_sampling_sample(smpl, &cur_p); //{ // const int n_top = 10; @@ -313,15 +272,15 @@ static llama_token llama_sampling_sample_impl( // for (int i = 0; i < n_top; i++) { // const llama_token id = cur_p.data[i].id; // (void)id; // To avoid a warning that id is unused when logging is disabled. - // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(smpl, id).c_str(), cur_p.data[i].p); // } //} - //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); + //LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(smpl, id).c_str()); } } - if (ctx_sampling->grammar != NULL && !is_resampling) { + if (!is_resampling) { // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); @@ -330,7 +289,7 @@ static llama_token llama_sampling_sample_impl( llama_token_data_array single_token_data_array = { &single_token_data, 1, false }; // Apply grammar constraints to the single token - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array); + llama_sampling_grammar(ctx_sampling->smpl, &single_token_data_array); // Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY bool is_valid = single_token_data_array.data[0].logit != -INFINITY; @@ -358,6 +317,8 @@ static llama_token_data_array llama_sampling_prepare_impl( const int idx, bool apply_grammar, std::vector * original_logits) { + llama_sampling * smpl = ctx_sampling->smpl; + const llama_sampling_params & params = ctx_sampling->params; const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); @@ -375,7 +336,7 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); - if (ctx_sampling->grammar != NULL && !apply_grammar) { + if (!apply_grammar) { GGML_ASSERT(original_logits != NULL); // Only make a copy of the original logits if we are not applying grammar checks, not sure if I actually have to do this. *original_logits = {logits, logits + n_vocab}; @@ -388,7 +349,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); - llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); + llama_sampling_apply_guidance(smpl, logits, logits_guidance, params.cfg_scale); } cur.resize(n_vocab); @@ -405,7 +366,7 @@ static llama_token_data_array llama_sampling_prepare_impl( if (penalty_tokens_used_size) { const float nl_logit = logits[llama_token_nl(llama_get_model(ctx_main))]; - llama_sample_repetition_penalties(ctx_main, &cur_p, + llama_sampling_repetition_penalties(smpl, &cur_p, penalty_tokens.data() + penalty_tokens.size() - penalty_tokens_used_size, penalty_tokens_used_size, penalty_repeat, penalty_freq, penalty_present); @@ -420,8 +381,8 @@ static llama_token_data_array llama_sampling_prepare_impl( } // apply grammar checks before sampling logic - if (apply_grammar && ctx_sampling->grammar != NULL) { - llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p); + if (apply_grammar) { + llama_sampling_grammar(ctx_sampling->smpl, &cur_p); } return cur_p; @@ -443,18 +404,17 @@ llama_token_data_array llama_sampling_prepare( const int idx, bool apply_grammar, std::vector * original_logits) { - return llama_sampling_prepare_impl(ctx_sampling,ctx_main, ctx_cfg, idx, apply_grammar, original_logits); + return llama_sampling_prepare_impl(ctx_sampling, ctx_main, ctx_cfg, idx, apply_grammar, original_logits); } void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, llama_token id, bool apply_grammar) { ctx_sampling->prev.erase(ctx_sampling->prev.begin()); ctx_sampling->prev.push_back(id); - if (ctx_sampling->grammar != NULL && apply_grammar) { - llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id); + if (apply_grammar) { + llama_sampling_accept(ctx_sampling->smpl, id); } } diff --git a/common/sampling.h b/common/sampling.h index eeaa53b8bcd008..6723895964b70f 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,8 +2,6 @@ #include "llama.h" -#include "grammar-parser.h" - #include #include #include @@ -73,23 +71,22 @@ struct llama_sampling_context { // mirostat sampler state float mirostat_mu; - llama_grammar * grammar; + bool owned; - // internal - grammar_parser::parse_state parsed_grammar; + llama_sampling * smpl; // TODO: replace with ring-buffer std::vector prev; std::vector cur; - size_t n_valid; // Number of correct top tokens with correct probabilities. - std::mt19937 rng; + size_t n_valid; // Number of correct top tokens with correct probabilities. }; #include "common.h" // Create a new sampling context instance. -struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model); +struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl); void llama_sampling_free(struct llama_sampling_context * ctx); @@ -98,9 +95,6 @@ void llama_sampling_free(struct llama_sampling_context * ctx); // - reset grammar void llama_sampling_reset(llama_sampling_context * ctx); -// Set the sampler seed -void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed); - // Copy the sampler context void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst); @@ -155,6 +149,5 @@ llama_token_data_array llama_sampling_prepare( void llama_sampling_accept( struct llama_sampling_context * ctx_sampling, - struct llama_context * ctx_main, llama_token id, bool apply_grammar); diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775a0095d..afd3f11d24daee 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { } } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d841d0..2bc5fce7dfb6ee 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -44,6 +44,8 @@ context_params.n_threads = 8 context_params.n_threads_batch = 8 let context = llama_new_context_with_model(model, context_params) +let smpl = llama_get_sampling(context) + guard context != nil else { print("Failed to initialize context") exit(1) @@ -144,13 +146,13 @@ while n_cur <= n_len { let top_p: Float = 0.9 let temp: Float = 0.4 - llama_sample_top_k(context, &candidates_p, top_k, 1) - llama_sample_top_p(context, &candidates_p, top_p, 1) - llama_sample_temp(context, &candidates_p, temp) + llama_sampling_top_k(smpl, &candidates_p, top_k, 1) + llama_sampling_top_p(smpl, &candidates_p, top_p, 1) + llama_sampling_temp(smpl, &candidates_p, temp) - let new_token_id = llama_sample_token(context, &candidates_p) + let new_token_id = llama_sampling_sample(smpl, &candidates_p) - // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + // const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of stream? -> mark the stream as finished if llama_token_is_eog(model, new_token_id) || n_cur == n_len { @@ -212,7 +214,7 @@ let t_main_end = ggml_time_us() print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") -llama_print_timings(context) +llama_print_timings(context, smpl) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8cf2ae..142c825fcd9252 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -64,6 +64,7 @@ int main(int argc, char ** argv) { ctx_params.n_batch = std::max(n_predict, n_parallel); llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_sampling * smpl = llama_get_sampling(ctx); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -180,13 +181,13 @@ int main(int argc, char ** argv) { const float top_p = 0.9f; const float temp = 0.4f; - llama_sample_top_k(ctx, &candidates_p, top_k, 1); - llama_sample_top_p(ctx, &candidates_p, top_p, 1); - llama_sample_temp (ctx, &candidates_p, temp); + llama_sampling_top_k(smpl, &candidates_p, top_k, 1); + llama_sampling_top_p(smpl, &candidates_p, top_p, 1); + llama_sampling_temp (smpl, &candidates_p, temp); - const llama_token new_token_id = llama_sample_token(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample(smpl, &candidates_p); - //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + //const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? -> mark the stream as finished if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -244,12 +245,13 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, smpl); fprintf(stderr, "\n"); llama_batch_free(batch); + llama_sampling_free(smpl); llama_free(ctx); llama_free_model(model); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index cd7b448a619fa3..f4dc6d4c0d873d 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -258,7 +258,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_batch_free(batch); llama_free(ctx); llama_free_model(model); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index ef35ba2c03942e..6de7c3dd787207 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -183,7 +183,7 @@ int main(int argc, char ** argv) { return 1; } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/gbnf-validator/gbnf-validator.cpp b/examples/gbnf-validator/gbnf-validator.cpp index 48a705e15cea99..9467fe4e2898e8 100644 --- a/examples/gbnf-validator/gbnf-validator.cpp +++ b/examples/gbnf-validator/gbnf-validator.cpp @@ -1,8 +1,7 @@ -#define LLAMA_API_INTERNAL - -#include "grammar-parser.h" #include "ggml.h" #include "llama.h" +#include "llama-vocab.h" // TMP +#include "llama-grammar.h" #include "unicode.h" #include @@ -85,27 +84,8 @@ int main(int argc, char** argv) { grammar_str = buffer.str(); } - // Parse the GBNF grammar - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - fprintf(stdout, "%s: failed to parse grammar\n", __func__); - return 1; - } - - // Ensure that there is a "root" node. - if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) { - fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__); - return 1; - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - - // Create the LLAMA grammar - auto grammar = llama_grammar_init( - grammar_rules.data(), - grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + llama_vocab vocab; // TMP + llama_grammar * grammar = llama_grammar_init_impl(vocab, grammar_str.c_str(), "root"); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); } @@ -131,7 +111,7 @@ int main(int argc, char** argv) { } // Clean up - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); return 0; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1eb3bc8..98acb29d25d412 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -9,7 +9,7 @@ static std::vector> encode(llama_context * ctx, const std::vector & sentences, const std::string & instruction) { std::vector> result; - const llama_model * mdl = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -18,16 +18,16 @@ static std::vector> encode(llama_context * ctx, const std::ve const std::string input_string = instruction + sentences[i]; - std::vector inputs = llama_tokenize(mdl, input_string, true, false); + std::vector inputs = llama_tokenize(model, input_string, true, false); const int32_t n_toks = inputs.size(); // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(mdl)); + // inputs.push_back(llama_token_eos(model)); // we want to ignore instruction tokens for mean pooling - const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size(); + const int32_t n_inst = llama_tokenize(model, instruction, true, false).size(); #ifdef GRIT_DEBUG // debug tokens - should be matching as referenced in the GritLM sample @@ -51,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(mdl); + uint64_t n_embd = llama_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -95,8 +95,9 @@ static std::vector> encode(llama_context * ctx, const std::ve static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) { std::string result; - const llama_model * mdl = llama_get_model(ctx); - llama_token eos_token = llama_token_eos(mdl); + const llama_model * model = llama_get_model(ctx); + llama_sampling * smpl = llama_get_sampling(ctx); + llama_token eos_token = llama_token_eos(model); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -104,7 +105,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); - std::vector inputs = llama_tokenize(mdl, prompt, false, true); + std::vector inputs = llama_tokenize(model, prompt, false, true); int32_t i_current_token = 0; while (true) { @@ -118,14 +119,14 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_decode(ctx, bat); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); - auto candidates = std::vector(llama_n_vocab(mdl)); + auto candidates = std::vector(llama_n_vocab(model)); auto n_candidates = (int32_t)candidates.size(); for (int32_t token = 0; token < n_candidates; token++) { candidates[token] = llama_token_data{ token, logits[token], 0.0f }; } auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false }; - llama_token token = llama_sample_token_greedy(ctx, &candidates_p); + llama_token token = llama_sampling_sample_greedy(smpl, &candidates_p); if (token == eos_token) { break; } @@ -167,10 +168,10 @@ int main(int argc, char * argv[]) { llama_backend_init(); - llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); + llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(model, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -191,7 +192,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(mdl); + const int n_embd = llama_n_embd(model); const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); @@ -212,7 +213,7 @@ int main(int argc, char * argv[]) { } llama_free(ctx); - llama_free_model(mdl); + llama_free_model(model); llama_backend_free(); return 0; diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 58814b96e7d497..720ec6ae087866 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -638,7 +638,7 @@ int main(int argc, char ** argv) { g_collector.save_imatrix(); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 92d630b15fdf1b..13b3c362214c09 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -2,7 +2,6 @@ #include "console.h" #include "llama.h" -#include "grammar-parser.h" #include #include @@ -34,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling_context ** g_ctx_sampling; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -93,7 +93,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -171,11 +171,13 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_sampling_context * ctx_sampling = nullptr; g_model = &model; g_ctx = &ctx; + g_ctx_sampling = &ctx_sampling; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -349,7 +351,7 @@ int main(int argc, char ** argv) { std::vector embd; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx)); while (n_remain != 0 || params.interactive) { // predict @@ -423,7 +425,7 @@ int main(int argc, char ** argv) { if ((int) embd_inp.size() <= n_consumed && !is_interacting) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, id, true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -444,7 +446,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -638,7 +640,7 @@ int main(int argc, char ** argv) { fflush(stdout); } - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 42918bfc79f224..20052901730bf7 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1463,7 +1463,7 @@ int main(int argc, char ** argv) { fflush(p_err->fout); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); } diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe23167557f..af3b356bc04b7f 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); + const auto sampling = reinterpret_cast(llama_get_sampling(context)); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); @@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const auto new_token_id = llama_sample_token_greedy(context, &candidates_p); + const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 58c32ca533bb10..bfd273072e627b 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer + private var sampling: OpaquePointer private var batch: llama_batch private var tokens_list: [llama_token] var is_done: Bool = false @@ -42,12 +43,14 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] + self.sampling = llama_get_sampling(context) } deinit { llama_batch_free(batch) llama_free(context) llama_free_model(model) + llama_sampling_free(sampling) llama_backend_free() } @@ -156,7 +159,7 @@ actor LlamaContext { candidates.withUnsafeMutableBufferPointer() { buffer in var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false) - new_token_id = llama_sample_token_greedy(context, &candidates_p) + new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p) } if llama_token_is_eog(model, new_token_id) || n_cur == n_len { diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 8c7dd2ae3d0dc3..55566d022735fb 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -44,7 +44,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, struct llama_context * ctx_llama, int * n_past) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL); - llama_sampling_accept(ctx_sampling, ctx_llama, id, true); + llama_sampling_accept(ctx_sampling, id, true); static std::string ret; if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { ret = ""; @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_ LOG_TEE("\n"); - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama)); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -310,7 +310,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); @@ -327,7 +327,7 @@ int main(int argc, char ** argv) { // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); - llama_print_timings(ctx_llava->ctx_llama); + llama_print_timings(ctx_llava->ctx_llama, nullptr); llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629c5b6ae..5617d45c45c77a 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -118,7 +118,7 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx)); // verification n-grams std::vector ngrams_cur(G); @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { { id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, id, true); { const std::string token_str = llama_token_to_piece(ctx, id); @@ -286,7 +286,7 @@ int main(int argc, char ** argv) { // sample the next token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, id, true); // print { @@ -468,7 +468,7 @@ int main(int argc, char ** argv) { LOG_TEE("n_predict = %d\n", n_predict); LOG_TEE("n_accept = %d\n", n_accept); - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl); llama_kv_cache_view_free(&kvc_view); llama_sampling_free(ctx_sampling); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828c2ea23..6753bedee8d844 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -106,7 +106,7 @@ int main(int argc, char ** argv){ bool has_eos = false; - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx)); std::vector draft; @@ -132,7 +132,7 @@ int main(int argc, char ** argv){ // sample from the target model llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft); - llama_sampling_accept(ctx_sampling, ctx, id, true); + llama_sampling_accept(ctx_sampling, id, true); const std::string token_str = llama_token_to_piece(ctx, id); @@ -241,7 +241,7 @@ int main(int argc, char ** argv){ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl); llama_sampling_free(ctx_sampling); llama_batch_free(batch_tgt); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e0635a66cd067..b2ac2e54eaad40 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -33,6 +33,7 @@ static llama_context ** g_ctx; static llama_model ** g_model; +static llama_sampling_context ** g_ctx_sampling; static gpt_params * g_params; static std::vector * g_input_tokens; static std::ostringstream * g_output_ss; @@ -105,7 +106,7 @@ static void sigint_handler(int signo) { } else { console::cleanup(); printf("\n"); - llama_print_timings(*g_ctx); + llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl); write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); _exit(130); } @@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v static std::string chat_add_and_format(struct llama_model * model, std::vector & chat_msgs, std::string role, std::string content) { llama_chat_msg new_msg{role, content}; - auto formatted = llama_chat_format_single( - model, g_params->chat_template, chat_msgs, new_msg, role == "user"); + auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user"); chat_msgs.push_back({role, content}); LOG("formatted: %s\n", formatted.c_str()); return formatted; @@ -198,12 +198,16 @@ int main(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); - llama_model * model; - llama_context * ctx; - llama_context * ctx_guidance = NULL; + llama_model * model = nullptr; + llama_context * ctx = nullptr; + llama_context * ctx_guidance = nullptr; + llama_sampling_context * ctx_sampling = nullptr; + std::vector chat_msgs; + g_model = &model; g_ctx = &ctx; + g_ctx_sampling = &ctx_sampling; // load the model and apply lora adapter, if any LOG("%s: load the model and apply lora adapter, if any\n", __func__); @@ -531,7 +535,7 @@ int main(int argc, char ** argv) { antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true)); } - struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx)); if (!ctx_sampling) { fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__); exit(1); @@ -734,7 +738,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); - llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true); + llama_sampling_accept(ctx_sampling, id, /* apply_grammar= */ true); LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); @@ -755,7 +759,7 @@ int main(int argc, char ** argv) { // push the prompt in the sampling context in order to apply repetition penalties later // for the prompt, we don't apply grammar rules - llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false); + llama_sampling_accept(ctx_sampling, embd_inp[n_consumed], /* apply_grammar= */ false); ++n_consumed; if ((int) embd.size() >= params.n_batch) { @@ -979,7 +983,7 @@ int main(int argc, char ** argv) { llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); } - llama_print_timings(ctx); + llama_print_timings(ctx, ctx_sampling->smpl); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); if (ctx_guidance) { llama_free(ctx_guidance); } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c95906226..1e7fdbeff37732 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -161,7 +161,7 @@ int main(int argc, char ** argv) { for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; client.id = i; - client.ctx_sampling = llama_sampling_init(params.sparams); + client.ctx_sampling = llama_sampling_init(params.sparams, model); } std::vector tokens_system; @@ -343,7 +343,7 @@ int main(int argc, char ** argv) { const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i); - llama_sampling_accept(client.ctx_sampling, ctx, id, true); + llama_sampling_accept(client.ctx_sampling, id, true); if (client.n_decoded == 1) { // start measuring generation time after the first token to make sure all concurrent clients @@ -371,7 +371,7 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); + llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); @@ -413,7 +413,8 @@ int main(int argc, char ** argv) { LOG_TEE("\n"); - llama_print_timings(ctx); + // TODO: print sampling/grammar timings for all clients + llama_print_timings(ctx, nullptr); llama_batch_free(batch); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a94..de5dd625eee410 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -80,12 +80,13 @@ int main(int argc, char ** argv) { GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); llama_context * ctx = llama_new_context_with_model(model, ctx_params); - if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize the prompt std::vector tokens_list; tokens_list = ::llama_tokenize(ctx, params.prompt, true); @@ -230,7 +231,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { @@ -267,7 +268,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 372684f092de29..fbf3872a01e437 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -2054,7 +2054,7 @@ int main(int argc, char ** argv) { results = perplexity(ctx, params, n_ctx); } - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); write_logfile(ctx, params, model, results); llama_free(ctx); diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 68cf8d3595e878..25c2de60cbce7a 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -1,7 +1,7 @@ -#define LLAMA_API_INTERNAL #include "common.h" #include "ggml.h" #include "llama.h" +#include "llama-impl.h" #include #include diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 65b19ce71cbe3e..e1d7729dd45bde 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -293,7 +293,7 @@ int main(int argc, char ** argv) { } // clean up - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); llama_free(ctx); llama_free_model(model); llama_backend_free(); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3ea7c790d2bf78..937cd6389f8359 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -38,6 +38,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); @@ -73,7 +75,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx, &candidates_p); + auto next_token = llama_sampling_sample(smpl, &candidates_p); auto next_token_str = llama_token_to_piece(ctx, next_token); printf("%s", next_token_str.c_str()); @@ -96,6 +98,8 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + llama_sampling * smpl2 = llama_get_sampling(ctx2); + printf("\nsecond run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file @@ -132,7 +136,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx2, &candidates_p); + auto next_token = llama_sampling_sample(smpl2, &candidates_p); auto next_token_str = llama_token_to_piece(ctx2, next_token); printf("%s", next_token_str.c_str()); @@ -157,7 +161,9 @@ int main(int argc, char ** argv) { } // make new context - auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); + + llama_sampling * smpl3 = llama_get_sampling(ctx3); printf("\nsingle seq run: %s", params.prompt.c_str()); @@ -223,7 +229,7 @@ int main(int argc, char ** argv) { candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - auto next_token = llama_sample_token(ctx3, &candidates_p); + auto next_token = llama_sampling_sample(smpl3, &candidates_p); auto next_token_str = llama_token_to_piece(ctx3, next_token); printf("%s", next_token_str.c_str()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 360f571e428676..7fb02b2abd8024 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3,7 +3,6 @@ #include "common.h" #include "json-schema-to-grammar.h" #include "llama.h" -#include "grammar-parser.h" #ifndef NDEBUG // crash the server in debug mode, otherwise send an http 500 error @@ -1098,7 +1097,8 @@ struct server_context { if (slot.ctx_sampling != nullptr) { llama_sampling_free(slot.ctx_sampling); } - slot.ctx_sampling = llama_sampling_init(slot.sparams); + + slot.ctx_sampling = llama_sampling_init(slot.sparams, model); if (slot.ctx_sampling == nullptr) { // for now, the only error that may happen here is invalid grammar send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); @@ -2168,7 +2168,7 @@ struct server_context { // push the prompt into the sampling context (do not apply grammar) for (int i = 0; i < slot.n_past; ++i) { - llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + llama_sampling_accept(slot.ctx_sampling, slot.cache_tokens[i], false); } } } @@ -2400,7 +2400,7 @@ struct server_context { completion_token_output result; const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i); - llama_sampling_accept(slot.ctx_sampling, ctx, id, true); + llama_sampling_accept(slot.ctx_sampling, id, true); slot.n_decoded += 1; if (slot.n_decoded == 1) { @@ -2418,7 +2418,7 @@ struct server_context { // Make sure at least n_probs top tokens are at the front of the vector: if (slot.sparams.temp == 0.0f && n_probs > n_valid) { - llama_sample_top_k(ctx, &cur_p, n_probs, 0); + llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0); } if (slot.sparams.temp == 0.0f) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 69a92cf7dc0c01..6c8aaacca468b9 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -55,6 +55,8 @@ int main(int argc, char ** argv) { return 1; } + llama_sampling * smpl = llama_get_sampling(ctx); + // tokenize the prompt std::vector tokens_list; @@ -123,7 +125,7 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // sample the most likely token - const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p); // is it an end of generation? if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { @@ -160,7 +162,7 @@ int main(int argc, char ** argv) { LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n", __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); - llama_print_timings(ctx); + llama_print_timings(ctx, nullptr); fprintf(stderr, "\n"); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index b051a18f169c26..261fcd6afbeaed 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -178,8 +178,8 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; - // target model sampling context - struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams); + // target model sampling context (reuse the llama_context's sampling instance) + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx_tgt)); // draft sequence data std::vector drafts(n_seq_dft); @@ -190,7 +190,8 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + // allocate llama_sampling for each draft sequence + drafts[s].ctx_sampling = llama_sampling_init(params.sparams, model_dft); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -234,8 +235,10 @@ int main(int argc, char ** argv) { // stochastic verification llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL); - llama_sample_softmax(ctx_tgt, &dist_tgt); - float p_tgt = 0, p_dft = 0; + llama_sampling_softmax(ctx_sampling->smpl, &dist_tgt); + + float p_tgt = 0.0f; + float p_dft = 0.0f; // GGML_ASSERT(dist_tgt.size() == dist_dft.size()); @@ -277,7 +280,7 @@ int main(int argc, char ** argv) { accept = true; token_id = drafts[s].tokens[i_dft]; token_str = llama_token_to_piece(ctx_tgt, token_id); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(ctx_sampling, token_id, true); LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str()); break; @@ -331,8 +334,8 @@ int main(int argc, char ** argv) { // all drafted tokens were rejected // sample from the target model LOG("all drafted tokens were rejected, sampling from residual distribution\n"); - token_id = llama_sample_token(ctx_tgt, &dist_tgt); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + token_id = llama_sampling_sample(ctx_sampling->smpl, &dist_tgt); + llama_sampling_accept(ctx_sampling, token_id, true); token_str = llama_token_to_piece(ctx_tgt, token_id); } @@ -343,7 +346,7 @@ int main(int argc, char ** argv) { LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); - llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true); + llama_sampling_accept(ctx_sampling, token_id, true); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); @@ -518,7 +521,7 @@ int main(int argc, char ** argv) { const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); + llama_sampling_accept(drafts[s].ctx_sampling, id, true); drafts[s].tokens.push_back(id); // save cur_p.data into drafts[s].dists @@ -593,10 +596,11 @@ int main(int argc, char ** argv) { LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted); LOG_TEE("\ndraft:\n"); - llama_print_timings(ctx_dft); + // TODO: print sampling/grammar timings for all drafts + llama_print_timings(ctx_dft, nullptr); LOG_TEE("\ntarget:\n"); - llama_print_timings(ctx_tgt); + llama_print_timings(ctx_tgt, ctx_sampling->smpl); llama_sampling_free(ctx_sampling); for (int s = 0; s < n_seq_dft; ++s) { diff --git a/include/llama.h b/include/llama.h index 66c266298e86f0..f0974b536547b6 100644 --- a/include/llama.h +++ b/include/llama.h @@ -53,6 +53,7 @@ extern "C" { // TODO: show sample usage // + // struct llama_vocab; // TODO: add in the future struct llama_model; struct llama_context; @@ -355,53 +356,22 @@ extern "C" { void * kv_overrides; // pointer to vector containing overrides } llama_model_quantize_params; - // grammar types - struct llama_grammar; - - // grammar element type - enum llama_gretype { - // end of rule definition - LLAMA_GRETYPE_END = 0, - - // start of alternate definition for rule - LLAMA_GRETYPE_ALT = 1, - - // non-terminal element: reference to rule - LLAMA_GRETYPE_RULE_REF = 2, - - // terminal element: character (code point) - LLAMA_GRETYPE_CHAR = 3, - - // inverse char(s) ([^a], [^a-b] [^abc]) - LLAMA_GRETYPE_CHAR_NOT = 4, - - // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to - // be an inclusive range ([a-z]) - LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, - - // modifies a preceding LLAMA_GRETYPE_CHAR or - // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) - LLAMA_GRETYPE_CHAR_ALT = 6, - - // any character (.) - LLAMA_GRETYPE_CHAR_ANY = 7, - }; - - typedef struct llama_grammar_element { - enum llama_gretype type; - uint32_t value; // Unicode code point or rule ID - } llama_grammar_element; + // sampling types + struct llama_sampling; // performance timing information struct llama_timings { double t_start_ms; double t_end_ms; double t_load_ms; - double t_sample_ms; + double t_sampling_ms; + double t_grammar_ms; double t_p_eval_ms; double t_eval_ms; - int32_t n_sample; + int32_t n_sampling; + int32_t n_grammar_sample; + int32_t n_grammar_accept; int32_t n_p_eval; int32_t n_eval; }; @@ -452,23 +422,23 @@ extern "C" { LLAMA_API bool llama_supports_mlock (void); LLAMA_API bool llama_supports_gpu_offload(void); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); - LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); - LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API struct llama_sampling * llama_get_sampling( struct llama_context * ctx); + + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -996,121 +966,100 @@ extern "C" { int32_t length); // - // Grammar + // Sampling functions // - /// Initialize a llama_grammar. - /// - /// @param rules The rule elements of the grammar to initialize. - /// @param n_rules The number of rules. - /// @param start_rule_index The index of the root rule (the starting point of the grammar). - /// @return The initialized llama_grammar or nullptr if initialization failed. - LLAMA_API struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); - - LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // TODO: args become llama_sampling_params + // TODO: llama_model should become llama_vocab + LLAMA_API struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root); - LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar); + LLAMA_API void llama_sampling_free(struct llama_sampling * smpl); - /// @details Apply constraints from grammar - LLAMA_API void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates); - LLAMA_API DEPRECATED(void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar), - "use llama_grammar_sample instead"); + LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl); - /// @details Accepts the sampled token into the grammar - LLAMA_API void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token); - - // - // Sampling functions - // + LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root); // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); - - /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present); - - /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 - /// @param logits Logits extracted from the original generation context. - /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. - /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. - LLAMA_API void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale); + LLAMA_API void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. - LLAMA_API void llama_sample_softmax( - struct llama_context * ctx, - llama_token_data_array * candidates); + LLAMA_API void llama_sampling_softmax( + struct llama_sampling * smpl, + llama_token_data_array * candidates); /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_k( - struct llama_context * ctx, - llama_token_data_array * candidates, - int32_t k, - size_t min_keep); + LLAMA_API void llama_sampling_top_k( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + int32_t k, + size_t min_keep); /// @details Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 - LLAMA_API void llama_sample_top_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_top_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 - LLAMA_API void llama_sample_min_p( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_min_p( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. - LLAMA_API void llama_sample_tail_free( - struct llama_context * ctx, - llama_token_data_array * candidates, - float z, - size_t min_keep); + LLAMA_API void llama_sampling_tail_free( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float z, + size_t min_keep); /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. - LLAMA_API void llama_sample_typical( - struct llama_context * ctx, - llama_token_data_array * candidates, - float p, - size_t min_keep); + LLAMA_API void llama_sampling_typical( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float p, + size_t min_keep); /// @details Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772. - LLAMA_API void llama_sample_entropy( - struct llama_context * ctx, - llama_token_data_array * candidates_p, - float min_temp, - float max_temp, - float exponent_val); + LLAMA_API void llama_sampling_entropy( + struct llama_sampling * smpl, + llama_token_data_array * candidates_p, + float min_temp, + float max_temp, + float exponent_val); + + LLAMA_API void llama_sampling_temp( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float temp); - LLAMA_API void llama_sample_temp( - struct llama_context * ctx, - llama_token_data_array * candidates, - float temp); + /// @details Apply constraints from grammar + LLAMA_API void llama_sampling_grammar( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + LLAMA_API void llama_sampling_repetition_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); + + /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 + /// @param logits Logits extracted from the original generation context. + /// @param logits_guidance Logits extracted from a separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context. + /// @param scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance. + LLAMA_API void llama_sampling_apply_guidance( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale); /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. @@ -1118,36 +1067,41 @@ extern "C" { /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param m The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - int32_t m, - float * mu); + LLAMA_API llama_token llama_sampling_sample_mirostat( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float tau, + float eta, + int32_t m, + float * mu); /// @details Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text. /// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. /// @param eta The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. - LLAMA_API llama_token llama_sample_token_mirostat_v2( - struct llama_context * ctx, - llama_token_data_array * candidates, - float tau, - float eta, - float * mu); + LLAMA_API llama_token llama_sampling_sample_mirostat_v2( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + float tau, + float eta, + float * mu); /// @details Selects the token with the highest probability. - /// Does not compute the token probabilities. Use llama_sample_softmax() instead. - LLAMA_API llama_token llama_sample_token_greedy( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// Does not compute the token probabilities. Use llama_sampling_softmax() instead. + LLAMA_API llama_token llama_sampling_sample_greedy( + struct llama_sampling * smpl, + llama_token_data_array * candidates); - /// @details Randomly selects a token from the candidates based on their probabilities using the RNG of ctx. - LLAMA_API llama_token llama_sample_token( - struct llama_context * ctx, - llama_token_data_array * candidates); + /// @details Randomly selects a token from the candidates based on their probabilities + LLAMA_API llama_token llama_sampling_sample( + struct llama_sampling * smpl, + llama_token_data_array * candidates); + + /// @details Accepts the sampled token into the grammar + LLAMA_API void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token); // // Model split @@ -1166,8 +1120,8 @@ extern "C" { // Performance information LLAMA_API struct llama_timings llama_get_timings(struct llama_context * ctx); - LLAMA_API void llama_print_timings(struct llama_context * ctx); - LLAMA_API void llama_reset_timings(struct llama_context * ctx); + LLAMA_API void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl); + LLAMA_API void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl); // Print system information LLAMA_API const char * llama_print_system_info(void); @@ -1182,59 +1136,4 @@ extern "C" { } #endif -// Internal API to be implemented by llama.cpp and used by tests/benchmarks only -#ifdef LLAMA_API_INTERNAL - -#include -#include -#include - -struct ggml_tensor; - -const std::vector> & llama_internal_get_tensor_map( - struct llama_context * ctx -); - -struct llama_partial_utf8 { - uint32_t value; // bit value so far (unshifted) - int n_remain; // num bytes remaining; -1 indicates invalid sequence -}; - -struct llama_grammar_candidate { - size_t index; - const uint32_t * code_points; - llama_partial_utf8 partial_utf8; -}; - -using llama_grammar_rule = std::vector< llama_grammar_element>; -using llama_grammar_stack = std::vector; - -using llama_grammar_rules = std::vector; -using llama_grammar_stacks = std::vector; -using llama_grammar_candidates = std::vector; - -const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); - llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); - -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks); - -std::vector llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates); - -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start); - -// Randomly selects a token from the candidates based on their probabilities using given std::mt19937. -// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences. -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng); - -#endif // LLAMA_API_INTERNAL - #endif // LLAMA_H diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index b123d733100ce8..37703839dae9fc 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -4,75 +4,530 @@ #include "llama-sampling.h" #include +#include -// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as -// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. -std::pair, llama_partial_utf8> decode_utf8( - const std::string & src, - llama_partial_utf8 partial_start) { - static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; - const char * pos = src.c_str(); - std::vector code_points; +// +// helpers +// - // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. - code_points.reserve(src.size() + 1); - uint32_t value = partial_start.value; - int n_remain = partial_start.n_remain; +// NOTE: assumes valid utf8 (but checks for overrun) +// TODO: deduplicate +static std::pair decode_utf8(const char * src) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + uint8_t first_byte = static_cast(*src); + uint8_t highbits = first_byte >> 4; + int len = lookup[highbits]; + uint8_t mask = (1 << (8 - len)) - 1; + uint32_t value = first_byte & mask; + const char * end = src + len; // may overrun! + const char * pos = src + 1; + for ( ; pos < end && *pos; pos++) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + } + return std::make_pair(value, pos); +} - // continue previous decode, if applicable - while (*pos != 0 && n_remain > 0) { - uint8_t next_byte = static_cast(*pos); - if ((next_byte >> 6) != 2) { - // invalid sequence, abort - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); +static bool is_digit_char(char c) { + return '0' <= c && c <= '9'; +} + +static bool is_word_char(char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || is_digit_char(c); +} + +static std::pair parse_hex(const char * src, int size) { + const char * pos = src; + const char * end = src + size; + uint32_t value = 0; + for ( ; pos < end && *pos; pos++) { + value <<= 4; + char c = *pos; + if ('a' <= c && c <= 'f') { + value += c - 'a' + 10; + } else if ('A' <= c && c <= 'F') { + value += c - 'A' + 10; + } else if ('0' <= c && c <= '9') { + value += c - '0'; + } else { + break; } - value = (value << 6) + (next_byte & 0x3F); - ++pos; - --n_remain; } + if (pos != end) { + throw std::runtime_error("expecting " + std::to_string(size) + " hex chars at " + src); + } + return std::make_pair(value, pos); +} - if (partial_start.n_remain > 0 && n_remain == 0) { - code_points.push_back(value); +static const char * parse_space(const char * src, bool newline_ok) { + const char * pos = src; + while (*pos == ' ' || *pos == '\t' || *pos == '#' || + (newline_ok && (*pos == '\r' || *pos == '\n'))) { + if (*pos == '#') { + while (*pos && *pos != '\r' && *pos != '\n') { + pos++; + } + } else { + pos++; + } } + return pos; +} - // decode any subsequent utf-8 sequences, which may end in an incomplete one - while (*pos != 0) { - uint8_t first_byte = static_cast(*pos); - uint8_t highbits = first_byte >> 4; - n_remain = lookup[highbits] - 1; +static const char * parse_name(const char * src) { + const char * pos = src; + while (is_word_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting name at ") + src); + } + return pos; +} - if (n_remain < 0) { - // invalid sequence, abort - code_points.clear(); - code_points.push_back(0); - return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); +static const char * parse_int(const char * src) { + const char * pos = src; + while (is_digit_char(*pos)) { + pos++; + } + if (pos == src) { + throw std::runtime_error(std::string("expecting integer at ") + src); + } + return pos; +} + +static std::pair parse_char(const char * src) { + if (*src == '\\') { + switch (src[1]) { + case 'x': return parse_hex(src + 2, 2); + case 'u': return parse_hex(src + 2, 4); + case 'U': return parse_hex(src + 2, 8); + case 't': return std::make_pair('\t', src + 2); + case 'r': return std::make_pair('\r', src + 2); + case 'n': return std::make_pair('\n', src + 2); + case '\\': + case '"': + case '[': + case ']': + return std::make_pair(src[1], src + 2); + default: + throw std::runtime_error(std::string("unknown escape at ") + src); } + } else if (*src) { + return decode_utf8(src); + } + throw std::runtime_error("unexpected end of input"); +} - uint8_t mask = (1 << (7 - n_remain)) - 1; - value = first_byte & mask; +static void print_grammar_char(FILE * file, uint32_t c) { + if (0x20 <= c && c <= 0x7f) { + fprintf(file, "%c", static_cast(c)); + } else { + // cop out of encoding UTF-8 + fprintf(file, "", c); + } +} - ++pos; - while (*pos != 0 && n_remain > 0) { - value = (value << 6) + (static_cast(*pos) & 0x3F); - ++pos; - --n_remain; +static bool is_char_element(llama_grammar_element elem) { + switch (elem.type) { + case LLAMA_GRETYPE_CHAR: return true; + case LLAMA_GRETYPE_CHAR_NOT: return true; + case LLAMA_GRETYPE_CHAR_ALT: return true; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: return true; + case LLAMA_GRETYPE_CHAR_ANY: return true; + default: return false; + } +} + +static void print_rule_binary(FILE * file, const llama_grammar_rule & rule) { + for (auto elem : rule) { + switch (elem.type) { + case LLAMA_GRETYPE_END: fprintf(file, "END"); break; + case LLAMA_GRETYPE_ALT: fprintf(file, "ALT"); break; + case LLAMA_GRETYPE_RULE_REF: fprintf(file, "RULE_REF"); break; + case LLAMA_GRETYPE_CHAR: fprintf(file, "CHAR"); break; + case LLAMA_GRETYPE_CHAR_NOT: fprintf(file, "CHAR_NOT"); break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: fprintf(file, "CHAR_RNG_UPPER"); break; + case LLAMA_GRETYPE_CHAR_ALT: fprintf(file, "CHAR_ALT"); break; + case LLAMA_GRETYPE_CHAR_ANY: fprintf(file, "CHAR_ANY"); break; } - if (n_remain == 0) { - code_points.push_back(value); + switch (elem.type) { + case LLAMA_GRETYPE_END: + case LLAMA_GRETYPE_ALT: + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "(%u) ", elem.value); + break; + case LLAMA_GRETYPE_CHAR: + case LLAMA_GRETYPE_CHAR_NOT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "(\""); + print_grammar_char(file, elem.value); + fprintf(file, "\") "); + break; } } - code_points.push_back(0); + fprintf(file, "\n"); +} - return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +static void print_rule( + FILE * file, + uint32_t rule_id, + const llama_grammar_rule & rule, + const std::map & symbol_id_names) { + if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { + throw std::runtime_error( + "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); + } + fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); + for (size_t i = 0, end = rule.size() - 1; i < end; i++) { + llama_grammar_element elem = rule[i]; + switch (elem.type) { + case LLAMA_GRETYPE_END: + throw std::runtime_error( + "unexpected end of rule: " + std::to_string(rule_id) + "," + + std::to_string(i)); + case LLAMA_GRETYPE_ALT: + fprintf(file, "| "); + break; + case LLAMA_GRETYPE_RULE_REF: + fprintf(file, "%s ", symbol_id_names.at(elem.value).c_str()); + break; + case LLAMA_GRETYPE_CHAR: + fprintf(file, "["); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_NOT: + fprintf(file, "[^"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + fprintf(file, "-"); + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ALT: + if (i == 0 || !is_char_element(rule[i - 1])) { + throw std::runtime_error( + "LLAMA_GRETYPE_CHAR_ALT without preceding char: " + + std::to_string(rule_id) + "," + std::to_string(i)); + } + print_grammar_char(file, elem.value); + break; + case LLAMA_GRETYPE_CHAR_ANY: + fprintf(file, "."); + break; + } + if (is_char_element(elem)) { + switch (rule[i + 1].type) { + case LLAMA_GRETYPE_CHAR_ALT: + case LLAMA_GRETYPE_CHAR_RNG_UPPER: + case LLAMA_GRETYPE_CHAR_ANY: + break; + default: + fprintf(file, "] "); + } + } + } + fprintf(file, "\n"); } -const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { - return grammar->rules; +// +// implementation +// + +uint32_t llama_grammar_parser::get_symbol_id(const char * src, size_t len) { + uint32_t next_id = static_cast(symbol_ids.size()); + auto result = symbol_ids.emplace(std::string(src, len), next_id); + return result.first->second; } -llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { - return grammar->stacks; +uint32_t llama_grammar_parser::generate_symbol_id(const std::string & base_name) { + uint32_t next_id = static_cast(symbol_ids.size()); + symbol_ids[base_name + '_' + std::to_string(next_id)] = next_id; + return next_id; +} + +void llama_grammar_parser::add_rule(uint32_t rule_id, const llama_grammar_rule & rule) { + if (rules.size() <= rule_id) { + rules.resize(rule_id + 1); + } + rules[rule_id] = rule; +} + +const char * llama_grammar_parser::parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested) { + llama_grammar_rule rule; + const char * pos = parse_sequence(src, rule_name, rule, is_nested); + while (*pos == '|') { + rule.push_back({LLAMA_GRETYPE_ALT, 0}); + pos = parse_space(pos + 1, true); + pos = parse_sequence(pos, rule_name, rule, is_nested); + } + rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule(rule_id, rule); + return pos; +} + +const char * llama_grammar_parser::parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested) { + size_t last_sym_start = rule.size(); + const char * pos = src; + + auto handle_repetitions = [&](int min_times, int max_times) { + + if (last_sym_start == rule.size()) { + throw std::runtime_error(std::string("expecting preceding item to */+/?/{ at ") + pos); + } + + // apply transformation to previous symbol (last_sym_start to end) according to + // the following rewrite rules: + // S{m,n} --> S S S (m times) S'(n-m) + // S'(x) ::= S S'(x-1) | + // (... n-m definitions of these S' rules ...) + // S'(1) ::= S | + // S{m,} --> S S S (m times) S' + // S' ::= S S' | + // S* --> S{0,} + // --> S' ::= S S' | + // S+ --> S{1,} + // --> S S' + // S' ::= S S' | + // S? --> S{0,1} + // --> S' + // S' ::= S | + + llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + if (min_times == 0) { + rule.resize(last_sym_start); + } else { + // Repeat the previous elements (min_times - 1) times + for (int i = 1; i < min_times; i++) { + rule.insert(rule.end(), prev_rule.begin(), prev_rule.end()); + } + } + + uint32_t last_rec_rule_id = 0; + auto n_opt = max_times < 0 ? 1 : max_times - min_times; + + llama_grammar_rule rec_rule(prev_rule); + for (int i = 0; i < n_opt; i++) { + rec_rule.resize(prev_rule.size()); + uint32_t rec_rule_id = generate_symbol_id( rule_name); + if (i > 0 || max_times < 0) { + rec_rule.push_back({LLAMA_GRETYPE_RULE_REF, max_times < 0 ? rec_rule_id : last_rec_rule_id}); + } + rec_rule.push_back({LLAMA_GRETYPE_ALT, 0}); + rec_rule.push_back({LLAMA_GRETYPE_END, 0}); + add_rule( rec_rule_id, rec_rule); + last_rec_rule_id = rec_rule_id; + } + if (n_opt > 0) { + rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); + } + }; + + while (*pos) { + if (*pos == '"') { // literal string + pos++; + last_sym_start = rule.size(); + while (*pos != '"') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR, char_pair.first}); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '[') { // char range(s) + pos++; + enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; + if (*pos == '^') { + pos++; + start_type = LLAMA_GRETYPE_CHAR_NOT; + } + last_sym_start = rule.size(); + while (*pos != ']') { + if (!*pos) { + throw std::runtime_error("unexpected end of input"); + } + auto char_pair = parse_char(pos); + pos = char_pair.second; + enum llama_gretype type = last_sym_start < rule.size() + ? LLAMA_GRETYPE_CHAR_ALT + : start_type; + + rule.push_back({type, char_pair.first}); + if (pos[0] == '-' && pos[1] != ']') { + if (!pos[1]) { + throw std::runtime_error("unexpected end of input"); + } + auto endchar_pair = parse_char(pos + 1); + pos = endchar_pair.second; + rule.push_back({LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair.first}); + } + } + pos = parse_space(pos + 1, is_nested); + } else if (is_word_char(*pos)) { // rule reference + const char * name_end = parse_name(pos); + uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); + pos = parse_space(name_end, is_nested); + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); + } else if (*pos == '(') { // grouping + // parse nested alternates into synthesized rule + pos = parse_space(pos + 1, true); + uint32_t sub_rule_id = generate_symbol_id(rule_name); + pos = parse_alternates(pos, rule_name, sub_rule_id, true); + last_sym_start = rule.size(); + // output reference to synthesized rule + rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); + if (*pos != ')') { + throw std::runtime_error(std::string("expecting ')' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '.') { // any char + last_sym_start = rule.size(); + rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); + pos = parse_space(pos + 1, is_nested); + } else if (*pos == '*') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, -1); + } else if (*pos == '+') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(1, -1); + } else if (*pos == '?') { + pos = parse_space(pos + 1, is_nested); + handle_repetitions(0, 1); + } else if (*pos == '{') { + pos = parse_space(pos + 1, is_nested); + + if (!is_digit_char(*pos)) { + throw std::runtime_error(std::string("expecting an int at ") + pos); + } + const char * int_end = parse_int(pos); + int min_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + + int max_times = -1; + + if (*pos == '}') { + max_times = min_times; + pos = parse_space(pos + 1, is_nested); + } else if (*pos == ',') { + pos = parse_space(pos + 1, is_nested); + + if (is_digit_char(*pos)) { + const char * int_end = parse_int(pos); + max_times = std::stoul(std::string(pos, int_end - pos)); + pos = parse_space(int_end, is_nested); + } + + if (*pos != '}') { + throw std::runtime_error(std::string("expecting '}' at ") + pos); + } + pos = parse_space(pos + 1, is_nested); + } else { + throw std::runtime_error(std::string("expecting ',' at ") + pos); + } + handle_repetitions(min_times, max_times); + } else { + break; + } + } + return pos; + } + +const char * llama_grammar_parser::parse_rule(const char * src) { + const char * name_end = parse_name(src); + const char * pos = parse_space(name_end, false); + size_t name_len = name_end - src; + uint32_t rule_id = get_symbol_id(src, name_len); + const std::string name(src, name_len); + + if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) { + throw std::runtime_error(std::string("expecting ::= at ") + pos); + } + pos = parse_space(pos + 3, true); + + pos = parse_alternates(pos, name, rule_id, false); + + if (*pos == '\r') { + pos += pos[1] == '\n' ? 2 : 1; + } else if (*pos == '\n') { + pos++; + } else if (*pos) { + throw std::runtime_error(std::string("expecting newline or end at ") + pos); + } + return parse_space(pos, true); + } + +bool llama_grammar_parser::parse(const char * src) { + try { + const char * pos = parse_space(src, true); + while (*pos) { + pos = parse_rule(pos); + } + // Validate the state to ensure that all rules are defined + for (const auto & rule : rules) { + for (const auto & elem : rule) { + if (elem.type == LLAMA_GRETYPE_RULE_REF) { + // Ensure that the rule at that location exists + if (elem.value >= rules.size() || rules[elem.value].empty()) { + // Get the name of the rule that is missing + for (const auto & kv : symbol_ids) { + if (kv.second == elem.value) { + throw std::runtime_error("Undefined rule identifier '" + kv.first + "'"); + } + } + } + } + } + } + } catch (const std::exception & err) { + fprintf(stderr, "%s: error parsing grammar: %s\n", __func__, err.what()); + rules.clear(); + return false; + } + + return true; +} + +void llama_grammar_parser::print(FILE * file) { + try { + std::map symbol_id_names; + for (const auto & kv : symbol_ids) { + symbol_id_names[kv.second] = kv.first; + } + for (size_t i = 0, end = rules.size(); i < end; i++) { + // fprintf(file, "%zu: ", i); + // print_rule_binary(file, rules[i]); + print_rule(file, uint32_t(i), rules[i], symbol_id_names); + // fprintf(file, "\n"); + } + } catch (const std::exception & err) { + fprintf(stderr, "\n%s: error printing grammar: %s\n", __func__, err.what()); + } +} + +llama_grammar_stack llama_grammar_parser::c_rules() const { + llama_grammar_stack ret; + ret.reserve(rules.size()); + for (const auto & rule : rules) { + ret.push_back(rule.data()); + } + return ret; } // returns true iff pos points to the end of one of the definitions of a rule @@ -89,7 +544,6 @@ static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) static std::pair llama_grammar_match_char( const llama_grammar_element * pos, const uint32_t chr) { - bool found = false; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR || pos->type == LLAMA_GRETYPE_CHAR_ANY; @@ -225,36 +679,6 @@ static void llama_grammar_advance_stack( } } -// takes a set of possible pushdown stacks on a grammar, which are required to -// be positioned at a character range (see `llama_grammar_advance_stack`), and -// produces the N possible stacks if the given char is accepted at those -// positions -void llama_grammar_accept( - const llama_grammar_rules & rules, - const llama_grammar_stacks & stacks, - const uint32_t chr, - llama_grammar_stacks & new_stacks) { - new_stacks.clear(); - - for (const auto & stack : stacks) { - if (stack.empty()) { - continue; - } - - auto match = llama_grammar_match_char(stack.back(), chr); - if (match.first) { - const llama_grammar_element * pos = match.second; - - // update top of stack to next element, if any - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(pos)) { - new_stack.push_back(pos); - } - llama_grammar_advance_stack(rules, new_stack, new_stacks); - } - } -} - static llama_grammar_candidates llama_grammar_reject_candidates( const llama_grammar_rules & rules, const llama_grammar_stacks & stacks, @@ -273,61 +697,6 @@ static llama_grammar_candidates llama_grammar_reject_candidates( return rejects; } -llama_grammar_candidates llama_grammar_reject_candidates_for_stack( - const llama_grammar_rules & rules, - const llama_grammar_stack & stack, - const llama_grammar_candidates & candidates) { - - llama_grammar_candidates rejects; - rejects.reserve(candidates.size()); - - if (stack.empty()) { - for (const auto & tok : candidates) { - if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { - rejects.push_back(tok); - } - } - return rejects; - } - - const llama_grammar_element * stack_pos = stack.back(); - - llama_grammar_candidates next_candidates; - next_candidates.reserve(candidates.size()); - - for (const auto & tok : candidates) { - if (*tok.code_points == 0) { - // reached end of full codepoints in token, reject iff it ended in a partial sequence - // that cannot satisfy this position in grammar - if (tok.partial_utf8.n_remain != 0 && - !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { - rejects.push_back(tok); - } - } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { - next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); - } else { - rejects.push_back(tok); - } - } - - const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; - - // update top of stack to next element, if any - llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); - if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { - stack_after.push_back(stack_pos_after); - } - llama_grammar_stacks next_stacks; - llama_grammar_advance_stack(rules, stack_after, next_stacks); - - auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); - for (const auto & tok : next_rejects) { - rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); - } - - return rejects; -} - static bool llama_grammar_detect_left_recursion( const llama_grammar_rules & rules, size_t rule_index, @@ -380,14 +749,11 @@ static bool llama_grammar_detect_left_recursion( return false; } -// -// grammar - external -// - struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { + const struct llama_vocab & vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index) { const llama_grammar_element * pos; // copy rule definitions into vectors @@ -438,22 +804,100 @@ struct llama_grammar * llama_grammar_init_impl( // Important: vec_rules has to be moved here, not copied, because stacks contains // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar // then the pointers would be invalidated when the local vec_rules goes out of scope. - return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; + return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; +} + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + llama_grammar_parser parser; + + // if there is a grammar, parse it + if (!parser.parse(grammar_str)) { + return nullptr; + } + + // will be empty (default) if there are parse errors + if (parser.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + // Ensure that there is a "root" node. + if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { + fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + return nullptr; + } + + std::vector grammar_rules(parser.c_rules()); + + const size_t n_rules = grammar_rules.size(); + const size_t start_rule_index = parser.symbol_ids.at(grammar_root); + + const llama_grammar_element * pos; + + // copy rule definitions into vectors + llama_grammar_rules vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = grammar_rules[i]; pos->type != LLAMA_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); + } + + // Check for left recursion + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); + for (size_t i = 0; i < n_rules; i++) { + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + return nullptr; + } + } + + // loop over alternates of start rule to build initial stacks + llama_grammar_stacks stacks; + pos = vec_rules[start_rule_index].data(); + do { + llama_grammar_stack stack; + if (!llama_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + llama_grammar_advance_stack(vec_rules, stack, stacks); + while (!llama_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == LLAMA_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. + return new llama_grammar{ vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 }; } void llama_grammar_free_impl(struct llama_grammar * grammar) { delete grammar; } -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar) { - llama_grammar * result = new llama_grammar{ grammar->rules, grammar->stacks, grammar->partial_utf8 }; +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar) { + llama_grammar * result = new llama_grammar{ grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 }; // redirect elements in stacks to point to new rules for (size_t is = 0; is < result->stacks.size(); is++) { for (size_t ie = 0; ie < result->stacks[is].size(); ie++) { - for (size_t ir0 = 0; ir0 < grammar->rules.size(); ir0++) { - for (size_t ir1 = 0; ir1 < grammar->rules[ir0].size(); ir1++) { - if (grammar->stacks[is][ie] == &grammar->rules[ir0][ir1]) { + for (size_t ir0 = 0; ir0 < grammar.rules.size(); ir0++) { + for (size_t ir1 = 0; ir1 < grammar.rules[ir0].size(); ir1++) { + if (grammar.stacks[is][ie] == &grammar.rules[ir0][ir1]) { result->stacks[is][ie] = &result->rules[ir0][ir1]; } } @@ -464,14 +908,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram return result; } -void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) { - GGML_ASSERT(grammar); - GGML_ASSERT(vocab); - - int64_t t_start_sample_us = ggml_time_us(); - +void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_data_array * candidates) { bool allow_eog = false; - for (const auto & stack : grammar->stacks) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { allow_eog = true; break; @@ -486,33 +925,29 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string & piece = vocab->cache_token_to_piece.at(id); + const std::string & piece = grammar.vocab.cache_token_to_piece.at(id); - if (llama_token_is_eog_impl(*vocab, id)) { + if (llama_token_is_eog_impl(grammar.vocab, id)) { if (!allow_eog) { candidates->data[i].logit = -INFINITY; } } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar.partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + const auto rejects = llama_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) { - const int64_t t_start_sample_us = ggml_time_us(); - - if (llama_token_is_eog_impl(*vocab, token)) { - for (const auto & stack : grammar->stacks) { +void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { + if (llama_token_is_eog_impl(grammar.vocab, token)) { + for (const auto & stack : grammar.stacks) { if (stack.empty()) { return; } @@ -520,20 +955,169 @@ void llama_grammar_accept_token_impl(struct llama_grammar * grammar, const struc GGML_ABORT("fatal error"); } - const std::string & piece = vocab->cache_token_to_piece.at(token); + const std::string & piece = grammar.vocab.cache_token_to_piece.at(token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece, grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto & code_points = decoded.first; llama_grammar_stacks tmp_new_stacks; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { - llama_grammar_accept(grammar->rules, grammar->stacks, *it, tmp_new_stacks); - grammar->stacks = tmp_new_stacks; + llama_grammar_accept(grammar.rules, grammar.stacks, *it, tmp_new_stacks); + grammar.stacks = tmp_new_stacks; } - grammar->partial_utf8 = decoded.second; - GGML_ASSERT(!grammar->stacks.empty()); + grammar.partial_utf8 = decoded.second; + GGML_ASSERT(!grammar.stacks.empty()); +} + +//////////////////// + +const llama_grammar_rules & llama_grammar_get_rules(const struct llama_grammar * grammar) { + return grammar->rules; +} - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +llama_grammar_stacks & llama_grammar_get_stacks(struct llama_grammar * grammar) { + return grammar->stacks; +} + +std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src.c_str(); + std::vector code_points; + + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(src.size() + 1); + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), llama_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); +} + +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr, + llama_grammar_stacks & new_stacks) { + new_stacks.clear(); + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = llama_grammar_match_char(stack.back(), chr); + if (match.first) { + const llama_grammar_element * pos = match.second; + + // update top of stack to next element, if any + llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + llama_grammar_advance_stack(rules, new_stack, new_stacks); + } + } +} + +llama_grammar_candidates llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates) { + + llama_grammar_candidates rejects; + rejects.reserve(candidates.size()); + + if (stack.empty()) { + for (const auto & tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const llama_grammar_element * stack_pos = stack.back(); + + llama_grammar_candidates next_candidates; + next_candidates.reserve(candidates.size()); + + for (const auto & tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && + !llama_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (llama_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.index, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + llama_grammar_stack stack_after(stack.begin(), stack.end() - 1); + if (!llama_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + llama_grammar_stacks next_stacks; + llama_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = llama_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (const auto & tok : next_rejects) { + rejects.push_back({ tok.index, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; } diff --git a/src/llama-grammar.h b/src/llama-grammar.h index 695ea0632bb84c..2fbb954088fc5c 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -2,38 +2,153 @@ #include "llama-impl.h" +#include + struct llama_vocab; -struct llama_sampling; + +// grammar element type +enum llama_gretype { + // end of rule definition + LLAMA_GRETYPE_END = 0, + + // start of alternate definition for rule + LLAMA_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + LLAMA_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + LLAMA_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + LLAMA_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding LLAMA_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + LLAMA_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding LLAMA_GRETYPE_CHAR or + // LLAMA_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + LLAMA_GRETYPE_CHAR_ALT = 6, + + // any character (.) + LLAMA_GRETYPE_CHAR_ANY = 7, +}; + +typedef struct llama_grammar_element { + enum llama_gretype type; + uint32_t value; // Unicode code point or rule ID +} llama_grammar_element; + +struct llama_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct llama_grammar_candidate { + size_t index; + const uint32_t * code_points; + llama_partial_utf8 partial_utf8; +}; + +using llama_grammar_rule = std::vector< llama_grammar_element>; +using llama_grammar_stack = std::vector; + +using llama_grammar_rules = std::vector; +using llama_grammar_stacks = std::vector; +using llama_grammar_candidates = std::vector; + +const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar); + llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar); + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. +std::pair, llama_partial_utf8> decode_utf8( + const std::string & src, + llama_partial_utf8 partial_start); + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `llama_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +void llama_grammar_accept( + const llama_grammar_rules & rules, + const llama_grammar_stacks & stacks, + const uint32_t chr, + llama_grammar_stacks & new_stacks); + +std::vector llama_grammar_reject_candidates_for_stack( + const llama_grammar_rules & rules, + const llama_grammar_stack & stack, + const llama_grammar_candidates & candidates); + +struct llama_grammar_parser { + std::map symbol_ids; + + llama_grammar_rules rules; + + llama_grammar_stack c_rules() const; + + uint32_t get_symbol_id(const char * src, size_t len); + uint32_t generate_symbol_id(const std::string & base_name); + + void add_rule(uint32_t rule_id, const llama_grammar_rule & rule); + + const char * parse_alternates( + const char * src, + const std::string & rule_name, + uint32_t rule_id, + bool is_nested); + + const char * parse_sequence( + const char * src, + const std::string & rule_name, + llama_grammar_rule & rule, + bool is_nested); + + const char * parse_rule(const char * src); + + bool parse(const char * src); + void print(FILE * file); +}; struct llama_grammar { - const llama_grammar_rules rules; + const llama_vocab & vocab; + + const llama_grammar_rules rules; // TODO: shared ptr llama_grammar_stacks stacks; // buffer for partially generated UTF-8 sequence from accepted tokens llama_partial_utf8 partial_utf8; + + mutable int64_t t_total_us; + + mutable int32_t n_sample; + mutable int32_t n_accept; }; // // internal API // +// TODO: temporary until the tests are fixed struct llama_grammar * llama_grammar_init_impl( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index); + const struct llama_vocab & vocab, + const llama_grammar_element ** rules, + size_t n_rules, + size_t start_rule_index); + +struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); void llama_grammar_free_impl(struct llama_grammar * grammar); -struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * grammar); +struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar & grammar); -void llama_grammar_sample_impl( - const struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +// TODO: move the API below as member functions of llama_grammar +void llama_grammar_apply_impl( + const struct llama_grammar & grammar, llama_token_data_array * candidates); -void llama_grammar_accept_token_impl( - struct llama_grammar * grammar, - const struct llama_vocab * vocab, - const struct llama_sampling * smpl, +void llama_grammar_accept_impl( + struct llama_grammar & grammar, llama_token token); diff --git a/src/llama-impl.h b/src/llama-impl.h index 399b134a7f9bc6..f71c1ec02d0037 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -1,8 +1,11 @@ #pragma once -#define LLAMA_API_INTERNAL #include "llama.h" +#include +#include +#include + #ifdef __GNUC__ #ifdef __MINGW32__ #define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) @@ -39,3 +42,7 @@ static void replace_all(std::string & s, const std::string & search, const std:: pos += replace.length(); } } + +const std::vector> & llama_internal_get_tensor_map( + struct llama_context * ctx +); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 8910f6d6542e91..bf2e88e71def19 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1,5 +1,8 @@ #include "llama-sampling.h" +#include "llama-vocab.h" +#include "llama-grammar.h" + #include #include #include @@ -21,19 +24,66 @@ static void llama_log_softmax(float * array, size_t size) { } } -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) { +llama_sampling::llama_sampling(uint32_t n_vocab) : n_vocab(n_vocab) { +} + +llama_sampling::llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) : n_vocab(vocab.n_vocab) { + if (grammar_str != nullptr && grammar_str[0] != '\0') { + grammar = llama_grammar_init_impl(vocab, grammar_str, grammar_root); + } +} + +llama_sampling::~llama_sampling() { + if (grammar) { + llama_grammar_free_impl(grammar); + } +} + +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { + return new llama_sampling(vocab, grammar_str, grammar_root); +} + +void llama_sampling_free_impl(struct llama_sampling * sampling) { + delete sampling; +} + +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl) { + auto * result = new llama_sampling(smpl.n_vocab); + + if (smpl.grammar) { + result->grammar = llama_grammar_copy_impl(*smpl.grammar); + } + + return result; +} + +void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root) { + // TODO: this is dumb, need to fix + const struct llama_vocab * vocab = nullptr; + + if (smpl.grammar) { + vocab = &smpl.grammar->vocab; + + llama_grammar_free_impl(smpl.grammar); + smpl.grammar = nullptr; + } + + if (grammar_str != nullptr && grammar_str[0] != '\0') { + smpl.grammar = llama_grammar_init_impl(*vocab, grammar_str, grammar_root); + } +} + +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed) { if (seed == LLAMA_DEFAULT_SEED) { seed = time(NULL); } - smpl->rng.seed(seed); + smpl.rng.seed(seed); } -void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { +void llama_sampling_softmax_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); - const int64_t t_start_sample_us = ggml_time_us(); - // Sort the logits in descending order if (!candidates->sorted) { std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { @@ -44,28 +94,24 @@ void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_ar float max_l = candidates->data[0].logit; float cum_sum = 0.0f; + for (size_t i = 0; i < candidates->size; ++i) { float p = expf(candidates->data[i].logit - max_l); candidates->data[i].p = p; cum_sum += p; } + for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].p /= cum_sum; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { +void llama_sampling_top_k_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, int32_t k, size_t min_keep) { // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast // if (k >= (int32_t)candidates->size) { // return; // } - const int64_t t_start_sample_us = ggml_time_us(); - if (k <= 0) { k = candidates->size; } @@ -133,20 +179,14 @@ void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_arra candidates->sorted = true; } candidates->size = k; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_top_p_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { if (p >= 1.0f) { return; } - llama_sample_softmax_impl(smpl, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(smpl, candidates); // Compute the cumulative probabilities float cum_sum = 0.0f; @@ -165,19 +205,13 @@ void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the top-p tokens candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_min_p_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float p, size_t min_keep) { if (p <= 0.0f || !candidates->size) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - bool min_p_applied = false; // if the candidates aren't sorted, try the unsorted implementation first @@ -226,19 +260,14 @@ void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_arra // Resize the output vector to keep only the matching tokens candidates->size = i; } - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { +void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep) { if (z >= 1.0f || candidates->size <= 2) { return; } - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(smpl, candidates); // Compute the first and second derivatives std::vector first_derivatives(candidates->size - 1); @@ -285,13 +314,9 @@ void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_ // Resize the output vector to keep only the tokens above the tail location candidates->size = last_idx; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { +void llama_sampling_typical_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr if (p >= 1.0f) { @@ -299,9 +324,7 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar } // Compute the softmax of logits and calculate entropy - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); - - const int64_t t_start_sample_us = ggml_time_us(); + llama_sampling_softmax_impl(smpl, candidates); float entropy = 0.0f; for (size_t i = 0; i < candidates->size; ++i) { @@ -349,15 +372,9 @@ void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_ar std::copy(new_candidates.begin(), new_candidates.end(), candidates->data); candidates->size = new_candidates.size(); candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_entropy_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) { // no need to do anything if there is only one (or zero) candidates if(candidates->size <= 1) { return; @@ -366,7 +383,7 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar // Calculate maximum possible entropy float max_entropy = -logf(1.0f / candidates->size); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Calculate entropy of the softmax probabilities float entropy = 0.0f; @@ -416,38 +433,32 @@ void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_ar LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f); } #endif - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) { - const int64_t t_start_sample_us = ggml_time_us(); - +void llama_sampling_temp_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, float temp) { for (size_t i = 0; i < candidates->size; ++i) { candidates->data[i].logit /= temp; } +} - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; +void llama_sampling_grammar_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { + if (smpl.grammar) { + llama_grammar_apply_impl(*smpl.grammar, candidates); } } -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +void llama_sampling_repetition_penalties_impl( + struct llama_sampling & /*smpl*/, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { + float penalty_repeat, + float penalty_freq, + float penalty_present) { if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } - const int64_t t_start_sample_us = ggml_time_us(); - // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; for (size_t i = 0; i < penalty_last_n; ++i) { @@ -475,43 +486,30 @@ void llama_sample_repetition_penalties_impl( } candidates->sorted = false; - - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } } -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, +void llama_sampling_apply_guidance_impl( + struct llama_sampling & smpl, float * logits, float * logits_guidance, float scale) { - GGML_ASSERT(smpl); - - const auto t_start_sample_us = ggml_time_us(); - const auto n_vocab = smpl->n_vocab; + const auto n_vocab = smpl.n_vocab; llama_log_softmax(logits, n_vocab); llama_log_softmax(logits_guidance, n_vocab); - for (int i = 0; i < n_vocab; ++i) { + for (uint32_t i = 0; i < n_vocab; ++i) { auto & l = logits[i]; const auto & g = logits_guidance[i]; l = scale * (l - g) + g; } - - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; } -llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - GGML_ASSERT(smpl); - - const int32_t n_vocab = float(smpl->n_vocab); +llama_token llama_sampling_sample_mirostat_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + const int32_t n_vocab = float(smpl.n_vocab); - int64_t t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Estimate s_hat using the most probable m tokens float s_hat = 0.0; @@ -530,10 +528,8 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat); // Sample the next word X using top-k sampling - llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1); - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_sampling_top_k_impl(smpl, candidates, int(k), 1); + llama_token X = llama_sampling_sample_impl(smpl, candidates); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -545,15 +541,11 @@ llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama // Update mu using the learning rate and error *mu = *mu - eta * e; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; return X; } -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { - int64_t t_start_sample_us; - t_start_sample_us = ggml_time_us(); - - llama_sample_softmax_impl(smpl, candidates); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { + llama_sampling_softmax_impl(smpl, candidates); // Truncate the words with surprise values greater than mu candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -564,16 +556,11 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll candidates->size = 1; } - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } - // Normalize the probabilities of the remaining words - llama_sample_softmax_impl(smpl, candidates); + llama_sampling_softmax_impl(smpl, candidates); // Sample the next word X from the remaining words - llama_token X = llama_sample_token_impl(smpl, candidates); - t_start_sample_us = ggml_time_us(); + llama_token X = llama_sampling_sample_impl(smpl, candidates); // Compute error as the difference between observed surprise and target surprise value size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) { @@ -585,33 +572,22 @@ llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, ll // Update mu using the learning rate and error *mu = *mu - eta * e; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - } return X; } -llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - const int64_t t_start_sample_us = ggml_time_us(); - +llama_token llama_sampling_sample_greedy_impl(struct llama_sampling & /*smpl*/, llama_token_data_array * candidates) { // Find max element auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) { return a.logit < b.logit; }); llama_token result = max_iter->id; - if (smpl) { - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - } + return result; } -llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) { - GGML_ASSERT(smpl); - - const int64_t t_start_sample_us = ggml_time_us(); - llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates); +llama_token llama_sampling_sample_with_rng_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng) { + llama_sampling_softmax_impl(smpl, candidates); std::vector probs; probs.reserve(candidates->size); @@ -624,12 +600,17 @@ llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama llama_token result = candidates->data[idx].id; - smpl->t_sample_us += ggml_time_us() - t_start_sample_us; - smpl->n_sample++; - return result; } -llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng); +llama_token llama_sampling_sample_impl(struct llama_sampling & smpl, llama_token_data_array * candidates) { + return llama_sampling_sample_with_rng_impl(smpl, candidates, smpl.rng); +} + +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token) { + // TODO: implement token storing in history + + if (smpl.grammar) { + llama_grammar_accept_impl(*smpl.grammar, token); + } } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index f7f8e3ef706bc8..8ee593c723c6ba 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -1,40 +1,54 @@ #pragma once #include "llama-impl.h" +#include "llama-grammar.h" + +struct llama_vocab; +struct llama_grammar; struct llama_sampling { - llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {} + llama_sampling(uint32_t n_vocab); + llama_sampling(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + ~llama_sampling(); + + const uint32_t n_vocab; std::mt19937 rng; - int32_t n_vocab = 0; + struct llama_grammar * grammar = nullptr; - mutable int64_t t_sample_us = 0; - mutable int32_t n_sample = 0; + mutable int64_t t_total_us = 0; - void reset_timings() const { - t_sample_us = 0; - n_sample = 0; - } + mutable int32_t n_sample = 0; }; // // internal API // -void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed); +struct llama_sampling * llama_sampling_init_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +void llama_sampling_free_impl(struct llama_sampling * sampling); + +struct llama_sampling * llama_sampling_cp_impl(const struct llama_sampling & smpl); + +void llama_sampling_reset_impl(struct llama_sampling & smpl, const char * grammar_str, const char * grammar_root); + +// TODO: move the API below as member functions of llama_sampling +void llama_sampling_set_rng_seed_impl(struct llama_sampling & smpl, uint32_t seed); -void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); -void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep); -void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep); -void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); -void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp); +void llama_sampling_softmax_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +void llama_sampling_top_k_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep); +void llama_sampling_top_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_min_p_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_tail_free_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float z, size_t min_keep); +void llama_sampling_typical_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float p, size_t min_keep); +void llama_sampling_entropy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val); +void llama_sampling_temp_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float temp); +void llama_sampling_grammar_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); -void llama_sample_repetition_penalties_impl( - struct llama_sampling * smpl, +void llama_sampling_repetition_penalties_impl( + struct llama_sampling & smpl, llama_token_data_array * candidates, const llama_token * last_tokens, size_t penalty_last_n, @@ -42,15 +56,16 @@ void llama_sample_repetition_penalties_impl( float penalty_freq, float penalty_present); -void llama_sample_apply_guidance_impl( - struct llama_sampling * smpl, +void llama_sampling_apply_guidance_impl( + struct llama_sampling & smpl, float * logits, float * logits_guidance, float scale); -llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); -llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); -llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); -llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng); -llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates); +llama_token llama_sampling_sample_mirostat_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu); +llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_sampling & smpl, llama_token_data_array * candidates, float tau, float eta, float * mu); +llama_token llama_sampling_sample_greedy_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +llama_token llama_sampling_sample_with_rng_impl (struct llama_sampling & smpl, llama_token_data_array * candidates, std::mt19937 & rng); +llama_token llama_sampling_sample_impl (struct llama_sampling & smpl, llama_token_data_array * candidates); +void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token); diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 7adfc16da3af38..f48bbe1bff22f1 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -18,6 +18,8 @@ struct llama_vocab { tattr attr; }; + uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab + enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; @@ -62,8 +64,6 @@ struct llama_vocab { int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; }; -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx); - // // internal API // @@ -76,6 +76,7 @@ std::vector llama_tokenize_internal( bool add_special, bool parse_special = false); +// TODO: move the API below as member functions of llama_vocab llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token); diff --git a/src/llama.cpp b/src/llama.cpp index 97dd1b3fea4b94..241689e6a44fcc 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,6 +1,5 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-grammar.h" #include "llama-sampling.h" #include "unicode.h" @@ -148,6 +147,19 @@ static void zeros(std::ofstream & file, size_t n) { } } +struct time_meas { + time_meas(int64_t & t_acc) : t_start_us(ggml_time_us()), t_acc(t_acc) {} + + ~time_meas() { + t_acc += ggml_time_us() - t_start_us; + } + + const int64_t t_start_us; + + int64_t & t_acc; +}; + + LLAMA_ATTRIBUTE_FORMAT(1, 2) static std::string format(const char * fmt, ...) { va_list ap; @@ -2641,7 +2653,7 @@ struct llama_model { struct llama_context { llama_context(const llama_model & model) : model(model) - , sampling(llama_n_vocab(&model)) + , sampling(model.vocab, nullptr, nullptr) // by default, no grammar , t_start_us(model.t_start_us) , t_load_us(model.t_load_us) {} @@ -2675,16 +2687,16 @@ struct llama_context { bool has_evaluated_once = false; - int64_t t_start_us; - int64_t t_load_us; - int64_t t_p_eval_us = 0; - int64_t t_eval_us = 0; + mutable int64_t t_start_us; + mutable int64_t t_load_us; + mutable int64_t t_p_eval_us = 0; + mutable int64_t t_eval_us = 0; - int64_t t_compute_start_us = 0; - int64_t n_queued_tokens = 0; + mutable int64_t t_compute_start_us = 0; + mutable int64_t n_queued_tokens = 0; - int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) - int32_t n_eval = 0; // number of eval calls + mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) + mutable int32_t n_eval = 0; // number of eval calls // host buffer for the model output (logits and embeddings) ggml_backend_buffer_t buf_output = nullptr; @@ -5477,6 +5489,7 @@ static void llm_load_vocab( const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); + vocab.n_vocab = n_vocab; vocab.id_to_token.resize(n_vocab); for (uint32_t i = 0; i < n_vocab; i++) { @@ -16565,8 +16578,9 @@ struct llama_context * llama_new_context_with_model( ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; - ctx->sampling.rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; + llama_sampling_set_rng_seed_impl(ctx->sampling, params.seed); + + ctx->logits_all = params.logits_all; uint32_t kv_size = cparams.n_ctx; ggml_type type_k = params.type_k; @@ -16842,14 +16856,6 @@ void llama_free(struct llama_context * ctx) { delete ctx; } -const struct llama_model * llama_get_model(const struct llama_context * ctx) { - return &ctx->model; -} - -const struct llama_vocab * llama_get_vocab(const struct llama_context * ctx) { - return &ctx->model.vocab; -} - uint32_t llama_n_ctx(const struct llama_context * ctx) { return ctx->cparams.n_ctx; } @@ -16870,6 +16876,34 @@ enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { return model->vocab.type; } +int32_t llama_n_vocab(const struct llama_model * model) { + return model->hparams.n_vocab; +} + +int32_t llama_n_ctx_train(const struct llama_model * model) { + return model->hparams.n_ctx_train; +} + +int32_t llama_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int32_t llama_n_layer(const struct llama_model * model) { + return model->hparams.n_layer; +} + +const struct llama_model * llama_get_model(const struct llama_context * ctx) { + return &ctx->model; +} + +struct llama_sampling * llama_get_sampling(struct llama_context * ctx) { + return &ctx->sampling; +} + +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + enum llama_rope_type llama_rope_type(const struct llama_model * model) { switch (model->arch) { // these models do not use RoPE @@ -16929,26 +16963,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } -enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { - return ctx->cparams.pooling_type; -} - -int32_t llama_n_vocab(const struct llama_model * model) { - return model->hparams.n_vocab; -} - -int32_t llama_n_ctx_train(const struct llama_model * model) { - return model->hparams.n_ctx_train; -} - -int32_t llama_n_embd(const struct llama_model * model) { - return model->hparams.n_embd; -} - -int32_t llama_n_layer(const struct llama_model * model) { - return model->hparams.n_layer; -} - float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -18891,124 +18905,164 @@ int32_t llama_chat_apply_template( } // -// grammar +// sampling // -struct llama_grammar * llama_grammar_init( - const llama_grammar_element ** rules, - size_t n_rules, - size_t start_rule_index) { - return llama_grammar_init_impl(rules, n_rules, start_rule_index); +struct llama_sampling * llama_sampling_init(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { + return llama_sampling_init_impl(model->vocab, grammar_str, grammar_root); } -void llama_grammar_free(struct llama_grammar * grammar) { - llama_grammar_free_impl(grammar); -} +void llama_sampling_free(struct llama_sampling * smpl) { + if (smpl == nullptr) { + return; + } -struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) { - return llama_grammar_copy_impl(grammar); + llama_sampling_free_impl(smpl); } -void llama_grammar_sample( - const struct llama_grammar * grammar, - const struct llama_context * ctx, - llama_token_data_array * candidates) { - llama_grammar_sample_impl(grammar, &ctx->model.vocab, &ctx->sampling, candidates); +struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl) { + return llama_sampling_cp_impl(*smpl); } -void llama_sample_grammar( - struct llama_context * ctx, - llama_token_data_array * candidates, - const struct llama_grammar * grammar) { - llama_grammar_sample(grammar, ctx, candidates); +void llama_sampling_reset(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) { + llama_sampling_reset_impl(*smpl, grammar_str, grammar_root); } -void llama_grammar_accept_token( - struct llama_grammar * grammar, - struct llama_context * ctx, - llama_token token) { - llama_grammar_accept_token_impl(grammar, &ctx->model.vocab, &ctx->sampling, token); +void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) { + llama_sampling_set_rng_seed_impl(*smpl, seed); } -// -// sampling -// +void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - llama_set_rng_seed_impl(&ctx->sampling, seed); + llama_sampling_softmax_impl(*smpl, candidates); } -void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - llama_sample_softmax_impl(ctx ? &ctx->sampling : nullptr, candidates); +void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) { + time_meas tm(smpl->t_total_us); + + llama_sampling_top_k_impl(*smpl, candidates, k, min_keep); } -void llama_sample_top_k(struct llama_context * ctx, llama_token_data_array * candidates, int32_t k, size_t min_keep) { - llama_sample_top_k_impl(ctx ? &ctx->sampling : nullptr, candidates, k, min_keep); +void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); + + llama_sampling_top_p_impl(*smpl, candidates, p, min_keep); } -void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_top_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); + + llama_sampling_min_p_impl(*smpl, candidates, p, min_keep); } -void llama_sample_min_p(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_min_p_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) { + time_meas tm(smpl->t_total_us); + + llama_sampling_tail_free_impl(*smpl, candidates, z, min_keep); } -void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * candidates, float z, size_t min_keep) { - llama_sample_tail_free_impl(ctx ? &ctx->sampling : nullptr, candidates, z, min_keep); +void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) { + time_meas tm(smpl->t_total_us); + + llama_sampling_typical_impl(*smpl, candidates, p, min_keep); } -void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { - llama_sample_typical_impl(ctx ? &ctx->sampling : nullptr, candidates, p, min_keep); +void llama_sampling_entropy(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { + time_meas tm(smpl->t_total_us); + + llama_sampling_entropy_impl(*smpl, candidates_p, min_temp, max_temp, exponent_val); } -void llama_sample_entropy(struct llama_context * ctx, llama_token_data_array * candidates_p, float min_temp, float max_temp, float exponent_val) { - llama_sample_entropy_impl(ctx ? &ctx->sampling : nullptr, candidates_p, min_temp, max_temp, exponent_val); +void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates_p, float temp) { + time_meas tm(smpl->t_total_us); + + llama_sampling_temp_impl(*smpl, candidates_p, temp); } -void llama_sample_temp(struct llama_context * ctx, llama_token_data_array * candidates_p, float temp) { - llama_sample_temp_impl(ctx ? &ctx->sampling : nullptr, candidates_p, temp); +void llama_sampling_grammar( + struct llama_sampling * smpl, + llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling + + llama_sampling_grammar_impl(*smpl, candidates); } -void llama_sample_repetition_penalties( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t penalty_last_n, - float penalty_repeat, - float penalty_freq, - float penalty_present) { - llama_sample_repetition_penalties_impl(ctx ? &ctx->sampling : nullptr, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); +void llama_sampling_repetition_penalties( + struct llama_sampling * smpl, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + time_meas tm(smpl->t_total_us); + + llama_sampling_repetition_penalties_impl(*smpl, candidates, last_tokens, penalty_last_n, penalty_repeat, penalty_freq, penalty_present); } -void llama_sample_apply_guidance( - struct llama_context * ctx, - float * logits, - float * logits_guidance, - float scale) { - llama_sample_apply_guidance_impl(&ctx->sampling, logits, logits_guidance, scale); +void llama_sampling_apply_guidance( + struct llama_sampling * smpl, + float * logits, + float * logits_guidance, + float scale) { + time_meas tm(smpl->t_total_us); + + llama_sampling_apply_guidance_impl(*smpl, logits, logits_guidance, scale); } -llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { - return llama_sample_token_mirostat_impl(&ctx->sampling, candidates, tau, eta, m, mu); +llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) { + time_meas tm(smpl->t_total_us); + + auto res = llama_sampling_sample_mirostat_impl(*smpl, candidates, tau, eta, m, mu); + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token_mirostat_v2(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, float * mu) { - return llama_sample_token_mirostat_v2_impl(ctx ? &ctx->sampling : nullptr, candidates, tau, eta, mu); +llama_token llama_sampling_sample_mirostat_v2(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) { + time_meas tm(smpl->t_total_us); + + auto res = llama_sampling_sample_mirostat_v2_impl(*smpl, candidates, tau, eta, mu); + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_greedy_impl(ctx ? &ctx->sampling : nullptr, candidates); +llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); + + auto res = llama_sampling_sample_greedy_impl(*smpl, candidates); + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token_with_rng(struct llama_context * ctx, llama_token_data_array * candidates, std::mt19937 & rng) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, rng); +llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) { + time_meas tm(smpl->t_total_us); + + auto res = llama_sampling_sample_impl(*smpl, candidates); + + smpl->n_sample++; + + return res; } -llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - return llama_sample_token_with_rng_impl(&ctx->sampling, candidates, ctx->sampling.rng); +void llama_sampling_accept( + struct llama_sampling * smpl, + llama_token token) { + time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling + + llama_sampling_accept_impl(*smpl, token); } +// +// model split +// + int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { @@ -19033,30 +19087,29 @@ int llama_split_prefix(char * dest, size_t maxlen, const char * split_path, int return 0; } -struct llama_timings llama_get_timings(struct llama_context * ctx) { - struct llama_timings result = { - /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, - /*.t_end_ms =*/ 1.00 * ggml_time_ms(), - /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, - /*.t_sample_ms =*/ 1e-3 * ctx->sampling.t_sample_us, - /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, - /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, - - /*.n_sample =*/ std::max(1, ctx->sampling.n_sample), - /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), - /*.n_eval =*/ std::max(1, ctx->n_eval), +void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smpl) { + const llama_timings timings = { + /*.t_start_ms =*/ 1e-3 * ctx->t_start_us, + /*.t_end_ms =*/ 1.00 * ggml_time_ms(), + /*.t_load_ms =*/ 1e-3 * ctx->t_load_us, + /*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_total_us : ctx->sampling.t_total_us), + /*.t_grammar_ms =*/ 1e-3 * (smpl && smpl->grammar ? smpl->grammar->t_total_us : 0.0), + /*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us, + /*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us, + + /*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : ctx->sampling.n_sample), + /*.n_grammar_sample =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_sample : 0), + /*.n_grammar_accept =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_accept : 0), + /*.n_p_eval =*/ std::max(0, ctx->n_p_eval), + /*.n_eval =*/ std::max(1, ctx->n_eval), }; - return result; -} - -void llama_print_timings(struct llama_context * ctx) { - const llama_timings timings = llama_get_timings(ctx); - LLAMA_LOG_INFO("\n"); LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms); - LLAMA_LOG_INFO("%s: sample time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", - __func__, timings.t_sample_ms, timings.n_sample, timings.t_sample_ms / timings.n_sample, 1e3 / timings.t_sample_ms * timings.n_sample); + LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling); + LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", + __func__, timings.t_grammar_ms, timings.n_grammar_sample, timings.t_grammar_ms / timings.n_grammar_sample, 1e3 / timings.t_grammar_ms * timings.n_grammar_sample); LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", __func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval); LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", @@ -19064,12 +19117,18 @@ void llama_print_timings(struct llama_context * ctx) { LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (timings.t_end_ms - timings.t_start_ms), (timings.n_p_eval + timings.n_eval)); } -void llama_reset_timings(struct llama_context * ctx) { +void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smpl) { ctx->t_start_us = ggml_time_us(); ctx->t_eval_us = ctx->n_eval = 0; ctx->t_p_eval_us = ctx->n_p_eval = 0; - ctx->sampling.reset_timings(); + if (smpl) { + smpl->t_total_us = smpl->n_sample = 0; + + if (smpl->grammar) { + smpl->grammar->t_total_us = smpl->grammar->n_sample = smpl->grammar->n_accept = 0; + } + } } const char * llama_print_system_info(void) { @@ -19111,21 +19170,15 @@ void llama_dump_timing_info_yaml(FILE * stream, const llama_context * ctx) { 1.0e-3 * ctx->t_eval_us / ctx->n_eval); fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "mst_sample: %.2f # ms / token during sampling\n", - 1.0e-3 * ctx->sampling.t_sample_us / ctx->sampling.n_sample); fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "n_sample: %d # number of sampled tokens\n", ctx->sampling.n_sample); fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "t_sample_us: %" PRId64 " # total microseconds spent sampling\n", ctx->sampling.t_sample_us); fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", 1.0e6 * ctx->n_eval / ctx->t_eval_us); fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); - fprintf(stream, "ts_sample: %.2f # tokens / second during sampling\n", - 1.0e6 * ctx->sampling.n_sample / ctx->sampling.t_sample_us); } // For internal test use diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 68f971bfe2cf3d..c5a4bc4954f74a 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -2,33 +2,23 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL - #include "ggml.h" #include "llama.h" -#include "grammar-parser.h" -#include "json-schema-to-grammar.h" +#include "llama-vocab.h" // TMP +#include "llama-grammar.h" #include "unicode.h" +#include "json-schema-to-grammar.h" + #include #include #include using json = nlohmann::ordered_json; -static llama_grammar* build_grammar(const std::string & grammar_str) { - auto parsed_grammar = grammar_parser::parse(grammar_str.c_str()); - - // Ensure we parsed correctly - assert(!parsed_grammar.rules.empty()); - - // Ensure we have a root node - assert(!(parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end())); - - std::vector grammar_rules(parsed_grammar.c_rules()); - llama_grammar* grammar = llama_grammar_init( - grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); +llama_vocab vocab; // TMP - return grammar; +static llama_grammar * build_grammar(const std::string & grammar_str) { + return llama_grammar_init_impl(vocab, grammar_str.c_str(), "root"); } static bool test_build_grammar_fails(const std::string & grammar_str) { @@ -143,7 +133,7 @@ static void test(const std::string & test_desc, const std::string & grammar_str, } // Clean up allocated memory - llama_grammar_free(grammar); + llama_grammar_free_impl(grammar); } static void test_grammar(const std::string & test_desc, const std::string & grammar_str, const std::vector & passing_strings, const std::vector & failing_strings) { test(test_desc + ". Grammar: " + grammar_str, grammar_str, passing_strings, failing_strings); @@ -683,7 +673,8 @@ static void test_failure_missing_root() { term ::= number number ::= [0-9]+)"""; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we parsed correctly assert(!parsed_grammar.rules.empty()); @@ -705,7 +696,8 @@ static void test_failure_missing_reference() { fprintf(stderr, " Expected error: "); - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_str.c_str()); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_str.c_str()); // Ensure we did NOT parsed correctly assert(parsed_grammar.rules.empty()); diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 5df5abb25394c8..259172d999c789 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -3,7 +3,7 @@ #endif #include "llama.h" -#include "grammar-parser.h" +#include "llama-grammar.h" #include @@ -22,7 +22,8 @@ static const char * type_str(llama_gretype type) { static void verify_parsing(const char *grammar_bytes, const std::vector> expected, const std::vector &expected_rules) { uint32_t index = 0; - grammar_parser::parse_state parsed_grammar = grammar_parser::parse(grammar_bytes); + llama_grammar_parser parsed_grammar; + parsed_grammar.parse(grammar_bytes); std::map symbol_names; for (auto it = parsed_grammar.symbol_ids.begin(); it != parsed_grammar.symbol_ids.end(); ++it) { @@ -129,9 +130,10 @@ static void verify_parsing(const char *grammar_bytes, const std::vector #include #include #include -#include "json-schema-to-grammar.h" -#include "grammar-parser.h" - static std::string trim(const std::string & source) { std::string s(source); s.erase(0,s.find_first_not_of(" \n\r\t")); @@ -40,7 +41,8 @@ struct TestCase { } void verify_expectation_parseable() const { try { - auto state = grammar_parser::parse(expected_grammar.c_str()); + llama_grammar_parser state; + state.parse(expected_grammar.c_str()); if (state.symbol_ids.find("root") == state.symbol_ids.end()) { throw std::runtime_error("Grammar failed to parse:\n" + expected_grammar); } diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 1f3a267b39f9ba..1f4a1f1bc34325 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,16 +2,16 @@ #undef NDEBUG #endif -#define LLAMA_API_INTERNAL #include "llama.h" -#include "grammar-parser.h" +#include "llama-vocab.h" // TMP +#include "llama-grammar.h" #include #include int main() { - grammar_parser::parse_state parsed_grammar; + llama_grammar_parser parsed_grammar; std::vector> expected = { {"expr", 2}, @@ -117,7 +117,8 @@ int main() llama_grammar * grammar = NULL; std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); + llama_vocab vocab; // TMP + grammar = llama_grammar_init_impl(vocab, grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); if (grammar == nullptr) { throw std::runtime_error("Failed to initialize llama_grammar"); @@ -403,6 +404,8 @@ int main() delete[] candidate.code_points; candidate.code_points = nullptr; } - llama_grammar_free(grammar); + + llama_grammar_free_impl(grammar); + return 0; } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index de858bd3b87e3d..45dcb21ad0aab7 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" +#include "llama-sampling.h" #ifdef NDEBUG #undef NDEBUG @@ -20,6 +21,8 @@ static void dump(const llama_token_data_array * candidates) { static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { const size_t n_vocab = probs.size(); + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -28,9 +31,9 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -49,9 +54,9 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { const size_t n_vocab = probs.size(); + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -71,7 +78,7 @@ static void test_tfs(const std::vector & probs, const std::vector llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; DUMP(&candidates_p); - llama_sample_tail_free(nullptr, &candidates_p, z, 1); + llama_sampling_tail_free_impl(smpl, &candidates_p, z, 1); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -82,6 +89,8 @@ static void test_tfs(const std::vector & probs, const std::vector static void test_min_p(const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -91,9 +100,9 @@ static void test_min_p(const std::vector & probs, const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { const size_t n_vocab = probs.size(); + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -112,7 +123,7 @@ static void test_typical(const std::vector & probs, const std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -136,10 +149,10 @@ static void test_repetition_penalties( } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_softmax_impl(smpl, &candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); - llama_sample_softmax(nullptr, &candidates_p); + llama_sampling_repetition_penalties_impl(smpl, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); + llama_sampling_softmax_impl(smpl, &candidates_p); DUMP(&candidates_p); GGML_ASSERT(candidates_p.size == expected_probs.size()); @@ -148,9 +161,10 @@ static void test_repetition_penalties( } } -static void test_sampler_queue( - const size_t n_vocab, const std::string samplers_sequence, const int top_k, const float top_p, const float min_p +static void test_sampler_queue(const size_t n_vocab, const std::string & samplers_sequence, const int top_k, const float top_p, const float min_p ) { + llama_sampling smpl(n_vocab); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { @@ -165,16 +179,16 @@ static void test_sampler_queue( for (auto s : samplers_sequence) { switch (s){ - case 'k': llama_sample_top_k (nullptr, &candidates_p, top_k, 1); break; - case 'f': GGML_ABORT("tail_free test not implemented"); break; - case 'y': GGML_ABORT("typical test not implemented"); break; - case 'p': llama_sample_top_p (nullptr, &candidates_p, top_p, 1); break; - case 'm': llama_sample_min_p (nullptr, &candidates_p, min_p, 1); break; - case 't': GGML_ABORT("temperature test not implemented"); break; - default : GGML_ABORT("Unknown sampler"); break; + case 'k': llama_sampling_top_k_impl(smpl, &candidates_p, top_k, 1); break; + case 'f': GGML_ABORT("tail_free test not implemented"); + case 'y': GGML_ABORT("typical test not implemented"); + case 'p': llama_sampling_top_p_impl(smpl, &candidates_p, top_p, 1); break; + case 'm': llama_sampling_min_p_impl(smpl, &candidates_p, min_p, 1); break; + case 't': GGML_ABORT("temperature test not implemented"); + default : GGML_ABORT("Unknown sampler"); } - llama_sample_softmax(nullptr, &candidates_p); // make sure tokens are sorted for tests + llama_sampling_softmax_impl(smpl, &candidates_p); // make sure tokens are sorted for tests const int size = candidates_p.size;