diff --git a/src/llama.cpp b/src/llama.cpp index d8852cfe494af..71b7ef622019e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2081,7 +2081,7 @@ struct llama_hparams { bool use_par_res; uint32_t n_vocab; - uint32_t n_ctx_train; // context size the model was trained on + uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_head; uint32_t n_head_kv; @@ -2665,7 +2665,7 @@ struct llama_context { struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] // KQ mask per layer, used by sliding window attention (gemma 2) - std::vector inp_KQ_mask_l; + struct ggml_tensor * inp_KQ_mask_SWA; // control vectors struct llama_control_vector cvec; @@ -7794,6 +7794,7 @@ struct llm_build_context { lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; lctx.inp_s_seq = nullptr; + lctx.inp_KQ_mask_SWA = nullptr; } void free() { @@ -7946,15 +7947,18 @@ struct llm_build_context { return lctx.inp_out_ids; } - struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { - if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + struct ggml_tensor * build_inp_KQ_mask(bool causal = true, bool sliding_window = false) { + struct ggml_tensor * KQ_mask = causal + ? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)) + : ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + cb(KQ_mask, "KQ_mask", -1); + ggml_set_input(KQ_mask); + if (sliding_window) { + lctx.inp_KQ_mask_SWA = KQ_mask; } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + lctx.inp_KQ_mask = KQ_mask; } - cb(lctx.inp_KQ_mask, "KQ_mask", -1); - ggml_set_input(lctx.inp_KQ_mask); - return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask; + return flash_attn ? ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16) : KQ_mask; } struct ggml_tensor * build_inp_mean() { @@ -11038,14 +11042,12 @@ struct llm_build_context { // KQ_mask (mask for 1 head, it will be broadcasted to all heads) // gemma 2 requires different mask for layers using sliding window (SWA) - struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(); - struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(); - lctx.inp_KQ_mask_l.clear(); + struct ggml_tensor * KQ_mask_full = build_inp_KQ_mask(true, false); + struct ggml_tensor * KQ_mask_SWA = build_inp_KQ_mask(true, true); for (int il = 0; il < n_layer; ++il) { // (il % 2) layers use SWA struct ggml_tensor * KQ_mask = (il % 2 == 0) ? KQ_mask_SWA : KQ_mask_full; - lctx.inp_KQ_mask_l.push_back(KQ_mask); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -12685,15 +12687,15 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask->buffer)); - float * data = (float *) lctx.inp_KQ_mask->data; + float * data = (float *) lctx.inp_KQ_mask->data; float * data_swa = nullptr; const llama_pos n_keep_swa = hparams.n_sliding - batch.n_tokens; if (lctx.model.arch == LLM_ARCH_GEMMA2) { - GGML_ASSERT(!lctx.inp_KQ_mask_l.empty() && "gemma 2 requires different KQ mask per layer"); + GGML_ASSERT(lctx.inp_KQ_mask_SWA); GGML_ASSERT(hparams.n_sliding > 0); - data_swa = (float *) lctx.inp_KQ_mask_l[0]->data; - data = (float *) lctx.inp_KQ_mask_l[1]->data; + data = (float *) lctx.inp_KQ_mask->data; + data_swa = (float *) lctx.inp_KQ_mask_SWA->data; // because layer masks are alternate for gemma 2, we only need to take first 2 layers }