Skip to content

Commit

Permalink
llama : move vocab, grammar and sampling into separate files (#8508)
Browse files Browse the repository at this point in the history
* llama : move sampling code into llama-sampling

ggml-ci

* llama : move grammar code into llama-grammar

ggml-ci

* cont

ggml-ci

* cont : pre-fetch rules

* cont

ggml-ci

* llama : deprecate llama_sample_grammar

* llama : move tokenizers into llama-vocab

ggml-ci

* make : update llama.cpp deps [no ci]

* llama : redirect external API to internal APIs

ggml-ci

* llama : suffix the internal APIs with "_impl"

ggml-ci

* llama : clean-up
  • Loading branch information
ggerganov authored Jul 23, 2024
1 parent 751fcfc commit 938943c
Show file tree
Hide file tree
Showing 18 changed files with 3,656 additions and 3,103 deletions.
32 changes: 31 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -876,6 +876,9 @@ OBJ_GGML += \

OBJ_LLAMA = \
src/llama.o \
src/llama-vocab.o \
src/llama-grammar.o \
src/llama-sampling.o \
src/unicode.o \
src/unicode-data.o

Expand Down Expand Up @@ -1055,6 +1058,10 @@ src/unicode-data.o: \

src/llama.o: \
src/llama.cpp \
src/llama-impl.h \
src/llama-vocab.h \
src/llama-grammar.h \
src/llama-sampling.h \
src/unicode.h \
include/llama.h \
ggml/include/ggml-cuda.h \
Expand All @@ -1064,6 +1071,29 @@ src/llama.o: \
ggml/include/ggml-backend.h
$(CXX) $(CXXFLAGS) -c $< -o $@

src/llama-vocab.o: \
src/llama-vocab.cpp \
src/llama-vocab.h \
src/llama-impl.h \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

src/llama-grammar.o: \
src/llama-grammar.cpp \
src/llama-grammar.h \
src/llama-impl.h \
src/llama-vocab.h \
src/llama-sampling.h \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

src/llama-sampling.o: \
src/llama-sampling.cpp \
src/llama-sampling.h \
src/llama-impl.h \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

$(LIB_LLAMA): \
$(OBJ_LLAMA) \
$(LIB_GGML)
Expand Down Expand Up @@ -1439,7 +1469,7 @@ run-benchmark-matmult: llama-benchmark-matmult
.PHONY: run-benchmark-matmult swift

tests/test-llama-grammar: tests/test-llama-grammar.cpp \
$(OBJ_GGML) $(OBJ_COMMON) src/unicode.o src/unicode-data.o
$(OBJ_ALL)
$(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<)
$(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS)

Expand Down
3 changes: 3 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import PackageDescription

var sources = [
"src/llama.cpp",
"src/llama-vocab.cpp",
"src/llama-grammar.cpp",
"src/llama-sampling.cpp",
"src/unicode.cpp",
"src/unicode-data.cpp",
"ggml/src/ggml.c",
Expand Down
6 changes: 3 additions & 3 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ static llama_token llama_sampling_sample_impl(
llama_token_data_array single_token_data_array = { &single_token_data, 1, false };

// Apply grammar constraints to the single token
llama_sample_grammar(ctx_main, &single_token_data_array, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &single_token_data_array);

// Check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
Expand Down Expand Up @@ -421,7 +421,7 @@ static llama_token_data_array llama_sampling_prepare_impl(

// apply grammar checks before sampling logic
if (apply_grammar && ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
llama_grammar_sample(ctx_sampling->grammar, ctx_main, &cur_p);
}

return cur_p;
Expand Down Expand Up @@ -455,6 +455,6 @@ void llama_sampling_accept(
ctx_sampling->prev.push_back(id);

if (ctx_sampling->grammar != NULL && apply_grammar) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
llama_grammar_accept_token(ctx_sampling->grammar, ctx_main, id);
}
}
15 changes: 10 additions & 5 deletions examples/gbnf-validator/gbnf-validator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,25 @@ static bool llama_sample_grammar_string(struct llama_grammar * grammar, const st
auto decoded = decode_utf8(input_str, {});
const auto & code_points = decoded.first;

const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);

size_t pos = 0;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
auto prev_stacks = grammar->stacks;
llama_grammar_accept(grammar->rules, prev_stacks, *it, grammar->stacks);
if (grammar->stacks.empty()) {
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy

llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);

if (cur_stacks.empty()) {
error_pos = pos;
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
grammar->stacks = prev_stacks;
cur_stacks = prev_stacks;
return false;
}
++pos;
}

for (const auto & stack : grammar->stacks) {
for (const auto & stack : cur_stacks) {
if (stack.empty()) {
return true;
}
Expand Down
76 changes: 46 additions & 30 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -906,10 +906,10 @@ extern "C" {
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding

// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);
LLAMA_API int32_t llama_add_bos_token(const struct llama_model * model);

// Returns -1 if unknown, 1 for true or 0 for false.
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);
LLAMA_API int32_t llama_add_eos_token(const struct llama_model * model);

// Codellama infill tokens
LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
Expand Down Expand Up @@ -965,6 +965,10 @@ extern "C" {
bool remove_special,
bool unparse_special);

//
// Chat templates
//

/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
Expand Down Expand Up @@ -1003,6 +1007,23 @@ extern "C" {

LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);

/// @details Apply constraints from grammar
LLAMA_API void llama_grammar_sample(
const struct llama_grammar * grammar,
const struct llama_context * ctx,
llama_token_data_array * candidates);
LLAMA_API DEPRECATED(void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar),
"use llama_grammar_sample instead");

/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_grammar * grammar,
struct llama_context * ctx,
llama_token token);

//
// Sampling functions
//
Expand Down Expand Up @@ -1084,12 +1105,6 @@ extern "C" {
llama_token_data_array * candidates,
float temp);

/// @details Apply constraints from grammar
LLAMA_API void llama_sample_grammar(
struct llama_context * ctx,
llama_token_data_array * candidates,
const struct llama_grammar * grammar);

/// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
/// @param tau The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
Expand Down Expand Up @@ -1127,12 +1142,6 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates);

/// @details Accepts the sampled token into the grammar
LLAMA_API void llama_grammar_accept_token(
struct llama_context * ctx,
struct llama_grammar * grammar,
llama_token token);

//
// Model split
//
Expand Down Expand Up @@ -1175,38 +1184,45 @@ extern "C" {

struct ggml_tensor;

const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);

struct llama_partial_utf8 {
uint32_t value; // bit value so far (unshifted)
int n_remain; // num bytes remaining; -1 indicates invalid sequence
};

struct llama_grammar {
const std::vector<std::vector<llama_grammar_element>> rules;
std::vector<std::vector<const llama_grammar_element *>> stacks;

// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;
};

struct llama_grammar_candidate {
size_t index;
const uint32_t * code_points;
llama_partial_utf8 partial_utf8;
};

const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal_get_tensor_map(
struct llama_context * ctx
);
using llama_grammar_rule = std::vector< llama_grammar_element>;
using llama_grammar_stack = std::vector<const llama_grammar_element *>;

using llama_grammar_rules = std::vector<llama_grammar_rule>;
using llama_grammar_stacks = std::vector<llama_grammar_stack>;
using llama_grammar_candidates = std::vector<llama_grammar_candidate>;

const llama_grammar_rules & llama_grammar_get_rules (const struct llama_grammar * grammar);
llama_grammar_stacks & llama_grammar_get_stacks( struct llama_grammar * grammar);

void llama_grammar_accept(
const std::vector<std::vector<llama_grammar_element>> & rules,
const std::vector<std::vector<const llama_grammar_element *>> & stacks,
const uint32_t chr,
std::vector<std::vector<const llama_grammar_element *>> & new_stacks);
const llama_grammar_rules & rules,
const llama_grammar_stacks & stacks,
const uint32_t chr,
llama_grammar_stacks & new_stacks);

std::vector<llama_grammar_candidate> llama_grammar_reject_candidates_for_stack(
const llama_grammar_rules & rules,
const llama_grammar_stack & stack,
const llama_grammar_candidates & candidates);

std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const std::string & src,
llama_partial_utf8 partial_start);
llama_partial_utf8 partial_start);

// Randomly selects a token from the candidates based on their probabilities using given std::mt19937.
// This is a temporary workaround in order to fix race conditions when sampling with multiple sequences.
Expand Down
3 changes: 3 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ endif()
add_library(llama
../include/llama.h
llama.cpp
llama-vocab.cpp
llama-grammar.cpp
llama-sampling.cpp
unicode.h
unicode.cpp
unicode-data.cpp
Expand Down
Loading

0 comments on commit 938943c

Please sign in to comment.