Skip to content

Commit

Permalink
cont : remove llama_sampling_context
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 29, 2024
1 parent 97731bf commit 9dd2061
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 164 deletions.
48 changes: 18 additions & 30 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,8 @@ std::string gpt_sampling_params::print_samplers() const {

return result;
}
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
struct llama_sampling_context * result = new llama_sampling_context();

result->params = params;
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params) {
struct llama_sampling * result = nullptr;

{
auto lparams = llama_sampling_default_params();
Expand Down Expand Up @@ -66,35 +64,25 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_model * m
lparams.samplers[i] = params.samplers[i];
}

result->smpl = llama_sampling_init(model, lparams);
result = llama_sampling_init(model, lparams);

llama_sampling_set_grammar (result->smpl, params.grammar.c_str(), "root");
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
llama_sampling_set_grammar (result, params.grammar.c_str(), "root");
llama_sampling_set_logit_bias(result, params.logit_bias.size(), params.logit_bias.data());
}

return result;
}

void llama_sampling_free(struct llama_sampling_context * ctx) {
llama_sampling_free(ctx->smpl);

delete ctx;
}

void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
if (dst->smpl) {
llama_sampling_free(dst->smpl);
void llama_sampling_cp(llama_sampling * src, llama_sampling * dst) {
if (dst) {
llama_sampling_free(dst);
}

dst->smpl = llama_sampling_cp(src->smpl);
}

llama_token llama_sampling_last(llama_sampling_context * ctx) {
return llama_sampling_prev(ctx->smpl, 0);
dst = llama_sampling_cp(src);
}

std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n) {
n = std::min(n, llama_sampling_n_prev(ctx_sampling->smpl));
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx_main, int n) {
n = std::min(n, llama_sampling_n_prev(smpl));

if (n <= 0) {
return "";
Expand All @@ -104,7 +92,7 @@ std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama
result.reserve(8*n); // 8 is the average length of a token [citation needed], TODO: compute this from the vocab

for (int i = n - 1; i >= 0; i--) {
const llama_token id = llama_sampling_prev(ctx_sampling->smpl, i);
const llama_token id = llama_sampling_prev(smpl, i);

GGML_ASSERT(id != LLAMA_TOKEN_NULL && "null token in the sampling history - should not happen");

Expand Down Expand Up @@ -206,14 +194,14 @@ std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::strin
}

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_sampling * smpl,
struct llama_context * ctx,
int idx) {
llama_sampling_set_logits(ctx_sampling->smpl, llama_get_logits_ith(ctx_main, idx));
llama_sampling_set_logits(smpl, llama_get_logits_ith(ctx, idx));

auto * cur_p = llama_sampling_get_candidates(ctx_sampling->smpl);
auto * cur_p = llama_sampling_get_candidates(smpl);

llama_sampling_grammar(ctx_sampling->smpl, cur_p);
llama_sampling_grammar(smpl, cur_p);

return llama_sampling_sample(ctx_sampling->smpl, cur_p);
return llama_sampling_sample(smpl, cur_p);
}
41 changes: 13 additions & 28 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

// sampling parameters
typedef struct gpt_sampling_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling

int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
Expand All @@ -30,7 +30,7 @@ typedef struct gpt_sampling_params {
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;

std::vector<llama_sampler_type> samplers = {
std::vector<enum llama_sampler_type> samplers = {
LLAMA_SAMPLER_TYPE_TOP_K,
LLAMA_SAMPLER_TYPE_TFS_Z,
LLAMA_SAMPLER_TYPE_TYPICAL_P,
Expand All @@ -50,36 +50,21 @@ typedef struct gpt_sampling_params {
std::string print_samplers() const;
} gpt_sampling_params;

// general sampler context
// TODO: move to llama.h
struct llama_sampling_context {
// parameters that will be used for sampling
gpt_sampling_params params;
// overload of llama_sampling_init using gpt_sampling_params
struct llama_sampling * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);

llama_sampling * smpl;
};
void llama_sampling_cp(llama_sampling * src, llama_sampling * dst);

// Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_model * model, const struct gpt_sampling_params & params);
// get a string representation of the last accepted tokens
std::string llama_sampling_prev_str(llama_sampling * smpl, llama_context * ctx, int n);

void llama_sampling_free(struct llama_sampling_context * ctx);
char llama_sampling_type_to_chr(enum llama_sampler_type sampler_type);
std::string llama_sampling_type_to_str(enum llama_sampler_type sampler_type);

// Copy the sampler context
void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst);

// Get the last accepted token
llama_token llama_sampling_last(llama_sampling_context * ctx);

// Get a string representation of the last accepted tokens
std::string llama_sampling_prev_str(llama_sampling_context * ctx_sampling, llama_context * ctx_main, int n);

char llama_sampling_type_to_chr(llama_sampler_type sampler_type);
std::string llama_sampling_type_to_str(llama_sampler_type sampler_type);

std::vector<llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);
std::vector<enum llama_sampler_type> llama_sampling_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
std::vector<enum llama_sampler_type> llama_sampling_types_from_chars(const std::string & names_string);

llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_sampling * smpl,
struct llama_context * ctx,
int idx = -1);
28 changes: 14 additions & 14 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

