Skip to content

Commit

Permalink
grammar : timing
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 27, 2024
1 parent 597e947 commit 179c0f4
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 39 deletions.
5 changes: 3 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,13 @@ extern "C" {
double t_load_ms;
double t_sampling_ms;
double t_grammar_ms;
double t_accept_ms;
double t_p_eval_ms;
double t_eval_ms;

int32_t n_sampling;
int32_t n_grammar_sample;
int32_t n_grammar_accept;
int32_t n_grammar;
int32_t n_accept;
int32_t n_p_eval;
int32_t n_eval;
};
Expand Down
6 changes: 3 additions & 3 deletions src/llama-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ struct llama_grammar * llama_grammar_init_impl(
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
}

struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
Expand Down Expand Up @@ -1039,15 +1039,15 @@ struct llama_grammar * llama_grammar_init_impl(const struct llama_vocab * vocab,
// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, 0, 0, 0 };
return new llama_grammar { vocab, std::move(vec_rules), std::move(stacks), {}, };
}

void llama_grammar_free_impl(struct llama_grammar * grammar) {
delete grammar;
}

struct llama_grammar * llama_grammar_cp_impl(const struct llama_grammar & grammar) {
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, 0, 0, 0 };
llama_grammar * result = new llama_grammar { grammar.vocab, grammar.rules, grammar.stacks, grammar.partial_utf8, };

// redirect elements in stacks to point to new rules
for (size_t is = 0; is < result->stacks.size(); is++) {
Expand Down
5 changes: 0 additions & 5 deletions src/llama-grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,6 @@ struct llama_grammar {

// buffer for partially generated UTF-8 sequence from accepted tokens
llama_partial_utf8 partial_utf8;

mutable int64_t t_total_us;

mutable int32_t n_sample;
mutable int32_t n_accept;
};

//
Expand Down
8 changes: 6 additions & 2 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,13 @@ struct llama_sampling {
// mirostat sampler state
float mirostat_mu;

mutable int64_t t_total_us = 0;
mutable int64_t t_sample_us = 0;
mutable int64_t t_grammar_us = 0;
mutable int64_t t_accept_us = 0;

mutable int32_t n_sample = 0;
mutable int32_t n_sample = 0;
mutable int32_t n_grammar = 0;
mutable int32_t n_accept = 0;
};

//
Expand Down
59 changes: 32 additions & 27 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20095,43 +20095,43 @@ void llama_sampling_set_logit_bias(struct llama_sampling * smpl, int32_t n_logit
}

void llama_sampling_softmax(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_softmax_impl(candidates);
}

void llama_sampling_top_k(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_top_k_impl(candidates, smpl->params.top_k, smpl->params.min_keep);
}

void llama_sampling_top_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_top_p_impl(candidates, smpl->params.top_p, smpl->params.min_keep);
}

void llama_sampling_min_p(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_min_p_impl(candidates, smpl->params.min_p, smpl->params.min_keep);
}

void llama_sampling_tail_free(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_tail_free_impl(candidates, smpl->params.tfs_z, smpl->params.min_keep);
}

void llama_sampling_typical(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_typical_impl(candidates, smpl->params.typical_p, smpl->params.min_keep);
}

void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

