Skip to content

Commit

Permalink
llama : proper handling of batches, support for multiple sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
sszymczy committed Jun 17, 2024
1 parent 7821068 commit 205fee3
Showing 1 changed file with 186 additions and 98 deletions.
284 changes: 186 additions & 98 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2450,6 +2450,7 @@ struct llama_context {
bool is_encoding = false;
// output of the encoder part of the encoder-decoder models
std::vector<float> encoder_output;
std::vector<std::set<llama_seq_id> > encoder_output_seq_ids;

// memory buffers used to evaluate the model
std::vector<uint8_t> buf_compute_meta;
Expand All @@ -2472,6 +2473,7 @@ struct llama_context {
struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch]
struct ggml_tensor * inp_pos_bucket; // I32 []
struct ggml_tensor * inp_enc_output; // F32 []
struct ggml_tensor * inp_cross_KQ_mask; // F32 []

// control vectors
struct llama_control_vector cvec;
Expand Down Expand Up @@ -6878,56 +6880,6 @@ static struct ggml_tensor * llm_build_inp_embd(
return inpL;
}

static struct ggml_tensor * llm_build_inp_rel_pos_bias(
struct ggml_context * ctx,
struct llama_context & lctx,
const llama_batch & batch,
struct ggml_tensor * rel_attn_b,
int64_t query_length,
int64_t key_length,
bool causal,
const llm_build_cb & cb) {

lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, query_length, key_length);
ggml_set_input(lctx.inp_pos_bucket);
cb(lctx.inp_pos_bucket, "pos_bucket", -1);

struct ggml_tensor * pos_bucket = ggml_dup(ctx, lctx.inp_pos_bucket);
cb(pos_bucket, "pos_bucket", -1);

struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
cb(pos_bucket_1d, "pos_bucket_1d", -1);

struct ggml_tensor * pos_bias = ggml_get_rows(ctx, rel_attn_b, pos_bucket_1d);
cb(pos_bias, "pos_bias", -1);

if (!causal) {
pos_bias = ggml_view_3d(ctx, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0);
} else {
pos_bias = ggml_view_3d(ctx, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], 1, ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0] * batch.all_pos_0);
}
cb(pos_bias, "pos_bias", -1);

pos_bias = ggml_permute(ctx, pos_bias, 2, 0, 1, 3);
cb(pos_bias, "pos_bias", -1);

return pos_bias;
}

static struct ggml_tensor * llm_build_inp_enc_output(
struct ggml_context * ctx,
struct llama_context & lctx,
const llama_hparams & hparams,
const llm_build_cb & cb) {

const int64_t n_embd = hparams.n_embd;
lctx.inp_enc_output = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd);
ggml_set_input(lctx.inp_enc_output);
cb(lctx.inp_enc_output, "enc_output", -1);

return lctx.inp_enc_output;
}

static void llm_build_kv_store(
struct ggml_context * ctx,
const llama_hparams & hparams,
Expand Down Expand Up @@ -7465,6 +7417,9 @@ struct llm_build_context {
lctx.inp_s_copy = nullptr;
lctx.inp_s_mask = nullptr;
lctx.inp_s_seq = nullptr;
lctx.inp_pos_bucket = nullptr;
lctx.inp_enc_output = nullptr;
lctx.inp_cross_KQ_mask = nullptr;
}

void free() {
Expand Down Expand Up @@ -7663,6 +7618,52 @@ struct llm_build_context {
return lctx.inp_s_seq;
}

struct ggml_tensor * llm_build_inp_rel_pos_bias(
struct ggml_tensor * rel_attn_b,
bool causal) {

if (causal) {
lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens);
} else {
lctx.inp_pos_bucket = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_tokens);
}

ggml_set_input(lctx.inp_pos_bucket);
cb(lctx.inp_pos_bucket, "pos_bucket", -1);

struct ggml_tensor * pos_bucket = ggml_dup(ctx0, lctx.inp_pos_bucket);
cb(pos_bucket, "pos_bucket", -1);

struct ggml_tensor * pos_bucket_1d = ggml_view_1d(ctx0, pos_bucket, pos_bucket->ne[0] * pos_bucket->ne[1], 0);
cb(pos_bucket_1d, "pos_bucket_1d", -1);

struct ggml_tensor * pos_bias = ggml_get_rows(ctx0, rel_attn_b, pos_bucket_1d);
cb(pos_bias, "pos_bias", -1);

pos_bias = ggml_view_3d(ctx0, pos_bias, pos_bias->ne[0], lctx.inp_pos_bucket->ne[0], lctx.inp_pos_bucket->ne[1], ggml_element_size(pos_bias) * pos_bias->ne[0], ggml_element_size(pos_bias) * pos_bias->ne[0] * lctx.inp_pos_bucket->ne[0], 0);
cb(pos_bias, "pos_bias", -1);

pos_bias = ggml_permute(ctx0, pos_bias, 2, 0, 1, 3);
cb(pos_bias, "pos_bias", -1);

return pos_bias;
}

