diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index d895c9acdb596..b87ec8840b135 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -552,6 +552,13 @@ extern "C" { GGML_TENSOR_FLAG_PARAM = 4, }; + // Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache + enum ggml_kv_cache_flag { + GGML_KV_CACHE_FLAG_NONE = 0, + GGML_KV_CACHE_FLAG_K = 1, + GGML_KV_CACHE_FLAG_V = 2 + }; + // ggml object struct ggml_object { size_t offs; @@ -586,6 +593,8 @@ extern "C" { // op params - allocated as int32_t for alignment int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + enum ggml_kv_cache_flag kv_cache_flag; + int32_t flags; struct ggml_tensor * grad; @@ -601,7 +610,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - // char padding[4]; + char padding[1]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index bc91ac3a726ab..7bc3c079cc6f3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3638,6 +3638,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.nb =*/ { 0, 0, 0, 0 }, /*.op =*/ GGML_OP_NONE, /*.op_params =*/ { 0 }, + /*.kv_cache_flag=*/ GGML_KV_CACHE_FLAG_NONE, /*.flags =*/ 0, /*.grad =*/ NULL, /*.src =*/ { NULL }, @@ -3646,7 +3647,7 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, + /*.padding =*/ { 0 }, }; #ifdef __clang__ diff --git a/src/llama.cpp b/src/llama.cpp index 4a309b999205a..d0a237d2decd8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7794,7 +7794,9 @@ static void llm_build_kv_store( cb(k_cache_view, "k_cache_view", il); // note: storing RoPE-ed version of K in the KV cache - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view); + tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K; + ggml_build_forward_expand(graph, tmp); assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens); @@ -7812,8 +7814,9 @@ static void llm_build_kv_store( v_cur = ggml_transpose(ctx, v_cur); } cb(v_cache_view, "v_cache_view", il); - - ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + tmp=ggml_cpy(ctx, v_cur, v_cache_view); + tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V; + ggml_build_forward_expand(graph, tmp); } static struct ggml_tensor * llm_build_norm( @@ -14606,30 +14609,20 @@ static int llama_decode_internal( if(ggml_use_cached_graph(lctx.sched)) { - // If using flash attention, find mask node so it can be skipped when updating - // KV cache paramaters in cached graph nodes below - void * flash_attn_mask_node = nullptr; - if(cparams.flash_attn) { - for (int i = 0; i < gf->n_nodes; i++) { - ggml_tensor * node = gf->nodes[i]; - if (node->op == GGML_OP_FLASH_ATTN_EXT) { - flash_attn_mask_node = node->src[3]; - break; - } - } - } - // Temporarily store KV cache parameters that will need updated in cached graph. const struct llama_hparams & hparams = model.hparams; const int64_t n_layer = hparams.n_layer; const int64_t kv_head = kv_self.head; std::vector kv_cache_ptrs; + std::vector k_cache_ptrs; + std::vector v_cache_ptrs; for (int il = 0; il < n_layer; ++il) { const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); ggml_tensor * tmp_tensor = kv_self.k_l[il]; size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head; kv_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + k_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); tmp_tensor = kv_self.v_l[il]; if (cparams.flash_attn) { tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa); @@ -14637,17 +14630,21 @@ static int llama_decode_internal( tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]); } kv_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); + v_cache_ptrs.push_back(static_cast(tmp_tensor->data) + tmp_offset); } // Update KV cache parameters in cached graph. - int copy_op_count = 0; + int k_count = 0; + int v_count = 0; if(gf != nullptr && gf->nodes != nullptr){ for (int i = 0; i < gf->n_nodes; i++) { ggml_tensor * node = gf->nodes[i]; if (node->op == GGML_OP_CPY) { - if (node != flash_attn_mask_node) { - node->src[1]->data = kv_cache_ptrs[copy_op_count]; - copy_op_count++; + if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) { + node->src[1]->data = k_cache_ptrs[k_count++]; + } + if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) { + node->src[1]->data = v_cache_ptrs[v_count++]; } } }