Skip to content

Commit

Permalink
llama : move encoder output from llama_batch to llama_context, add is…
Browse files Browse the repository at this point in the history
…_encoding flag.
  • Loading branch information
sszymczy committed Jun 14, 2024
1 parent 2bd023d commit b6694e2
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 76 deletions.
8 changes: 1 addition & 7 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,12 +509,6 @@ int main(int argc, char ** argv) {
return 1;
}

int32_t n_enc_output = enc_input_size;
const int n_embd = llama_n_embd(model);
size_t enc_output_size = sizeof(float) * n_embd * enc_input_size;
float * enc_output = (float*) malloc(enc_output_size);
memcpy(enc_output, llama_get_embeddings(ctx), enc_output_size);

embd_inp.clear();
embd_inp.push_back(0);

Expand Down Expand Up @@ -663,7 +657,7 @@ int main(int argc, char ** argv) {

LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0, n_enc_output, enc_output))) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
Expand Down
96 changes: 33 additions & 63 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2446,6 +2446,11 @@ struct llama_context {
// populated only when pooling_type != LLAMA_POOLING_TYPE_NONE
std::map<llama_seq_id, std::vector<float>> embd_seq;

// whether we are computing encoder output or decoder output
bool is_encoding = false;
// output of the encoder part of the encoder-decoder models
std::vector<float> encoder_output;

// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
ggml_backend_sched_t sched = nullptr;
Expand Down Expand Up @@ -6913,11 +6918,10 @@ static struct ggml_tensor * llm_build_inp_enc_output(
struct ggml_context * ctx,
struct llama_context & lctx,
const llama_hparams & hparams,
const llama_batch & batch,
const llm_build_cb & cb) {

const int64_t n_embd = hparams.n_embd;
lctx.inp_enc_output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, batch.n_enc_output == 0 ? 512 : batch.n_enc_output);
lctx.inp_enc_output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd);
ggml_set_input(lctx.inp_enc_output);
cb(lctx.inp_enc_output, "enc_output", -1);

Expand Down Expand Up @@ -11694,7 +11698,7 @@ struct llm_build_context {
return gf;
}

