Skip to content

Commit

Permalink
sampling : remove n_valid from the state
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 17, 2024
1 parent 5adba1c commit 25a46f2
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 64 deletions.
12 changes: 6 additions & 6 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,16 +249,14 @@ void gpt_params_handle_model_default(gpt_params & params) {
}

bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false;
std::string arg;
const std::string arg_prefix = "--";
auto & sparams = params.sparams;

for (int i = 1; i < argc; i++) {
arg = argv[i];
const std::string arg_prefix = "--";

std::string arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
bool invalid_param = false;
if (!gpt_params_find_arg(argc, argv, arg, params, i, invalid_param)) {
throw std::invalid_argument("error: unknown argument: " + arg);
}
Expand All @@ -275,6 +273,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {

gpt_params_handle_hf_token(params);

auto & sparams = params.sparams;

if (params.escape) {
string_process_escapes(params.prompt);
string_process_escapes(params.input_prefix);
Expand Down
12 changes: 4 additions & 8 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ struct llama_sampling_context * llama_sampling_init(const struct gpt_sampling_pa
llama_sampling_set_logit_bias(result->smpl, params.logit_bias.size(), params.logit_bias.data());
}

result->n_valid = 0;

return result;
}

Expand All @@ -55,7 +53,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
llama_sampling_reset(ctx->smpl);

ctx->cur.clear();
ctx->n_valid = 0;
ctx->org.clear();
}

void llama_sampling_cp(llama_sampling_context * src, llama_sampling_context * dst) {
Expand Down Expand Up @@ -294,11 +292,11 @@ static llama_token llama_sampling_sample(

llama_token id = 0;

if (temp < 0.0) {
if (temp < 0.0f || (temp == 0.0f && params.n_probs > 0)) {
// greedy sampling, with probs
llama_sampling_softmax(smpl, cur_p);
id = cur_p->data[0].id;
} else if (temp == 0.0) {
} else if (temp == 0.0f) {
// greedy sampling, no probs
id = llama_sampling_sample_greedy(smpl, cur_p);
} else {
Expand All @@ -325,8 +323,6 @@ static llama_token llama_sampling_sample(
}
}

ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p->size;

return id;
}

Expand All @@ -341,7 +337,7 @@ llama_token llama_sampling_sample(
return llama_sampling_sample(ctx_sampling, &cur_p);
}

// TODO: this lofic is confusing, try to figure out a better way to handle this
// TODO: this logic is confusing, try to figure out a better way to handle this

// store the original candidates
ctx_sampling->org = ctx_sampling->cur;
Expand Down
44 changes: 21 additions & 23 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,27 @@ enum class llama_sampler_type : char {

// sampling parameters
typedef struct gpt_sampling_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampling_context
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t min_keep = 0; // 0 = disabled, otherwise samplers should return at least min_keep tokens
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float min_p = 0.05f; // 0.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
bool penalize_nl = false; // consider newlines as a repeatable token
bool ignore_eos = false;

std::vector<llama_sampler_type> samplers_sequence = {
llama_sampler_type::TOP_K,
Expand Down Expand Up @@ -68,8 +68,6 @@ struct llama_sampling_context {

std::vector<llama_token_data> cur;
std::vector<llama_token_data> org;

size_t n_valid; // Number of correct top tokens with correct probabilities.
};

// Create a new sampling context instance.
Expand Down
34 changes: 7 additions & 27 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2350,35 +2350,15 @@ struct server_context {
metrics.on_prompt_eval(slot);
}

llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
result.tok = id;

const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_valid = slot.ctx_sampling->n_valid;

// Make sure at least n_probs top tokens are at the front of the vector:
// TODO: decide to how to handle this after the refactoring
//if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
// llama_sampling_top_k(slot.ctx_sampling->smpl, &cur_p, n_probs, 0);
//}

if (slot.sparams.temp == 0.0f) {
// With greedy sampling the probabilities have possibly not been calculated.
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i == 0 ? 1.0f : 0.0f
});
}
} else {
for (size_t i = 0; i < n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
});
}
}
const llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };

for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
result.probs.push_back({
cur_p.data[i].id,
i >= cur_p.size ? 0.0f : cur_p.data[i].p,
});
}

if (!process_token(result, slot)) {
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ extern "C" {
} llama_token_data;

typedef struct llama_token_data_array {
// TODO: consider SoA
llama_token_data * data;
size_t size;
bool sorted;
Expand Down

0 comments on commit 25a46f2

Please sign in to comment.