struct ggml_tensor * llm_build_inp_enc_output() {
const int64_t n_embd = hparams.n_embd;
lctx.inp_enc_output = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd);
ggml_set_input(lctx.inp_enc_output);
cb(lctx.inp_enc_output, "enc_output", -1);
return lctx.inp_enc_output;
}

struct ggml_tensor * llm_build_inp_cross_KQ_mask() {
lctx.inp_cross_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, lctx.encoder_output.size() == 0 ? 512 : lctx.encoder_output.size() / n_embd, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
ggml_set_input(lctx.inp_cross_KQ_mask);
cb(lctx.inp_cross_KQ_mask, "enc_mask", -1);
return lctx.inp_cross_KQ_mask;
}

struct ggml_cgraph * build_llama() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

Expand Down Expand Up @@ -11715,7 +11716,7 @@ struct llm_build_context {
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);

if (lctx.is_encoding) {
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(ctx0, lctx, batch, model.enc_rel_attn_b, n_tokens, n_tokens, false, cb);
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(model.enc_rel_attn_b, false);

// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * enc_KQ_mask = build_inp_KQ_mask(false);
Expand Down Expand Up @@ -11821,10 +11822,11 @@ struct llm_build_context {
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
} else {
struct ggml_tensor * enc_output = llm_build_inp_enc_output(ctx0, lctx, hparams, cb);
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(ctx0, lctx, batch, model.rel_attn_b, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), true, cb);
struct ggml_tensor * enc_output = llm_build_inp_enc_output();
struct ggml_tensor * pos_bias = llm_build_inp_rel_pos_bias(model.rel_attn_b, true);

struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
struct ggml_tensor * enc_KQ_mask = llm_build_inp_cross_KQ_mask();

for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
Expand Down Expand Up @@ -11866,7 +11868,7 @@ struct llm_build_context {

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);

struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
struct ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);

struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);
Expand Down Expand Up @@ -11923,7 +11925,7 @@ struct llm_build_context {
struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
cb(kq, "kq", il);

kq = ggml_soft_max(ctx0, kq);
kq = ggml_soft_max_ext(ctx0, kq, enc_KQ_mask, 1.0f, hparams.f_max_alibi_bias);
cb(kq, "kq_soft_max_ext", il);

struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_enc_output)));
Expand Down Expand Up @@ -12291,47 +12293,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
}

if (lctx.inp_pos_bucket) {
const int64_t n_tokens = lctx.is_encoding ? batch.n_tokens : GGML_PAD(batch.n_tokens, GGML_KQ_MASK_PAD);

const int64_t query_length = lctx.inp_pos_bucket->ne[0];
const int64_t key_length = lctx.inp_pos_bucket->ne[1];

int64_t num_buckets = hparams.n_rel_attn_bkts;
const int64_t max_distance = 128; // TODO move to haprams
bool bidirectional = lctx.is_encoding;

if (bidirectional) {
num_buckets >>= 1;
}

int64_t max_exact = num_buckets >> 1;

uint32_t * pos_bucket = (uint32_t *) malloc(sizeof(uint32_t) * n_tokens * n_tokens);
for (int y = 0; y < query_length; y++) {
for (int x = 0; x < key_length; x++) {
int32_t relative_position = x - y;
int32_t relative_buckets = 0;
if (bidirectional) {
relative_buckets += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large);
pos_bucket[x + y * key_length] = relative_buckets;
}
}
ggml_backend_tensor_set(lctx.inp_pos_bucket, pos_bucket, 0, n_tokens * n_tokens * ggml_element_size(lctx.inp_pos_bucket));
free(pos_bucket);
}

if (!lctx.is_encoding && lctx.inp_enc_output) {
ggml_backend_tensor_set(lctx.inp_enc_output, lctx.encoder_output.data(), 0, lctx.encoder_output.size() * ggml_element_size(lctx.inp_enc_output));
}

if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
const int64_t n_tokens = batch.n_tokens;
Expand Down Expand Up @@ -12369,7 +12330,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