struct ggml_cgraph * build_t5(bool is_encoding = false) {
struct ggml_cgraph * build_t5() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

// mutable variable, needed during the last layer of the computation to skip unused tokens
Expand All @@ -11703,14 +11707,14 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
const int32_t n_enc_output = batch.n_enc_output == 0 ? 512 : batch.n_enc_output;
const int32_t n_enc_output = lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd;

struct ggml_tensor * cur;
struct ggml_tensor * inpL;

inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

if (is_encoding) {
if (lctx.is_encoding) {
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(ctx0, lctx, batch, model.enc_rel_attn_b, n_tokens, n_tokens, false, cb);

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
Expand Down Expand Up @@ -11817,7 +11821,7 @@ struct llm_build_context {
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
} else {
struct ggml_tensor * enc_output = llm_build_inp_enc_output(ctx0, lctx, hparams, batch, cb);
struct ggml_tensor * enc_output = llm_build_inp_enc_output(ctx0, lctx, hparams, cb);
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(ctx0, lctx, batch, model.rel_attn_b, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), true, cb);

struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
Expand Down Expand Up @@ -12055,8 +12059,7 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) {
static struct ggml_cgraph * llama_build_graph(
llama_context & lctx,
const llama_batch & batch,
bool worst_case,
bool is_encoding = false) {
bool worst_case) {
const auto & model = lctx.model;

// this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.)
Expand Down Expand Up @@ -12225,7 +12228,7 @@ static struct ggml_cgraph * llama_build_graph(
} break;
case LLM_ARCH_T5:
{
result = llm.build_t5(is_encoding);
result = llm.build_t5();
} break;
default:
GGML_ASSERT(false);
Expand Down Expand Up @@ -12289,14 +12292,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}

if (lctx.inp_pos_bucket) {
const int64_t n_tokens = batch.enc_output ? GGML_PAD(batch.n_tokens, GGML_KQ_MASK_PAD) : batch.n_tokens;
const int64_t n_tokens = lctx.is_encoding ? batch.n_tokens : GGML_PAD(batch.n_tokens, GGML_KQ_MASK_PAD);

const int64_t query_length = lctx.inp_pos_bucket->ne[0];
const int64_t key_length = lctx.inp_pos_bucket->ne[1];

int64_t num_buckets = hparams.n_rel_attn_bkts;
const int64_t max_distance = 128; // TODO move to haprams
bool bidirectional = batch.enc_output == NULL;
bool bidirectional = lctx.is_encoding;

if (bidirectional) {
num_buckets >>= 1;
Expand Down Expand Up @@ -12325,10 +12328,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
free(pos_bucket);
}

if (batch.enc_output && lctx.inp_enc_output) {
const int64_t n_embd = hparams.n_embd;

ggml_backend_tensor_set(lctx.inp_enc_output, batch.enc_output, 0, batch.n_enc_output*n_embd*ggml_element_size(lctx.inp_enc_output));
if (!lctx.is_encoding && lctx.inp_enc_output) {
ggml_backend_tensor_set(lctx.inp_enc_output, lctx.encoder_output.data(), 0, lctx.encoder_output.size() * ggml_element_size(lctx.inp_enc_output));
}

if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
Expand Down Expand Up @@ -12545,7 +12546,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

// Make sure enough space is available for outputs.
// Returns max number of outputs for which space was reserved.
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs, bool is_encoding = false) {
static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) {
const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams;

Expand All @@ -12557,7 +12558,7 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs, bool

// TODO: use a per-batch flag for logits presence instead
const bool has_logits = cparams.causal_attn;
const bool has_embd = is_encoding || (cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));
const bool has_embd = lctx.is_encoding || (cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE));

const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0;
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0;
Expand Down Expand Up @@ -12636,7 +12637,6 @@ static void llama_graph_compute(
// fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
}


// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
Expand All @@ -12650,6 +12650,7 @@ static int llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch

lctx.is_encoding = false;
const uint32_t n_tokens_all = batch_all.n_tokens;

if (n_tokens_all == 0) {
Expand Down Expand Up @@ -12732,8 +12733,6 @@ static int llama_decode_internal(
/* .all_pos_0 = */ batch_all.all_pos_0 + (llama_pos) cur_token*batch_all.all_pos_1,
/* .all_pos_1 = */ batch_all.all_pos_1,
/* .all_seq_id = */ batch_all.all_seq_id,
batch_all.n_enc_output,
batch_all.enc_output,
};

// count the outputs in this u_batch
Expand Down Expand Up @@ -12983,6 +12982,7 @@ static int llama_encode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch

lctx.is_encoding = true;
const uint32_t n_tokens_all = batch_all.n_tokens;

if (n_tokens_all == 0) {
Expand Down Expand Up @@ -13022,7 +13022,7 @@ static int llama_encode_internal(
n_outputs = n_tokens_all;

// reserve output buffer
if (llama_output_reserve(lctx, n_outputs, true) < n_outputs) {
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
};
Expand Down Expand Up @@ -13090,7 +13090,7 @@ static int llama_encode_internal(
ggml_backend_sched_reset(lctx.sched);
ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data);

ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false, true);
ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false);

// the output is always the last tensor in the graph
struct ggml_tensor * embd = gf->nodes[gf->n_nodes - 2];
Expand Down Expand Up @@ -13144,42 +13144,16 @@ static int llama_encode_internal(
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(lctx.sched, embd);
GGML_ASSERT(backend_embd != nullptr);

switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
float * embd_out = lctx.embd + n_outputs_prev*n_embd;
const int32_t n_outputs_new = lctx.n_outputs;

if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_CLS:
case LLAMA_POOLING_TYPE_MEAN:
{
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0);

// extract sequence embeddings
auto & embd_seq_out = lctx.embd_seq;
embd_seq_out.clear();
// extract token embeddings
GGML_ASSERT(lctx.embd != nullptr);
const int32_t n_outputs_new = lctx.n_outputs;
lctx.encoder_output.resize((n_outputs_prev + n_outputs_new)*n_embd);
float * embd_out = lctx.encoder_output.data() + n_outputs_prev*n_embd;

for (uint32_t i = 0; i < n_tokens; i++) {
const llama_seq_id seq_id = u_batch.seq_id[i][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(n_embd);
ggml_backend_tensor_get_async(backend_embd, embd, embd_seq_out[seq_id].data(), (n_embd*seq_id)*sizeof(float), n_embd*sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ASSERT(false && "unknown pooling type");
} break;
if (n_outputs_new) {
GGML_ASSERT( n_outputs_prev + n_outputs_new <= n_outputs);
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}
}
n_outputs_prev += lctx.n_outputs;
Expand Down Expand Up @@ -19101,9 +19075,7 @@ struct llama_batch llama_batch_get_one(
llama_token * tokens,
int32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id,
int32_t n_enc_output,
float * enc_output) {
llama_seq_id seq_id) {
return {
/*n_tokens =*/ n_tokens,
/*tokens =*/ tokens,
Expand All @@ -19115,13 +19087,11 @@ struct llama_batch llama_batch_get_one(
/*all_pos_0 =*/ pos_0,
/*all_pos_1 =*/ 1,
/*all_seq_id =*/ seq_id,
/*n_enc_output =*/ n_enc_output,
/*enc_output =*/ enc_output,
};
}

struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, 0, nullptr };
llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0 };

if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens_alloc * embd);
Expand Down
7 changes: 1 addition & 6 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,6 @@ extern "C" {
llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL

int32_t n_enc_output;
float * enc_output;
} llama_batch;

enum llama_model_kv_override_type {
Expand Down Expand Up @@ -753,9 +750,7 @@ extern "C" {
llama_token * tokens,
int32_t n_tokens,
llama_pos pos_0,
llama_seq_id seq_id,
int32_t n_enc_output = 0,
float * enc_output = NULL);
llama_seq_id seq_id);

// Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
// Each token can be assigned up to n_seq_max sequence ids
Expand Down

0 comments on commit b6694e2

Please sign in to comment.