Skip to content

Commit

Permalink
llama : add comments [no ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Aug 30, 2024
1 parent 5dde421 commit 584ef0e
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 19 deletions.
24 changes: 18 additions & 6 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1036,23 +1036,33 @@ extern "C" {

LLAMA_API void llama_sampling_free(struct llama_sampling * smpl);

// Copies the internal state of the sampler (rng, prev, params, grammar, etc.)
LLAMA_API struct llama_sampling * llama_sampling_cp(const struct llama_sampling * smpl);

// - clear prev token
// - reset grammar state
LLAMA_API void llama_sampling_reset(struct llama_sampling * smpl);

LLAMA_API void llama_sampling_set_rng_seed (struct llama_sampling * smpl, uint32_t seed);
// Sampling parameter mutation
// TODO: not sure if we want to keep these. Maybe it's better to keep llama_sampling immutable
LLAMA_API void llama_sampling_set_grammar (struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root);
LLAMA_API void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit_bias, const llama_logit_bias * logit_bias);

// Set the logits from which to sample.
// This call initializes the internal token candidates array.
// The internal candidates are implicitly used by the sampling API below when no candidates are provided.
LLAMA_API void llama_sampling_set_logits(
struct llama_sampling * smpl,
const float * logits);

/// @details Returns the current candidate tokens.
LLAMA_API llama_token_data_array * llama_sampling_get_candidates(
struct llama_sampling * smpl);

// The llama_sampling_ API below uses the parameters passed during the creation of the llama_sampling object.
// Each function can accept an array of token candidates. If the candidates are not provided, the internal
// candidates are used. The internal candidates are initialized by llama_sampling_set_logits().

/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
LLAMA_API void llama_sampling_softmax(
struct llama_sampling * smpl,
Expand Down Expand Up @@ -1115,17 +1125,22 @@ extern "C" {
struct llama_sampling * smpl,
llama_token_data_array * candidates);

/// @details Sample a token using the configured samplers.
/// @details Sample a token using the configured samplers (see "llama_sampling_params.samplers").
LLAMA_API llama_token llama_sampling_sample(
struct llama_sampling * smpl,
llama_token_data_array * candidates);

/// @details Accepts the sampled token into the sampling context
/// @details Accepts the sampled token into the sampling context.
/// - adds it to "prev" tokens
/// - updates the grammar state (if apply_grammar is true)
LLAMA_API void llama_sampling_accept(
struct llama_sampling * smpl,
llama_token token,
bool apply_grammar);

/// @details Get the number of accepted tokens so far (max of n_prev)
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);

/// @details Get the ith accepted token
/// @param ith [0, n_prev), ith == 0 is the last accepted token.
/// returns LLAMA_TOKEN_NULL if ith is out of bounds
Expand All @@ -1138,9 +1153,6 @@ extern "C" {
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);

/// @details Get the number of accepted tokens (max of n_prev)
LLAMA_API int llama_sampling_n_prev(const struct llama_sampling * smpl);

//
// Model split
//
Expand Down
7 changes: 5 additions & 2 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,12 @@ void llama_sampling_top_k_impl(llama_token_data_array * candidates, int32_t k, s
int ib = nbuckets - 1;
for ( ; ib >= 0; --ib) {
nhave += histo[ib];
if (nhave >= k) break;
if (nhave >= k) {
break;
}
}
std::vector<llama_token_data> tmp_tokens(nhave);
auto ptr = tmp_tokens.data();
auto * ptr = tmp_tokens.data();
std::vector<llama_token_data*> bucket_ptrs;
bucket_ptrs.reserve(nbuckets - ib);
for (int j = nbuckets - 1; j >= ib; --j) {
Expand Down Expand Up @@ -573,6 +575,7 @@ llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));

float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;

Expand Down
6 changes: 3 additions & 3 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ llama_token llama_sampling_sample_mirostat_impl (struct llama_token_data_array
/// @param mu Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
llama_token llama_sampling_sample_mirostat_v2_impl(struct llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float & mu);

llama_token llama_sampling_sample_greedy_impl (struct llama_token_data_array * candidates);
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);
llama_token llama_sampling_sample_greedy_impl(struct llama_token_data_array * candidates);
llama_token llama_sampling_sample_dist_impl (struct llama_token_data_array * candidates, std::mt19937 & rng);

void llama_sampling_accept_impl(struct llama_sampling & smpl, llama_token token, bool apply_grammar);

llama_token llama_sampling_prev_impl(const struct llama_sampling & smpl, int ith);
llama_token llama_sampling_prev_impl (const struct llama_sampling & smpl, int ith);
int llama_sampling_n_prev_impl(const struct llama_sampling & smpl);
12 changes: 4 additions & 8 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20106,10 +20106,6 @@ void llama_sampling_reset(struct llama_sampling * smpl) {
llama_sampling_reset_impl(*smpl);
}

void llama_sampling_set_rng_seed(struct llama_sampling * smpl, uint32_t seed) {
llama_sampling_set_rng_seed_impl(*smpl, seed);
}

void llama_sampling_set_grammar(struct llama_sampling * smpl, const char * grammar_str, const char * grammar_root) {
llama_sampling_set_grammar_impl(*smpl, grammar_str, grammar_root);
}
Expand Down Expand Up @@ -20414,6 +20410,10 @@ void llama_sampling_accept(
smpl->n_accept++;
}

int llama_sampling_n_prev(const struct llama_sampling * smpl) {
return llama_sampling_n_prev_impl(*smpl);
}

llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
return llama_sampling_prev_impl(*smpl, ith);
}
Expand All @@ -20422,10 +20422,6 @@ llama_token llama_sampling_last(const struct llama_sampling * smpl) {
return llama_sampling_prev_impl(*smpl, 0);
}

int llama_sampling_n_prev(const struct llama_sampling * smpl) {
return llama_sampling_n_prev_impl(*smpl);
}

//
// model split
//
Expand Down

0 comments on commit 584ef0e

Please sign in to comment.