Skip to content

Commit

Permalink
llama : rename batch_all to batch
Browse files Browse the repository at this point in the history
This commit addresses the TODO in the code to rename the `batch_all`
parameter to `batch` in `llama_decode_internal`.
  • Loading branch information
danbev committed Aug 6, 2024
1 parent c21a896 commit 436872f
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14447,10 +14447,10 @@ static void llama_graph_compute(
//
static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch
llama_batch batch) {

lctx.is_encoding = false;
const uint32_t n_tokens_all = batch_all.n_tokens;
const uint32_t n_tokens_all = batch.n_tokens;

if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
Expand All @@ -14461,7 +14461,7 @@ static int llama_decode_internal(
const auto & hparams = model.hparams;
const auto & cparams = lctx.cparams;

GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT

GGML_ASSERT(n_tokens_all <= cparams.n_batch);

Expand Down Expand Up @@ -14492,9 +14492,9 @@ static int llama_decode_internal(
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;

// count outputs
if (batch_all.logits && !embd_pooled) {
if (batch.logits && !embd_pooled) {
for (uint32_t i = 0; i < n_tokens_all; ++i) {
n_outputs += batch_all.logits[i] != 0;
n_outputs += batch.logits[i] != 0;
}
} else if (lctx.logits_all || embd_pooled) {
n_outputs = n_tokens_all;
Expand All @@ -14510,10 +14510,10 @@ static int llama_decode_internal(
};

// set output mappings
if (batch_all.logits) {
if (batch.logits) {
int32_t i_logits = 0;
for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.logits[i]) {
if (batch.logits[i]) {
lctx.output_ids[i] = i_logits++;
}
}
Expand All @@ -14527,15 +14527,15 @@ static int llama_decode_internal(
const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token);
llama_batch u_batch = {
/* .n_tokens = */ (int32_t) n_tokens,
/* .token = */ batch_all.token ? batch_all.token + cur_token : nullptr,
/* .embd = */ batch_all.embd ? batch_all.embd + cur_token*n_embd : nullptr,
/* .pos = */ batch_all.pos ? batch_all.pos + cur_token : nullptr,
/* .n_seq_id = */ batch_all.n_seq_id ? batch_all.n_seq_id + cur_token : nullptr,
/* .seq_id = */ batch_all.seq_id ? batch_all.seq_id + cur_token : nullptr,
/* .logits = */ batch_all.logits ? batch_all.logits + cur_token : nullptr,
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
/* .all_pos_1 = */ batch_all.all_pos_1,
/* .all_seq_id = */ batch_all.all_seq_id,
/* .token = */ batch.token ? batch.token + cur_token : nullptr,
/* .embd = */ batch.embd ? batch.embd + cur_token*n_embd : nullptr,
/* .pos = */ batch.pos ? batch.pos + cur_token : nullptr,
/* .n_seq_id = */ batch.n_seq_id ? batch.n_seq_id + cur_token : nullptr,
/* .seq_id = */ batch.seq_id ? batch.seq_id + cur_token : nullptr,
/* .logits = */ batch.logits ? batch.logits + cur_token : nullptr,
/* .all_pos_0 = */ batch.all_pos_0 + (llama_pos) cur_token*batch.all_pos_1,
/* .all_pos_1 = */ batch.all_pos_1,
/* .all_seq_id = */ batch.all_seq_id,
};

// count the outputs in this u_batch
Expand Down

0 comments on commit 436872f

Please sign in to comment.