Skip to content

Commit

Permalink
llama : add infill sampler (#9896)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov authored Oct 15, 2024
1 parent 223c25a commit 755a9b2
Show file tree
Hide file tree
Showing 9 changed files with 294 additions and 23 deletions.
4 changes: 2 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ enum common_sampler_type {
COMMON_SAMPLER_TYPE_TYPICAL_P = 5,
COMMON_SAMPLER_TYPE_TEMPERATURE = 6,
COMMON_SAMPLER_TYPE_XTC = 7,

COMMON_SAMPLER_TYPE_INFILL = 8,
};

// dimensionality reduction methods, used by cvector-generator
Expand Down Expand Up @@ -136,7 +136,7 @@ struct common_sampler_params {
COMMON_SAMPLER_TYPE_TOP_P,
COMMON_SAMPLER_TYPE_MIN_P,
COMMON_SAMPLER_TYPE_XTC,
COMMON_SAMPLER_TYPE_TEMPERATURE
COMMON_SAMPLER_TYPE_TEMPERATURE,
};

std::string grammar; // optional BNF-like grammar to constrain sampling
Expand Down
9 changes: 8 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
case COMMON_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
Expand Down Expand Up @@ -376,6 +379,7 @@ char common_sampler_type_to_chr(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return 'm';
case COMMON_SAMPLER_TYPE_TEMPERATURE: return 't';
case COMMON_SAMPLER_TYPE_XTC: return 'x';
case COMMON_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
Expand All @@ -389,6 +393,7 @@ std::string common_sampler_type_to_str(enum common_sampler_type cnstr) {
case COMMON_SAMPLER_TYPE_MIN_P: return "min_p";
case COMMON_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case COMMON_SAMPLER_TYPE_XTC: return "xtc";
case COMMON_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
Expand All @@ -402,6 +407,7 @@ std::vector<common_sampler_type> common_sampler_types_from_names(const std::vect
{ "tfs_z", COMMON_SAMPLER_TYPE_TFS_Z },
{ "temperature", COMMON_SAMPLER_TYPE_TEMPERATURE },
{ "xtc", COMMON_SAMPLER_TYPE_XTC },
{ "infill", COMMON_SAMPLER_TYPE_INFILL },
};

// since samplers names are written multiple ways
Expand Down Expand Up @@ -448,7 +454,8 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TOP_P), COMMON_SAMPLER_TYPE_TOP_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_MIN_P), COMMON_SAMPLER_TYPE_MIN_P },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_TEMPERATURE), COMMON_SAMPLER_TYPE_TEMPERATURE },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC }
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_XTC), COMMON_SAMPLER_TYPE_XTC },
{ common_sampler_type_to_chr(COMMON_SAMPLER_TYPE_INFILL), COMMON_SAMPLER_TYPE_INFILL },
};

std::vector<common_sampler_type> samplers;
Expand Down
34 changes: 17 additions & 17 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,30 +569,30 @@ int main(int argc, char ** argv) {
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
}

if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}

const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;

LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;
n_past -= n_discard;

LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("after swap: n_past = %d\n", n_past);

LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());

LOG_DBG("clear session path\n");
path_session.clear();
}
LOG_DBG("clear session path\n");
path_session.clear();
}
} else {
// context extension via Self-Extend
Expand Down
28 changes: 28 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,12 @@ extern "C" {
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);

/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
Expand Down Expand Up @@ -1148,6 +1154,28 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// this sampler is meant to be used for fill-in-the-middle infilling
// it's supposed to be used after top_k + top_p sampling
//
// 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
// 2. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
// 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
Expand Down
201 changes: 201 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,6 +1739,207 @@ struct llama_sampler * llama_sampler_init_logit_bias(
};
}

// infill

//#define GGML_DEBUG_SAMPLER_INFILL

struct llama_sampler_infill {
const struct llama_vocab * vocab;
};

static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
return "infill";
}

static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;

llama_sampler_softmax_impl(cur_p);

#if defined(GGML_DEBUG_SAMPLER_INFILL)
#define LOG_DBG_CUR LLAMA_LOG_DEBUG
#else
#define LOG_DBG_CUR(...)
#endif

for (size_t i = 0; i < cur_p->size; ++i) {
LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

float p_txt_sum = 0.0f;
float p_eog_sum = 0.0f;

for (size_t i = 0; i < cur_p->size; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
} else {
p_txt_sum += cur_p->data[i].p;
}
}

const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat);

LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size);

if (3*p_eog_sum*cur_p->size > p_txt_sum) {
LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum);

// keep just the EOG tokens
const auto size_org = cur_p->size;

cur_p->size = 0;

float p_sum = 0.0f;

for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_sum += cur_p->data[i].p;

cur_p->data[cur_p->size++] = cur_p->data[i];
}
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;
}

return;
}

size_t n_combined = 0; GGML_UNUSED(n_combined);

// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
break;
}

if (i == j || cur_p->data[j].logit == -INFINITY) {
continue;
}

if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
cur_p->data[j].p = 0.0f;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
cur_p->data[i].p = 0.0f;
}

n_combined++;
}
}
}

size_t n_non_eog = 0;

size_t size_org = cur_p->size;

float p_sum = 0.0f;
float thold = 0.2f;

cur_p->size = 0;

LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);

for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);

if (cur_p->data[i].p < thold && !is_eog) {
continue;
}

if (!is_eog) {
++n_non_eog;
}

p_sum += cur_p->data[i].p;

// keep this token
cur_p->data[cur_p->size++] = cur_p->data[i];
}

LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog);

// if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;

return;
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;

LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

size_org = cur_p->size;
p_sum = 0.0f;
thold = 1.0/(n_non_eog + 1);

cur_p->size = 0;

LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);

for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id);

if (cur_p->data[i].p < thold && !is_eog) {
continue;
}

p_sum += cur_p->data[i].p;

cur_p->data[cur_p->size++] = cur_p->data[i];
}

// normalize probs
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= p_sum;

LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

#undef LOG_DBG_CUR
}

static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab);
}

static void llama_sampler_infill_free(struct llama_sampler * smpl) {
delete (llama_sampler_infill *) smpl->ctx;
}

static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
};

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
},
};
}

// utils

uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
Expand Down
5 changes: 3 additions & 2 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "llama-grammar.h"

#include <unordered_map>

struct llama_vocab;
struct llama_grammar;

Expand All @@ -27,3 +25,6 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
17 changes: 17 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0;
}

bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];

const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, sizeof(text_buf_0) - 1, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, sizeof(text_buf_1) - 1, 0, false);

if (len0 <= 0 || len1 <= 0) {
return false;
}

return len0 <= len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}

int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
Expand Down
Loading

0 comments on commit 755a9b2

Please sign in to comment.