if (lctx.inp_KQ_mask) {
// NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
if (cparams.causal_attn) {
if (cparams.causal_attn && !lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

Expand Down Expand Up @@ -12409,7 +12370,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} else {
// when using kv cache, the mask needs to match the kv cache size
const int64_t n_tokens = batch.n_tokens;
const int64_t n_stride = hparams.causal_attn ? kv_self.n : n_tokens;
const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer));

Expand Down Expand Up @@ -12542,6 +12503,124 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
}
}
}

if (lctx.inp_pos_bucket) {
int64_t num_buckets = hparams.n_rel_attn_bkts;
const int64_t max_distance = 128; // TODO move to hparams
bool bidirectional = lctx.is_encoding;

if (bidirectional) {
num_buckets >>= 1;
}

int64_t max_exact = num_buckets >> 1;

if (!lctx.is_encoding) {
const int64_t n_kv = kv_self.n;
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));

int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_pos pos = batch.pos[j];
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_kv; ++i) {
int32_t f;
if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
f = 0;
} else {
int32_t relative_position = lctx.kv_self.cells[i].pos - pos;
int32_t relative_buckets = 0;
if (bidirectional) {
relative_buckets += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large);
f = relative_buckets;
}
data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
}
}
}
} else {
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));

int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
const llama_seq_id seq_id = batch.seq_id[j][0];

for (int i = 0; i < n_tokens; ++i) {
int32_t f = 0;
for (int s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id[i][s] == seq_id) {
int32_t relative_position = batch.pos[i] - batch.pos[j];
int32_t relative_buckets = 0;
if (bidirectional) {
relative_buckets += (relative_position > 0) * num_buckets;
relative_position = abs(relative_position);
} else {
relative_position = -std::min<int32_t>(relative_position, 0);
}
int32_t relative_position_if_large = floorf(max_exact + logf(1.0 * relative_position / max_exact) * (num_buckets - max_exact) / log(1.0 * max_distance / max_exact));
relative_position_if_large = std::min<int32_t>(relative_position_if_large, num_buckets - 1);
relative_buckets += (relative_position < max_exact ? relative_position : relative_position_if_large);
f = relative_buckets;
break;
}
}

data[h*(n_tokens*n_tokens) + j*n_tokens + i] = f;
}
}
}
}
}

if (!lctx.is_encoding && lctx.inp_enc_output) {
ggml_backend_tensor_set(lctx.inp_enc_output, lctx.encoder_output.data(), 0, lctx.encoder_output.size() * ggml_element_size(lctx.inp_enc_output));
}

if (!lctx.is_encoding && lctx.inp_cross_KQ_mask) {
const int64_t n_encoder_output = lctx.encoder_output.size() / hparams.n_embd;
const int64_t n_tokens = batch.n_tokens;

GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cross_KQ_mask->buffer));

float * data = (float *) lctx.inp_cross_KQ_mask->data;

for (int h = 0; h < 1; ++h) {
for (int j = 0; j < n_tokens; ++j) {
for (int i = 0; i < n_encoder_output; ++i) {
float f = -INFINITY;
for (int s = 0; s < batch.n_seq_id[j]; ++s) {
const llama_seq_id seq_id = batch.seq_id[j][s];
if (lctx.encoder_output_seq_ids[i].find(seq_id) != lctx.encoder_output_seq_ids[i].end()) {
f = 0.0f;
}
}
data[h*(n_encoder_output*n_tokens) + j*n_encoder_output + i] = f;
}
}

for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
for (int j = 0; j < n_encoder_output; ++j) {
data[h*(n_encoder_output*n_tokens) + i*n_encoder_output + j] = -INFINITY;
}
}
}
}
}

// Make sure enough space is available for outputs.
Expand Down Expand Up @@ -13155,6 +13234,15 @@ static int llama_encode_internal(
GGML_ASSERT((n_outputs_prev + n_outputs_new)*n_embd <= (int64_t) lctx.embd_size);
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float));
}

// extract output embeddings mask
lctx.encoder_output_seq_ids.resize(n_outputs_prev + n_outputs_new);
for (int i = 0; i < n_outputs_new; i++) {
for (int s = 0; s < u_batch.n_seq_id[i]; s++) {
llama_seq_id seq_id = u_batch.seq_id[i][s];
lctx.encoder_output_seq_ids[i].insert(seq_id);
}
}
}
n_outputs_prev += lctx.n_outputs;
}
Expand Down

0 comments on commit 205fee3

Please sign in to comment.