Skip to content

Commit

Permalink
sampler : API to iterate constraints
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 4, 2024
1 parent 23f0802 commit 762e955
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 47 deletions.
39 changes: 17 additions & 22 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct gpt_sampler {
struct llama_sampler * smpl;
};

std::string gpt_sampler_params::print_all() const {
std::string gpt_sampler_params::print() const {
char result[1024];

snprintf(result, sizeof(result),
Expand All @@ -26,17 +26,12 @@ std::string gpt_sampler_params::print_all() const {
return std::string(result);
}

std::string gpt_sampler_params::print_constraints() const {
std::string result = "CFG -> Penalties ";
if (mirostat == 0) {
for (const auto & cnstr : constraints) {
const auto name = gpt_constraint_type_to_str(cnstr);
if (!name.empty()) {
result += "-> " + name + " ";
}
}
} else {
result += "-> mirostat ";
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl) {
std::string result = "\tlogits";

for (int i = 0; i < llama_sampler_n_constraints(gsmpl->smpl); i++) {
const auto * cnstr = llama_sampler_constraint_get(gsmpl->smpl, i);
result += " -> " + std::string(cnstr->iface->name(cnstr)) + " ";
}

return result;
Expand Down Expand Up @@ -70,33 +65,33 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
for (const auto & cnstr : params.constraints) {
switch (cnstr) {
case GPT_CONSTRAINT_TYPE_TOP_K:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_k (params.top_k, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TOP_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_top_p (params.top_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_MIN_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_min_p (params.min_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TFS_Z:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_tail_free(params.tfs_z, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TYPICAL_P:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_typical (params.typ_p, params.min_keep));
break;
case GPT_CONSTRAINT_TYPE_TEMPERATURE:
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
default:
GGML_ASSERT(false && "unknown constraint type");
}
}
} else if (params.mirostat == 1) {
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat(model, params.mirostat_tau, params.mirostat_eta));
} else if (params.mirostat == 2) {
llama_sampler_add_constraint(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_add_constraint(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_temp(params.temp));
llama_sampler_constraint_add(result->smpl, llama_constraint_init_mirostat_v2(params.mirostat_tau, params.mirostat_eta));
} else {
GGML_ASSERT(false && "unknown mirostat version");
}
Expand Down
8 changes: 4 additions & 4 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ struct gpt_sampler_params {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply

// print the parameters into a string
std::string print_all() const;

// print the constraints into a string
std::string print_constraints() const;
std::string print() const;
};

// gpt_sampler extends llama_sampler with additional functionality:
Expand Down Expand Up @@ -100,6 +97,9 @@ llama_token gpt_sampler_sample_greedy(struct gpt_sampler * gsmpl, llama_token_da

// helpers

// print the constraints into a string
std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);

// get a string representation of the last accepted tokens
std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);

Expand Down
6 changes: 3 additions & 3 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ defer {
llama_sampler_free(smpl)
}

llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1));
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(40, 1));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(0.9, 1));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (0.4));

let n_ctx = llama_n_ctx(context)

Expand Down
6 changes: 3 additions & 3 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ int main(int argc, char ** argv) {

llama_sampler * smpl = llama_sampler_init(model, sparams);

llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
llama_sampler_constraint_add(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_constraint_add(smpl, llama_constraint_init_temp (params.sparams.temp));

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
Expand Down
17 changes: 9 additions & 8 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,15 @@ int main(int argc, char ** argv) {
}
}
}
LOG_TEE("sampling params: \n%s\n", sparams.print_all().c_str());
LOG_TEE("sampling constr: \n%s\n", sparams.print_constraints().c_str());

smpl = gpt_sampler_init(model, sparams);
if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}

LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);

// group-attention state
Expand Down Expand Up @@ -525,12 +532,6 @@ int main(int argc, char ** argv) {
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
}

smpl = gpt_sampler_init(model, sparams);
if (!smpl) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
}

if (llama_model_has_encoder(model)) {
int enc_input_size = embd_inp.size();
llama_token * enc_input_buf = embd_inp.data();
Expand Down
9 changes: 6 additions & 3 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1010,7 +1010,7 @@ extern "C" {
// The llama_sampler object contains the entire sampling information:
//
// - RNG state (seed and generator)
// - Custom set of constraints (see llama_sampler_add_constraint)
// - Custom set of constraints (see llama_sampler_constraint_add)
// - Sampling method (greedy, dist)
// - Previous tokens
//
Expand Down Expand Up @@ -1081,7 +1081,7 @@ extern "C" {

LLAMA_API struct llama_constraint * llama_constraint_cp(const struct llama_constraint * cnstr);

// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_add_constraint)
// important: do not call if the constraint has been added to a llama_sampler (via llama_sampler_constraint_add)
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);

LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
Expand All @@ -1100,7 +1100,10 @@ extern "C" {
LLAMA_API llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smpl);

// important: takes ownership of the constraint object and will free it in llama_sampler_free
LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);
LLAMA_API void llama_sampler_constraint_add( struct llama_sampler * smpl, struct llama_constraint * cnstr);
LLAMA_API int llama_sampler_n_constraints (const struct llama_sampler * smpl);
LLAMA_API struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i);


LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
LLAMA_API void llama_sampler_apply (struct llama_sampler * smpl, llama_token_data_array * cur_p);
Expand Down
14 changes: 13 additions & 1 deletion src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1215,10 +1215,22 @@ void llama_sampler_reset_impl(struct llama_sampler & smpl) {
// TODO: should we reset the timings?
}

void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
void llama_sampler_constraint_add_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr) {
smpl.constraints.push_back(cnstr);
}

int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl) {
return smpl.constraints.size();
}

struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith) {
if (ith < 0 || ith >= (int) smpl.constraints.size()) {
return nullptr;
}

return smpl.constraints[ith];
}

void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token) {
smpl.prev.push_back(token);

Expand Down
4 changes: 3 additions & 1 deletion src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ void llama_sampler_free_impl ( struct llama_sampler * smp
struct llama_sampler * llama_sampler_cp_impl (const struct llama_sampler & smpl);
void llama_sampler_reset_impl( struct llama_sampler & smpl);

void llama_sampler_add_constraint_impl(struct llama_sampler & smpl, struct llama_constraint * cnstr);
void llama_sampler_constraint_add_impl( struct llama_sampler & smpl, struct llama_constraint * cnstr);
int llama_sampler_n_constraints_impl (const struct llama_sampler & smpl);
struct llama_constraint * llama_sampler_constraint_get_impl(const struct llama_sampler & smpl, int ith);

void llama_sampler_accept_impl(struct llama_sampler & smpl, llama_token token);
void llama_sampler_apply_impl (struct llama_sampler & smpl, struct llama_token_data_array * cur_p);
Expand Down
12 changes: 10 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20699,8 +20699,16 @@ llama_token_data_array * llama_sampler_get_candidates(struct llama_sampler * smp
return &smpl->cur_p;
}

void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
llama_sampler_add_constraint_impl(*smpl, cnstr);
void llama_sampler_constraint_add(struct llama_sampler * smpl, struct llama_constraint * cnstr) {
llama_sampler_constraint_add_impl(*smpl, cnstr);
}

int llama_sampler_n_constraints (const struct llama_sampler * smpl) {
return llama_sampler_n_constraints_impl(*smpl);
}

struct llama_constraint * llama_sampler_constraint_get(const struct llama_sampler * smpl, int32_t i) {
return llama_sampler_constraint_get_impl(*smpl, i);
}

void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
Expand Down

0 comments on commit 762e955

Please sign in to comment.