Skip to content

Commit

Permalink
style : rearrange code + add comments and TODOs
Browse files Browse the repository at this point in the history
ggml-ci
ggerganov committed Sep 7, 2024
1 parent 4a4530b commit 4b27235
Showing 4 changed files with 115 additions and 48 deletions.
71 changes: 36 additions & 35 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
@@ -136,17 +136,6 @@ std::string gpt_sampler_params::print() const {
return std::string(result);
}

std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits ";

for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}

return result;
}

struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params) {
llama_sampler_chain_params lparams = llama_sampler_chain_default_params();

@@ -232,17 +221,6 @@ void gpt_sampler_free(struct gpt_sampler * gsmpl) {
}
}

struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}

void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar) {
if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
@@ -259,12 +237,15 @@ void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->chain);
}

llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
return &gsmpl->cur_p;
}

llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
return gsmpl->prev.rat(0);
struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl) {
return new gpt_sampler {
/* .params = */ gsmpl->params,
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
/* .chain = */ llama_sampler_clone(gsmpl->chain),
/* .prev = */ gsmpl->prev,
/* .cur = */ gsmpl->cur,
/* .cur_p = */ gsmpl->cur_p,
};
}

void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl) {
@@ -279,12 +260,11 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
}

llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;

gsmpl->set_logits(ctx, idx);

auto & cur_p = gsmpl->cur_p;
auto & grmr = gsmpl->grmr;
auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits

if (grammar_first) {
llama_sampler_apply(grmr, &cur_p);
@@ -307,24 +287,45 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context

llama_sampler_apply(grmr, &single_token_data_array);

// check if the token is valid according to the grammar by seeing if its logit has been set to -INFINITY
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
if (is_valid) {
return id;
}
}

// if the token is not valid, sample again, first apply the grammar samplers and then sample
// resampling:
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
gsmpl->set_logits(ctx, idx);

llama_sampler_apply(grmr, &cur_p);
llama_sampler_apply(chain, &cur_p);

GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");

return cur_p.data[cur_p.selected].id;
}

// helpers

llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
return &gsmpl->cur_p;
}

llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl) {
return gsmpl->prev.rat(0);
}

std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits ";

for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
}

return result;
}

std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx_main, int n) {
n = std::min(n, (int) gsmpl->prev.size());

41 changes: 32 additions & 9 deletions common/sampling.h
Original file line number Diff line number Diff line change
@@ -61,24 +61,41 @@ struct gpt_sampler_params {
//
// - grammar support
// - custom sampler logic based on the parameters
// - history of the last accepted tokens
// - performance metrics
//
// This goal is to have a common implementation of the sampling logic shared across the examples.
// For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
// complex (top-k, top-p, etc).
//
// Another example is related to the grammar. In general, the grammar constraints applied on the full
// vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
// token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
// grammar constraints are applied to the full vocabulary and the token is resampled.
//
// The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
// be moved into the core llama library.
//
// For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
// This can be used to access the probabilities of the rest of the non-sampled tokens.
//
// TODO: measure grammar performance
//

struct gpt_sampler;

// llama_sampler API overloads

struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);

void gpt_sampler_free(struct gpt_sampler * gsmpl);

struct gpt_sampler * gpt_sampler_clone(gpt_sampler * gsmpl);

void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);

llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);

llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
// if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
void gpt_sampler_reset (struct gpt_sampler * gsmpl);
struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);

// arguments can be nullptr to skip printing
void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);

// extended sampling implementation:
@@ -89,12 +106,18 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
//
// if grammar_first is true, the grammar is applied before the samplers (slower)
// useful in cases where all the resulting candidates must fit the grammar
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

// helpers

// access the internal list of current candidate tokens
llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);

// get the last accepted token
llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);

// print the sampler chain into a string
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);

49 changes: 45 additions & 4 deletions include/llama.h
Original file line number Diff line number Diff line change
@@ -206,6 +206,7 @@ extern "C" {
LLAMA_SPLIT_MODE_ROW = 2, // split rows across GPUs
};

// TODO: simplify (https://github.com/ggerganov/llama.cpp/pull/9294#pullrequestreview-2286561979)
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@@ -216,7 +217,7 @@ extern "C" {
// TODO: consider SoA
llama_token_data * data;
size_t size;
int64_t selected;
int64_t selected; // this is the index in the data array (i.e. not the token id)
bool sorted;
} llama_token_data_array;

@@ -979,9 +980,38 @@ extern "C" {
//
// Sampling API
//
// In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// Sample usage:
//
// // prepare the sampling chain at the start
// auto sparams = llama_sampler_chain_default_params();
//
// llama_sampler * smpl = llama_sampler_chain_init(sparams);
//
// llama_sampler_chain_add(smpl, llama_sampler_init_top_k(50));
// llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
// llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.8));
// llama_sampler_chain_add(smpl, llama_sampler_init_dist (seed));
//
// ...
//
// // decoding loop:
// while (...) {
// ...
//
// llama_decode(ctx, batch);
//
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// ...
// }
//
// llama_sampler_free(smpl);
//
//
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
//
// TODO: in the future, the entire API that uses llama_model should start using llama_vocab

typedef void * llama_sampler_context_t;

@@ -1003,6 +1033,7 @@ extern "C" {
llama_sampler_context_t ctx;
};

// mirror of llama_sampler_i:
LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p);
@@ -1011,7 +1042,8 @@ extern "C" {
// important: do not free if the sampler has been added to a llama_sampler_chain (via llama_sampler_chain_add)
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);

// llama_sampler_chain is a type of llama_sampler that can contain multiple llama_samplers
// llama_sampler_chain
// a type of llama_sampler that can chain multiple samplers one after another

LLAMA_API struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params);

@@ -1089,6 +1121,15 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// Shorthand for:
//
// const auto * logits = llama_get_logits_ith(ctx, idx);
// llama_token_data_array cur_p = { ... init from logits ... };
// llama_sampler_apply(smpl, &cur_p);
// return cur_p.data[cur_p.selected].id;
//
// At this point, this is mostly a convenience function.
//
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx);

// TODO: extend in the future
2 changes: 2 additions & 0 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?

#include "llama-grammar.h"

#include <unordered_map>

0 comments on commit 4b27235

Please sign in to comment.