if (smpl->params.dynatemp_range > 0) {
const float dynatemp_min = std::max(0.0f, smpl->params.temp - smpl->params.dynatemp_range);
Expand All @@ -20144,17 +20144,19 @@ void llama_sampling_temp(struct llama_sampling * smpl, llama_token_data_array *
}

void llama_sampling_grammar(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling
time_meas tm(smpl->t_grammar_us);

if (smpl->grammar) {
llama_sampling_grammar_impl(candidates, *smpl->grammar);
}

smpl->n_grammar++;
}

void llama_sampling_penalties(
struct llama_sampling * smpl,
llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

const size_t penalty_last_n = std::min<size_t>(smpl->params.penalty_last_n, smpl->prev.size());

Expand All @@ -20181,13 +20183,13 @@ void llama_sampling_cfg(
struct llama_sampling * smpl,
float * logits,
float * logits_guidance) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

llama_sampling_cfg_impl(*smpl, logits, logits_guidance);
}

llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

const auto type = smpl->params.mirostat;

Expand Down Expand Up @@ -20217,7 +20219,7 @@ llama_token llama_sampling_sample_mirostat(struct llama_sampling * smpl, llama_t
}

llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

auto res = llama_sampling_sample_greedy_impl(candidates);

Expand All @@ -20227,7 +20229,7 @@ llama_token llama_sampling_sample_greedy(struct llama_sampling * smpl, llama_tok
}

llama_token llama_sampling_sample(struct llama_sampling * smpl, llama_token_data_array * candidates) {
time_meas tm(smpl->t_total_us);
time_meas tm(smpl->t_sample_us);

auto res = llama_sampling_sample_impl(candidates, smpl->rng);

Expand All @@ -20240,9 +20242,11 @@ void llama_sampling_accept(
struct llama_sampling * smpl,
llama_token token,
bool apply_grammar) {
time_meas tm(smpl->t_total_us); // TODO: measure grammar time separately from sampling
time_meas tm(smpl->t_accept_us);

llama_sampling_accept_impl(*smpl, token, apply_grammar);

smpl->n_accept++;
}

llama_token llama_sampling_prev(const struct llama_sampling * smpl, int32_t ith) {
Expand Down Expand Up @@ -20286,24 +20290,27 @@ void llama_print_timings(struct llama_context * ctx, struct llama_sampling * smp
/*.t_start_ms =*/ 1e-3 * ctx->t_start_us,
/*.t_end_ms =*/ 1.00 * ggml_time_ms(),
/*.t_load_ms =*/ 1e-3 * ctx->t_load_us,
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_total_us : 0.0),
/*.t_grammar_ms =*/ 1e-3 * (smpl && smpl->grammar ? smpl->grammar->t_total_us : 0.0),
/*.t_sampling_ms =*/ 1e-3 * (smpl ? smpl->t_sample_us : 0.0),
/*.t_grammar_ms =*/ 1e-3 * (smpl ? smpl->t_grammar_us : 0.0),
/*.t_accept_ms =*/ 1e-3 * (smpl ? smpl->t_accept_us : 0.0),
/*.t_p_eval_ms =*/ 1e-3 * ctx->t_p_eval_us,
/*.t_eval_ms =*/ 1e-3 * ctx->t_eval_us,

/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0),
/*.n_grammar_sample =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_sample : 0),
/*.n_grammar_accept =*/ std::max(0, smpl && smpl->grammar ? smpl->grammar->n_accept : 0),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval),
/*.n_sampling =*/ std::max(0, smpl ? smpl->n_sample : 0),
/*.n_grammar =*/ std::max(0, smpl ? smpl->n_grammar : 0),
/*.n_accept =*/ std::max(0, smpl ? smpl->n_accept : 0),
/*.n_p_eval =*/ std::max(0, ctx->n_p_eval),
/*.n_eval =*/ std::max(1, ctx->n_eval),
};

LLAMA_LOG_INFO("\n");
LLAMA_LOG_INFO("%s: load time = %10.2f ms\n", __func__, timings.t_load_ms);
LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_sampling_ms, timings.n_sampling, timings.t_sampling_ms / timings.n_sampling, 1e3 / timings.t_sampling_ms * timings.n_sampling);
LLAMA_LOG_INFO("%s: grammar time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_grammar_ms, timings.n_grammar_sample, timings.t_grammar_ms / timings.n_grammar_sample, 1e3 / timings.t_grammar_ms * timings.n_grammar_sample);
__func__, timings.t_grammar_ms, timings.n_grammar, timings.t_grammar_ms / timings.n_grammar, 1e3 / timings.t_grammar_ms * timings.n_grammar);
//LLAMA_LOG_INFO("%s: accept time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
// __func__, timings.t_accept_ms, timings.n_accept, timings.t_accept_ms / timings.n_accept, 1e3 / timings.t_accept_ms * timings.n_accept);
LLAMA_LOG_INFO("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, timings.t_p_eval_ms, timings.n_p_eval, timings.t_p_eval_ms / timings.n_p_eval, 1e3 / timings.t_p_eval_ms * timings.n_p_eval);
LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
Expand All @@ -20317,11 +20324,9 @@ void llama_reset_timings(struct llama_context * ctx, struct llama_sampling * smp
ctx->t_p_eval_us = ctx->n_p_eval = 0;

if (smpl) {
smpl->t_total_us = smpl->n_sample = 0;

if (smpl->grammar) {
smpl->grammar->t_total_us = smpl->grammar->n_sample = smpl->grammar->n_accept = 0;
}
smpl->t_sample_us = smpl->n_sample = 0;
smpl->t_grammar_us = smpl->n_grammar = 0;
smpl->t_accept_us = smpl->n_accept = 0;
}
}

Expand Down

0 comments on commit 179c0f4

Please sign in to comment.