static llama_context ** g_ctx;
static llama_model ** g_model;
static llama_sampling_context ** g_ctx_sampling;
static llama_sampling ** g_smpl;
static gpt_params * g_params;
static std::vector<llama_token> * g_input_tokens;
static std::ostringstream * g_output_ss;
Expand Down Expand Up @@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
} else {
console::cleanup();
printf("\n");
llama_print_timings(*g_ctx, (*g_ctx_sampling)->smpl);
llama_print_timings(*g_ctx, *g_smpl);
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
_exit(130);
}
Expand Down Expand Up @@ -167,11 +167,11 @@ int main(int argc, char ** argv) {

llama_model * model = nullptr;
llama_context * ctx = nullptr;
llama_sampling_context * ctx_sampling = nullptr;
llama_sampling * smpl = nullptr;

g_model = &model;
g_ctx = &ctx;
g_ctx_sampling = &ctx_sampling;
g_smpl = &smpl;

// load the model and apply lora adapter, if any
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
Expand Down Expand Up @@ -345,7 +345,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;

ctx_sampling = llama_sampling_init(model, sparams);
smpl = llama_sampling_init(model, sparams);

while (n_remain != 0 || params.interactive) {
// predict
Expand Down Expand Up @@ -417,11 +417,11 @@ int main(int argc, char ** argv) {
embd.clear();

if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx);
const llama_token id = llama_sampling_sample(smpl, ctx);

llama_sampling_accept(ctx_sampling->smpl, id, true);
llama_sampling_accept(smpl, id, true);

// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev.to_vector()).c_str());
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());

embd.push_back(id);

Expand All @@ -440,7 +440,7 @@ int main(int argc, char ** argv) {

// push the prompt in the sampling context in order to apply repetition penalties later
// for the prompt, we don't apply grammar rules
llama_sampling_accept(ctx_sampling->smpl, embd_inp[n_consumed], false);
llama_sampling_accept(smpl, embd_inp[n_consumed], false);

++n_consumed;
if ((int) embd.size() >= params.n_batch) {
Expand Down Expand Up @@ -472,7 +472,7 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
if ((llama_sampling_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) {
// print an eot token
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
Expand Down Expand Up @@ -538,7 +538,7 @@ int main(int argc, char ** argv) {
is_interacting = false;
}
// deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
else if (llama_token_is_eog(model, llama_sampling_last(smpl))) {
LOG("found EOS token\n");

if (params.interactive) {
Expand Down Expand Up @@ -611,7 +611,7 @@ int main(int argc, char ** argv) {

if (n_past > 0) {
if (is_interacting) {
llama_sampling_reset(ctx_sampling->smpl);
llama_sampling_reset(smpl);
}
is_interacting = false;
}
Expand All @@ -634,13 +634,13 @@ int main(int argc, char ** argv) {
fflush(stdout);
}

llama_print_timings(ctx, ctx_sampling->smpl);
llama_print_timings(ctx, smpl);
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);

llama_free(ctx);
llama_free_model(model);

llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
llama_backend_free();

#ifndef LOG_DISABLE_LOGS
Expand Down
14 changes: 7 additions & 7 deletions examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
return true;
}

static const char * sample(struct llama_sampling_context * ctx_sampling,
static const char * sample(struct llama_sampling * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
llama_sampling_accept(ctx_sampling->smpl, id, true);
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
llama_sampling_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
Expand Down Expand Up @@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_

LOG_TEE("\n");

struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
if (!ctx_sampling) {
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}

std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
Expand All @@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
fflush(stdout);
}

llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
printf("\n");
}

Expand Down
28 changes: 14 additions & 14 deletions examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
}

static const char * sample(struct llama_sampling_context * ctx_sampling,
static const char * sample(struct llama_sampling * smpl,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama);
llama_sampling_accept(ctx_sampling->smpl, id, true);
const llama_token id = llama_sampling_sample(smpl, ctx_llama);
llama_sampling_accept(smpl, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
Expand Down Expand Up @@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
return ctx_llava;
}

static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
static struct llama_sampling * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
std::string user_prompt = prompt;
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
if (!is_first) {
Expand All @@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla

LOG_TEE("\n");

struct llama_sampling_context * ctx_sampling = llama_sampling_init(ctx_llava->model, params->sparams);
return ctx_sampling;
struct llama_sampling * smpl = llama_sampling_init(ctx_llava->model, params->sparams);
return smpl;
}

static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling * smpl, int &n_past){

const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
return tmp;
}

Expand Down Expand Up @@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
if (!params.prompt.empty()) {
LOG_TEE("<user>%s\n", params.prompt.c_str());
LOG_TEE("<assistant>");
auto ctx_sampling = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
auto smpl = llama_init(ctx_llava, &params, params.prompt.c_str(), n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
bool have_tmp = false;
for (int i = 0; i < max_tgt_len; i++) {
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
auto tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0){
if(!have_tmp)continue;
Expand All @@ -296,26 +296,26 @@ int main(int argc, char ** argv) {

fflush(stdout);
}
llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
}else {
while (true) {
LOG_TEE("<user>");
std::string prompt;
std::getline(std::cin, prompt);
LOG_TEE("<assistant>");
auto ctx_sampling = llama_init(ctx_llava, &params, prompt, n_past, true);
auto smpl = llama_init(ctx_llava, &params, prompt, n_past, true);
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
auto tmp = llama_loop(ctx_llava, smpl, n_past);
response += tmp;
if (strcmp(tmp, "</s>") == 0) break;
if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);// mistral llava-1.6
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
fflush(stdout);
}
llama_sampling_free(ctx_sampling);
llama_sampling_free(smpl);
}
}
printf("\n");
Expand Down
Loading

0 comments on commit 9dd2061

Please sign in to comment.