diff --git a/src/llama.cpp b/src/llama.cpp index c09471aaf696f..04a01b253058e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -10460,11 +10460,11 @@ static struct ggml_tensor * llm_build_mamba2( // (ab)using the KV cache to store the states struct ggml_tensor * conv = llm_build_rs(ctx, graph, conv_states_all, state_copy, rs_zero, - hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + hparams.n_embd_k_s(il), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); struct ggml_tensor * ssm = llm_build_rs(ctx, graph, ssm_states_all, state_copy, rs_zero, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + hparams.n_embd_v_s(il), kv.size, kv_head, n_kv, n_seqs, true); ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10808,7 +10808,7 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (ubatch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n), + n_kv_hybrid (worst_case ? kv_hybrid.size : kv_hybrid.n), n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), @@ -11036,8 +11036,8 @@ struct llm_build_context { return lctx.inp_cls; } - struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + struct ggml_tensor * build_inp_s_copy(bool hybrid = false) { + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hybrid ? n_kv_hybrid : n_kv); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; @@ -14686,7 +14686,7 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb); - struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_copy = build_inp_s_copy(/* hybrid */true); const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14710,7 +14710,8 @@ struct llm_build_context { if (hparams.recurrent_layer(il)) { // ssm layer cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy, - rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true); + rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, + /* hybrid */ true); cb(cur, "mamba_out", il); } else { // attention layer // @@ -17813,8 +17814,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + const bool hybrid = llama_model_is_hybrid(&lctx.model); + auto& kv_hybrid = lctx.kv_hybrid; + if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) { + auto& kv_recurrent = hybrid ? kv_hybrid : lctx.kv_self; + const int64_t n_kv = kv_recurrent.n; if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); @@ -17822,14 +17826,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + const uint32_t cell_id = i + kv_recurrent.head; + llama_kv_cell & kv_cell = kv_recurrent.cells[cell_id]; if (kv_cell.src < 0) { - GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source - kv_cell.src = kv_self.rs_z; + GGML_ASSERT(kv_recurrent.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_recurrent.rs_z; } - if ((uint32_t) kv_cell.src >= kv_self.size) { + if ((uint32_t) kv_cell.src >= kv_recurrent.size) { // ignore out-of-bound sources kv_cell.src = cell_id; } @@ -18135,7 +18139,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ !(kv_self.recurrent || (hybrid && kv_hybrid.recurrent)), /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -18146,7 +18150,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { + if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) ubatch = lctx.sbatch.split_seq(n_ubatch);